import os, json, pickle from datetime import datetime, timedelta import numpy as np import yfinance as yf import torch from models import StockLSTM os.environ["CUDA_VISIBLE_DEVICES"] = "" ARTIFACTS_DIR = "artifacts" device = torch.device("cpu") def _paths(symbol: str): base = os.path.join(ARTIFACTS_DIR, symbol.upper()) return { "model": os.path.join(base, "model.pt"), "scaler": os.path.join(base, "scaler.pkl"), "meta": os.path.join(base, "meta.json"), } def _load_artifacts(symbol: str): p = _paths(symbol) if not (os.path.exists(p["model"]) and os.path.exists(p["scaler"])): raise FileNotFoundError(f"Model/scaler not found for {symbol}. Train first.") with open(p["meta"], "r") as f: meta = json.load(f) with open(p["scaler"], "rb") as f: scaler = pickle.load(f) model = StockLSTM(input_dim=1, hidden_dim=64, num_layers=2, dropout=0.2) model = model.to(device) model.load_state_dict(torch.load(p["model"], map_location="cpu")) model.eval() return model, scaler, meta def _last_close_series(symbol: str, days: int = 400): end = datetime.utcnow().date() start = end - timedelta(days=days) df = yf.download(symbol, start=start.isoformat(), end=end.isoformat(), progress=False, auto_adjust=True) if df.empty: raise ValueError(f"No data for {symbol}") return df["Close"].values.reshape(-1, 1) @torch.no_grad() def predict_next(symbol: str, n_days: int = 1): model, scaler, meta = _load_artifacts(symbol) seq_len = meta["seq_len"] closes = _last_close_series(symbol, days=max(400, seq_len*5)) scaled = scaler.transform(closes) # seed window window = scaled[-seq_len:].reshape(1, seq_len, 1).astype(np.float32) window_t = torch.from_numpy(window) preds_scaled = [] for _ in range(n_days): yhat = model(window_t).numpy() # [1,1] in scaled space preds_scaled.append(yhat[0, 0]) # roll next_window = np.concatenate([window[:, 1:, :], yhat.reshape(1, 1, 1)], axis=1) window = next_window window_t = torch.from_numpy(window.astype(np.float32)) preds_scaled = np.array(preds_scaled, dtype=np.float32).reshape(-1, 1) preds_unscaled = scaler.inverse_transform(preds_scaled).flatten().tolist() return {"symbol": symbol.upper(), "days": n_days, "predictions": preds_unscaled, "seq_len": seq_len, "meta": meta}