Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import torch | |
| import joblib | |
| import logging | |
| import numpy as np | |
| import os | |
| import requests | |
| import yfinance as yf | |
| from huggingface_hub import hf_hub_download | |
| from model import LSTMModel | |
| from preprocess import create_input_sequence | |
| # =========================== | |
| # LOGGING | |
| # =========================== | |
| logger = logging.getLogger("app") | |
| logging.basicConfig(level=logging.INFO) | |
| # =========================== | |
| # CONFIG | |
| # =========================== | |
| HF_MODEL_REPO = "gaidasalsaa/lstm-exchange-rate-prediction-model" | |
| EXCHANGE_API_KEY = os.getenv("EXCHANGE_API_KEY") | |
| EXCHANGE_API_URL = "https://v6.exchangerate-api.com/v6" | |
| MODEL_MAP = { | |
| ("USD", "IDR", 1): "lstm_usd_idr_1d.pth", | |
| ("USD", "IDR", 7): "lstm_usd_idr_7d.pth", | |
| ("MYR", "IDR", 1): "lstm_myr_idr_1d.pth", | |
| ("MYR", "IDR", 7): "lstm_myr_idr_7d.pth", | |
| } | |
| SCALER_MAP = { | |
| "USD_IDR": "scalers/scaler_usd_idr.pkl", | |
| "MYR_IDR": "scalers/scaler_myr_idr.pkl", | |
| } | |
| LOOKBACK = 30 | |
| # =========================== | |
| # GLOBAL CACHE | |
| # =========================== | |
| models = {} | |
| scalers = {} | |
| # =========================== | |
| # LOAD MODELS ONCE | |
| # =========================== | |
| def load_models_once(): | |
| global models, scalers | |
| if models: | |
| logger.info("Models already loaded.") | |
| return | |
| logger.info("Loading scalers...") | |
| for pair, path in SCALER_MAP.items(): | |
| scalers[pair] = joblib.load(path) | |
| logger.info("Downloading & loading models...") | |
| for key, filename in MODEL_MAP.items(): | |
| base, target, horizon = key | |
| model_path = hf_hub_download( | |
| repo_id=HF_MODEL_REPO, | |
| filename=filename | |
| ) | |
| model = LSTMModel(output_size=horizon) | |
| model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
| model.eval() | |
| models[key] = model | |
| logger.info("ALL MODELS READY") | |
| # =========================== | |
| # FASTAPI | |
| # =========================== | |
| app = FastAPI(title="Forex Prediction API") | |
| def startup_event(): | |
| load_models_once() | |
| def health(): | |
| return {"status": "ok"} | |
| # =========================== | |
| # REQUEST / RESPONSE | |
| # =========================== | |
| class PredictionRequest(BaseModel): | |
| base_currency: str | |
| target_currency: str | |
| horizon: int | |
| class PredictionResponse(BaseModel): | |
| message: str | |
| data: Optional[dict] = None | |
| # =========================== | |
| # REAL-TIME RATE | |
| # =========================== | |
| def get_realtime_rate(base, target): | |
| url = f"{EXCHANGE_API_URL}/{EXCHANGE_API_KEY}/pair/{base}/{target}" | |
| r = requests.get(url, timeout=10) | |
| data = r.json() | |
| if data.get("result") != "success": | |
| raise ValueError("ExchangeRate API failed") | |
| return float(data["conversion_rate"]) | |
| # =========================== | |
| # HISTORICAL DATA (YAHOO) | |
| # =========================== | |
| def fetch_last_30_days(base, target): | |
| symbol = f"{base}{target}=X" | |
| df = yf.download( | |
| symbol, | |
| period="60d", | |
| interval="1d", | |
| progress=False | |
| ) | |
| if df.empty or len(df) < LOOKBACK: | |
| return None | |
| return df["Close"].dropna().tail(LOOKBACK).values.reshape(-1, 1) | |
| # =========================== | |
| # INFERENCE | |
| # =========================== | |
| def predict_forex(base, target, horizon): | |
| key = (base, target, horizon) | |
| pair_key = f"{base}_{target}" | |
| model = models.get(key) | |
| scaler = scalers.get(pair_key) | |
| if model is None or scaler is None: | |
| return None, None | |
| prices = fetch_last_30_days(base, target) | |
| if prices is None: | |
| return None, None | |
| scaled = scaler.transform(prices) | |
| X = create_input_sequence(scaled, LOOKBACK) | |
| X = torch.tensor(X, dtype=torch.float32) | |
| with torch.no_grad(): | |
| preds = model(X).numpy() | |
| preds = scaler.inverse_transform(preds.reshape(-1, 1)).flatten() | |
| return prices[-1][0], preds.tolist() | |
| # =========================== | |
| # ROUTE | |
| # =========================== | |
| def predict(req: PredictionRequest): | |
| current_price, preds = predict_forex( | |
| req.base_currency, | |
| req.target_currency, | |
| req.horizon | |
| ) | |
| if preds is None: | |
| return PredictionResponse( | |
| message="Prediction failed", | |
| data=None | |
| ) | |
| return PredictionResponse( | |
| message="Prediction success", | |
| data={ | |
| "base": req.base_currency, | |
| "target": req.target_currency, | |
| "horizon": req.horizon, | |
| "current_price": current_price, | |
| "predictions": preds | |
| } | |
| ) |