File size: 1,302 Bytes
5f10e37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import os
from fastapi import APIRouter, Security, Depends, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
import jwt
import numpy as np
import tensorflow as tf
from rnn.models.deep_portfolio import DeepPortfolioAI
router = APIRouter()
security = HTTPBearer()
MODEL = DeepPortfolioAI(num_assets=10)
SECRET_KEY = os.getenv("SECRET_KEY")
class PredictionRequest(BaseModel):
market_data: list
news_data: list
def verify_jwt(credentials: HTTPAuthorizationCredentials = Security(security)):
try:
return jwt.decode(credentials.credentials, SECRET_KEY, algorithms=["HS256"])
except:
raise HTTPException(status_code=401, detail="Invalid token")
@router.post("/predict")
async def predict(
request: PredictionRequest,
user=Depends(verify_jwt)
):
try:
market_tensor = tf.convert_to_tensor(request.market_data, dtype=tf.float32)
prediction = MODEL([market_tensor, request.news_data])
return {
"success": True,
"prediction": prediction.numpy().tolist(),
"timestamp": np.datetime64('now').astype(str)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|