Spaces:
Sleeping
Sleeping
File size: 3,331 Bytes
ac3e445 1160e34 ac3e445 879011a 1160e34 ac3e445 1160e34 ac3e445 1160e34 ac3e445 1160e34 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | import gradio as gr
import pandas as pd
import numpy as np
import torch
import pickle
import matplotlib.pyplot as plt
import io
from torch import nn
with open("arima.pkl", "rb") as f:
arima_model = pickle.load(f)
class LSTMModel(nn.Module):
def __init__(self, input_size=1, hidden_size=50, num_layers=2, output_size=1):
super(LSTMModel, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(2, x.size(0), 50)
c0 = torch.zeros(2, x.size(0), 50)
out, _ = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :])
return out
# Load trained LSTM
lstm_model = LSTMModel()
lstm_model.load_state_dict(torch.load("lstm.pth", map_location=torch.device('cpu')))
lstm_model.eval()
def predict_arima(values, horizon=10):
forecast = arima_model.forecast(steps=horizon)
return forecast.tolist()
def predict_lstm(values, horizon=10):
seq = torch.tensor(values[-50:], dtype=torch.float32).view(1, -1, 1)
preds = []
for _ in range(horizon):
with torch.no_grad():
pred = lstm_model(seq).item()
preds.append(pred)
seq = torch.cat([seq[:, 1:, :], torch.tensor([[[pred]]])], dim=1)
return preds
def forecast(file, horizon, model_choice):
df = pd.read_csv(file.name)
if "Close" not in df.columns:
return "❌ CSV must contain a 'Close' column", None
values = df["Close"].values.tolist()
# Run forecasts
preds_arima = predict_arima(values, horizon)
preds_lstm = predict_lstm(values, horizon)
# Prepare DataFrames
future_index = [f"t+{i+1}" for i in range(horizon)]
forecast_df = pd.DataFrame({
"Future": future_index,
"ARIMA Forecast": preds_arima,
"LSTM Forecast": preds_lstm
})
# Plot
plt.figure(figsize=(10,5))
plt.plot(range(len(values)), values, label="Historical")
if model_choice in ["ARIMA", "Compare Both"]:
plt.plot(range(len(values), len(values)+horizon), preds_arima, label="ARIMA Forecast")
if model_choice in ["LSTM", "Compare Both"]:
plt.plot(range(len(values), len(values)+horizon), preds_lstm, label="LSTM Forecast")
plt.title(f"{model_choice} Stock Forecast")
plt.xlabel("Time")
plt.ylabel("Price")
plt.legend()
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
return forecast_df, buf
with gr.Blocks() as demo:
gr.Markdown("# 📈 Stock Price Forecasting Demo")
gr.Markdown(
"Upload a CSV containing stock prices (must have a **'Close'** column). "
"Choose ARIMA, LSTM, or Compare Both, then set forecast horizon."
)
with gr.Row():
file = gr.File(label="Upload CSV", file_types=[".csv"])
horizon = gr.Slider(5, 30, value=10, step=1, label="Forecast Horizon (days)")
model_choice = gr.Radio(["ARIMA", "LSTM", "Compare Both"], label="Model", value="Compare Both")
output_table = gr.DataFrame(label="Forecasted Prices")
output_plot = gr.Image(type="pil", label="Forecast Plot")
submit = gr.Button("Run Forecast")
submit.click(forecast, inputs=[file, horizon, model_choice], outputs=[output_table, output_plot])
demo.launch()
|