File size: 1,302 Bytes
5f10e37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
from fastapi import APIRouter, Security, Depends, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
import jwt
import numpy as np
import tensorflow as tf
from rnn.models.deep_portfolio import DeepPortfolioAI


router = APIRouter()
security = HTTPBearer()
MODEL = DeepPortfolioAI(num_assets=10)
SECRET_KEY = os.getenv("SECRET_KEY")

class PredictionRequest(BaseModel):
    market_data: list
    news_data: list

def verify_jwt(credentials: HTTPAuthorizationCredentials = Security(security)):
    try:
        return jwt.decode(credentials.credentials, SECRET_KEY, algorithms=["HS256"])
    except:
        raise HTTPException(status_code=401, detail="Invalid token")

@router.post("/predict")
async def predict(

    request: PredictionRequest,

    user=Depends(verify_jwt)

):
    try:
        market_tensor = tf.convert_to_tensor(request.market_data, dtype=tf.float32)
        prediction = MODEL([market_tensor, request.news_data])
        return {
            "success": True,
            "prediction": prediction.numpy().tolist(),
            "timestamp": np.datetime64('now').astype(str)
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))