|
|
import os
|
|
|
from fastapi import APIRouter, Security, Depends, HTTPException
|
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
|
from pydantic import BaseModel
|
|
|
import jwt
|
|
|
import numpy as np
|
|
|
import tensorflow as tf
|
|
|
from rnn.models.deep_portfolio import DeepPortfolioAI
|
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
security = HTTPBearer()
|
|
|
MODEL = DeepPortfolioAI(num_assets=10)
|
|
|
SECRET_KEY = os.getenv("SECRET_KEY")
|
|
|
|
|
|
class PredictionRequest(BaseModel):
|
|
|
market_data: list
|
|
|
news_data: list
|
|
|
|
|
|
def verify_jwt(credentials: HTTPAuthorizationCredentials = Security(security)):
|
|
|
try:
|
|
|
return jwt.decode(credentials.credentials, SECRET_KEY, algorithms=["HS256"])
|
|
|
except:
|
|
|
raise HTTPException(status_code=401, detail="Invalid token")
|
|
|
|
|
|
@router.post("/predict")
|
|
|
async def predict(
|
|
|
request: PredictionRequest,
|
|
|
user=Depends(verify_jwt)
|
|
|
):
|
|
|
try:
|
|
|
market_tensor = tf.convert_to_tensor(request.market_data, dtype=tf.float32)
|
|
|
prediction = MODEL([market_tensor, request.news_data])
|
|
|
return {
|
|
|
"success": True,
|
|
|
"prediction": prediction.numpy().tolist(),
|
|
|
"timestamp": np.datetime64('now').astype(str)
|
|
|
}
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|