DeepFin / agents /endpoints /predict.py
amos-fernandes's picture
Upload 151 files
b3a7985 verified
import os
from fastapi import APIRouter, Security, Depends, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
import jwt
import numpy as np
import tensorflow as tf
from rnn.models.deep_portfolio import DeepPortfolioAI
router = APIRouter()
security = HTTPBearer()
MODEL = DeepPortfolioAI(num_assets=10)
SECRET_KEY = os.getenv("SECRET_KEY")
class PredictionRequest(BaseModel):
market_data: list
news_data: list
def verify_jwt(credentials: HTTPAuthorizationCredentials = Security(security)):
try:
return jwt.decode(credentials.credentials, SECRET_KEY, algorithms=["HS256"])
except:
raise HTTPException(status_code=401, detail="Invalid token")
@router.post("/predict")
async def predict(
request: PredictionRequest,
user=Depends(verify_jwt)
):
try:
market_tensor = tf.convert_to_tensor(request.market_data, dtype=tf.float32)
prediction = MODEL([market_tensor, request.news_data])
return {
"success": True,
"prediction": prediction.numpy().tolist(),
"timestamp": np.datetime64('now').astype(str)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))