finsight-space / app.py
gaidasalsaa's picture
change app
bad3546
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional
import torch
import joblib
import logging
import numpy as np
import os
import requests
import yfinance as yf
from huggingface_hub import hf_hub_download
from model import LSTMModel
from preprocess import create_input_sequence
# ===========================
# LOGGING
# ===========================
logger = logging.getLogger("app")
logging.basicConfig(level=logging.INFO)
# ===========================
# CONFIG
# ===========================
HF_MODEL_REPO = "gaidasalsaa/lstm-exchange-rate-prediction-model"
EXCHANGE_API_KEY = os.getenv("EXCHANGE_API_KEY")
EXCHANGE_API_URL = "https://v6.exchangerate-api.com/v6"
MODEL_MAP = {
("USD", "IDR", 1): "lstm_usd_idr_1d.pth",
("USD", "IDR", 7): "lstm_usd_idr_7d.pth",
("MYR", "IDR", 1): "lstm_myr_idr_1d.pth",
("MYR", "IDR", 7): "lstm_myr_idr_7d.pth",
}
SCALER_MAP = {
"USD_IDR": "scalers/scaler_usd_idr.pkl",
"MYR_IDR": "scalers/scaler_myr_idr.pkl",
}
LOOKBACK = 30
# ===========================
# GLOBAL CACHE
# ===========================
models = {}
scalers = {}
# ===========================
# LOAD MODELS ONCE
# ===========================
def load_models_once():
global models, scalers
if models:
logger.info("Models already loaded.")
return
logger.info("Loading scalers...")
for pair, path in SCALER_MAP.items():
scalers[pair] = joblib.load(path)
logger.info("Downloading & loading models...")
for key, filename in MODEL_MAP.items():
base, target, horizon = key
model_path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename=filename
)
model = LSTMModel(output_size=horizon)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
models[key] = model
logger.info("ALL MODELS READY")
# ===========================
# FASTAPI
# ===========================
app = FastAPI(title="Forex Prediction API")
@app.on_event("startup")
def startup_event():
load_models_once()
@app.get("/health")
def health():
return {"status": "ok"}
# ===========================
# REQUEST / RESPONSE
# ===========================
class PredictionRequest(BaseModel):
base_currency: str
target_currency: str
horizon: int
class PredictionResponse(BaseModel):
message: str
data: Optional[dict] = None
# ===========================
# REAL-TIME RATE
# ===========================
def get_realtime_rate(base, target):
url = f"{EXCHANGE_API_URL}/{EXCHANGE_API_KEY}/pair/{base}/{target}"
r = requests.get(url, timeout=10)
data = r.json()
if data.get("result") != "success":
raise ValueError("ExchangeRate API failed")
return float(data["conversion_rate"])
# ===========================
# HISTORICAL DATA (YAHOO)
# ===========================
def fetch_last_30_days(base, target):
symbol = f"{base}{target}=X"
df = yf.download(
symbol,
period="60d",
interval="1d",
progress=False
)
if df.empty or len(df) < LOOKBACK:
return None
return df["Close"].dropna().tail(LOOKBACK).values.reshape(-1, 1)
# ===========================
# INFERENCE
# ===========================
def predict_forex(base, target, horizon):
key = (base, target, horizon)
pair_key = f"{base}_{target}"
model = models.get(key)
scaler = scalers.get(pair_key)
if model is None or scaler is None:
return None, None
prices = fetch_last_30_days(base, target)
if prices is None:
return None, None
scaled = scaler.transform(prices)
X = create_input_sequence(scaled, LOOKBACK)
X = torch.tensor(X, dtype=torch.float32)
with torch.no_grad():
preds = model(X).numpy()
preds = scaler.inverse_transform(preds.reshape(-1, 1)).flatten()
return prices[-1][0], preds.tolist()
# ===========================
# ROUTE
# ===========================
@app.post("/predict", response_model=PredictionResponse)
def predict(req: PredictionRequest):
current_price, preds = predict_forex(
req.base_currency,
req.target_currency,
req.horizon
)
if preds is None:
return PredictionResponse(
message="Prediction failed",
data=None
)
return PredictionResponse(
message="Prediction success",
data={
"base": req.base_currency,
"target": req.target_currency,
"horizon": req.horizon,
"current_price": current_price,
"predictions": preds
}
)