Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| import traceback | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| REPO_ID = "Volko76/stock-prediction-lstm" | |
| MODEL_FILE = "best_stock_model.pth" | |
| SCALER_FILE = "scaler.pkl" | |
| FEATURES = [ | |
| "Returns", | |
| "HighLowRatio", | |
| "CloseOpenRatio", | |
| "MA_5", | |
| "MA_20", | |
| "Volatility_20", | |
| "VolumeChange", | |
| "RSI", | |
| ] | |
| MIN_SEQ_LEN = 20 | |
| # ---------- Utilities to read shapes from the checkpoint ---------- | |
| def _infer_lstm_params(sd: dict): | |
| """ | |
| Infer input_size, hidden_size, num_layers from LSTM tensors. | |
| weight_ih_l{k}: (4*hidden, input) | |
| weight_hh_l{k}: (4*hidden, hidden) | |
| """ | |
| # find all layers present | |
| layer_ids = [] | |
| for k in sd.keys(): | |
| if k.startswith("lstm.weight_ih_l"): | |
| layer_ids.append(int(k.split("l")[-1])) | |
| if not layer_ids: | |
| raise ValueError("No LSTM weights found in checkpoint.") | |
| num_layers = max(layer_ids) + 1 | |
| # assume layer 0 exists | |
| w_ih = sd["lstm.weight_ih_l0"] | |
| w_hh = sd["lstm.weight_hh_l0"] | |
| hidden_size = w_hh.shape[1] | |
| input_size = w_ih.shape[1] | |
| return int(input_size), int(hidden_size), int(num_layers) | |
| def _infer_fc_sizes(sd: dict): | |
| """ | |
| Detect fc1, fc2, fc3 (and bias) shapes to build matching Linear layers. | |
| Shapes are (out_features, in_features). | |
| """ | |
| # Support either 1, 2, or 3 FC layers by checking keys | |
| fcs = [] | |
| for name in ["fc1", "fc2", "fc3", "fc"]: | |
| wkey = f"{name}.weight" | |
| bkey = f"{name}.bias" | |
| if wkey in sd and bkey in sd: | |
| out_f, in_f = sd[wkey].shape | |
| fcs.append((name, int(in_f), int(out_f))) | |
| if not fcs: | |
| raise ValueError("No FC layers (fc/fc1/fc2/fc3) found in checkpoint.") | |
| # sort by natural order: fc1 -> fc2 -> fc3 -> fc | |
| order = {"fc1": 1, "fc2": 2, "fc3": 3, "fc": 99} | |
| fcs.sort(key=lambda t: order.get(t[0], 50)) | |
| return fcs | |
| # ---------- Model that we construct to match checkpoint ---------- | |
| class StockPredictorDynamic(nn.Module): | |
| def __init__(self, input_size, hidden_size, num_layers, fc_specs): | |
| """ | |
| fc_specs: list of tuples [(name, in_features, out_features), ...] in order. | |
| We create Linear layers with the exact in/out sizes. | |
| """ | |
| super().__init__() | |
| self.lstm = nn.LSTM( | |
| input_size=input_size, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| batch_first=True | |
| ) | |
| # create dynamically-named FC layers | |
| self.fc_names = [] | |
| for name, in_f, out_f in fc_specs: | |
| setattr(self, name, nn.Linear(in_f, out_f)) | |
| self.fc_names.append(name) | |
| self.activation = nn.ReLU() | |
| def forward(self, x): | |
| # x: (batch, seq, input_size) | |
| y, _ = self.lstm(x) | |
| last = y[:, -1, :] # (batch, hidden) | |
| # Pipe through FC stack in order | |
| h = last | |
| for i, name in enumerate(self.fc_names): | |
| layer = getattr(self, name) | |
| h = layer(h) | |
| # apply ReLU on all but last layer if there are multiple FCs | |
| if i < len(self.fc_names) - 1: | |
| h = self.activation(h) | |
| return h # shape (batch, out_features_of_last_fc) | |
| # ---------- Lazy globals ---------- | |
| _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| _model = None | |
| _scaler = None | |
| def _load_artifacts(): | |
| global _model, _scaler | |
| if _model is not None and _scaler is not None: | |
| return _model, _scaler | |
| # Download | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE) | |
| scaler_path = hf_hub_download(repo_id=REPO_ID, filename=SCALER_FILE) | |
| # Load scaler with pinned sklearn (1.6.1) | |
| with open(scaler_path, "rb") as f: | |
| _scaler = pickle.load(f) | |
| # Load weights (CPU) | |
| state = torch.load(model_path, map_location="cpu") | |
| state_dict = state["state_dict"] if isinstance(state, dict) and "state_dict" in state else state | |
| # Infer architecture from weights | |
| in_size, hidden_size, num_layers = _infer_lstm_params(state_dict) | |
| fc_specs = _infer_fc_sizes(state_dict) | |
| # If the first FC expects something other than hidden_size as in_features, | |
| # it probably used a different pooling; but commonly it's hidden_size. | |
| # We trust the checkpoint's declared in_features. | |
| model = StockPredictorDynamic(in_size, hidden_size, num_layers, fc_specs) | |
| # Load state dict strictly (now that shapes match) | |
| model.load_state_dict(state_dict, strict=True) | |
| model.to(_device) | |
| model.eval() | |
| _model = model | |
| return _model, _scaler | |
| # ---------- Pre/post ---------- | |
| def _sanitize_dataframe(df: pd.DataFrame) -> pd.DataFrame: | |
| missing = [c for c in FEATURES if c not in df.columns] | |
| if missing: | |
| raise gr.Error(f"Missing columns: {missing}. Expected: {FEATURES}") | |
| df = df[FEATURES].apply(pd.to_numeric, errors="coerce").astype("float32") | |
| df = df.dropna(axis=0, how="any") | |
| if len(df) < MIN_SEQ_LEN: | |
| raise gr.Error(f"Sequence too short after cleaning (got {len(df)} rows). Need ≥ {MIN_SEQ_LEN} rows.") | |
| return df | |
| def _to_batch(seq_2d: np.ndarray) -> torch.Tensor: | |
| batch = np.expand_dims(seq_2d, axis=0).astype("float32", copy=False) | |
| return torch.from_numpy(batch).to(_device) | |
| # ---------- Inference endpoints ---------- | |
| def predict_from_table(df: pd.DataFrame): | |
| try: | |
| if df is None or len(df) == 0: | |
| raise gr.Error("Upload or paste a table first.") | |
| if "Date" in df.columns: | |
| df = df.sort_values("Date") | |
| df = _sanitize_dataframe(df) | |
| seq = df.to_numpy() # (seq_len, 8) | |
| model, scaler = _load_artifacts() | |
| seq_scaled = scaler.transform(seq).astype("float32") | |
| xt = _to_batch(seq_scaled) | |
| with torch.no_grad(): | |
| pred = model(xt).squeeze().cpu().numpy().item() | |
| sign = "+" if pred >= 0 else "" | |
| return {"pred_7d_return_percent": float(pred), "pretty": f"{sign}{pred:.2f}%"} | |
| except gr.Error: | |
| raise | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| raise gr.Error(f"Unexpected error during prediction: {e}\n{tb}") | |
| def predict_from_csv(file_obj): | |
| try: | |
| df = pd.read_csv(file_obj.name) | |
| except Exception as e: | |
| raise gr.Error(f"CSV parse error: {e}") | |
| return predict_from_table(df) | |
| # ---------- UI ---------- | |
| EXPLAIN = ( | |
| "Upload a CSV or paste a table (oldest → newest rows) with these columns:\n\n" | |
| f"`{', '.join(FEATURES)}`\n\n" | |
| "Output is the predicted **7-day return (%)**. Educational use only." | |
| ) | |
| with gr.Blocks(title="Stock Prediction LSTM (PyTorch)") as demo: | |
| gr.Markdown("# 📈 Stock Price Prediction (LSTM)\n" + EXPLAIN) | |
| with gr.Tab("CSV upload"): | |
| csv_in = gr.File(file_types=[".csv"], label="Upload features CSV") | |
| btn1 = gr.Button("Predict") | |
| out1 = gr.JSON(label="Prediction") | |
| btn1.click(predict_from_csv, inputs=csv_in, outputs=out1) | |
| with gr.Tab("Paste table"): | |
| df_in = gr.Dataframe( | |
| headers=FEATURES, | |
| datatype=["number"] * len(FEATURES), | |
| row_count=(MIN_SEQ_LEN, "dynamic"), | |
| col_count=(len(FEATURES), "fixed"), | |
| label="Time series features (oldest → newest)" | |
| ) | |
| btn2 = gr.Button("Predict") | |
| out2 = gr.JSON(label="Prediction") | |
| btn2.click(predict_from_table, inputs=df_in, outputs=out2) | |
| if __name__ == "__main__": | |
| demo.launch() | |