StockSenseSpace / inference.py
amitke
.
457b70c
import os, json, pickle
from datetime import datetime, timedelta
import numpy as np
import yfinance as yf
import torch
from models import StockLSTM
os.environ["CUDA_VISIBLE_DEVICES"] = ""
ARTIFACTS_DIR = "artifacts"
device = torch.device("cpu")
def _paths(symbol: str):
base = os.path.join(ARTIFACTS_DIR, symbol.upper())
return {
"model": os.path.join(base, "model.pt"),
"scaler": os.path.join(base, "scaler.pkl"),
"meta": os.path.join(base, "meta.json"),
}
def _load_artifacts(symbol: str):
p = _paths(symbol)
if not (os.path.exists(p["model"]) and os.path.exists(p["scaler"])):
raise FileNotFoundError(f"Model/scaler not found for {symbol}. Train first.")
with open(p["meta"], "r") as f:
meta = json.load(f)
with open(p["scaler"], "rb") as f:
scaler = pickle.load(f)
model = StockLSTM(input_dim=1, hidden_dim=64, num_layers=2, dropout=0.2)
model = model.to(device)
model.load_state_dict(torch.load(p["model"], map_location="cpu"))
model.eval()
return model, scaler, meta
def _last_close_series(symbol: str, days: int = 400):
end = datetime.utcnow().date()
start = end - timedelta(days=days)
df = yf.download(symbol, start=start.isoformat(), end=end.isoformat(), progress=False, auto_adjust=True)
if df.empty:
raise ValueError(f"No data for {symbol}")
return df["Close"].values.reshape(-1, 1)
@torch.no_grad()
def predict_next(symbol: str, n_days: int = 1):
model, scaler, meta = _load_artifacts(symbol)
seq_len = meta["seq_len"]
closes = _last_close_series(symbol, days=max(400, seq_len*5))
scaled = scaler.transform(closes)
# seed window
window = scaled[-seq_len:].reshape(1, seq_len, 1).astype(np.float32)
window_t = torch.from_numpy(window)
preds_scaled = []
for _ in range(n_days):
yhat = model(window_t).numpy() # [1,1] in scaled space
preds_scaled.append(yhat[0, 0])
# roll
next_window = np.concatenate([window[:, 1:, :], yhat.reshape(1, 1, 1)], axis=1)
window = next_window
window_t = torch.from_numpy(window.astype(np.float32))
preds_scaled = np.array(preds_scaled, dtype=np.float32).reshape(-1, 1)
preds_unscaled = scaler.inverse_transform(preds_scaled).flatten().tolist()
return {"symbol": symbol.upper(), "days": n_days, "predictions": preds_unscaled, "seq_len": seq_len, "meta": meta}