mrshibly commited on
Commit
1160e34
·
verified ·
1 Parent(s): 879011a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -63
app.py CHANGED
@@ -1,74 +1,106 @@
1
  import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
- import numpy as np
5
- from sklearn.preprocessing import MinMaxScaler
6
  import pandas as pd
 
 
 
 
 
 
7
 
8
- # Load data and scaler
9
- df = pd.read_csv('HistoricalQuotes.csv')
10
- df['Date'] = pd.to_datetime(df['Date'], format='%m/%d/%Y')
11
- df = df.sort_index()
12
-
13
- # Find the closing price column
14
- possible_columns = [' Close/Last', 'Close', 'close', 'close_last']
15
- close_column = None
16
- for col in possible_columns:
17
- if col in df.columns:
18
- close_column = col
19
- break
20
-
21
- if close_column is None:
22
- raise KeyError(f"None of {possible_columns} found in columns: {list(df.columns)}")
23
 
24
- df = df[[close_column]].rename(columns={close_column: 'Close'})
25
- df['Close'] = df['Close'].replace({r'\$': ''}, regex=True).astype(float)
26
 
27
- scaler = MinMaxScaler(feature_range=(0, 1))
28
- scaler.fit(df['Close'].values.reshape(-1, 1))
29
 
30
- # Define LSTM model
31
  class LSTMModel(nn.Module):
32
- def __init__(self, input_size=1, hidden_size=50, num_layers=2, output_size=1, dropout=0.2):
33
- super().__init__()
34
- self.hidden_size = hidden_size
35
- self.num_layers = num_layers # Store num_layers as instance variable
36
- self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
37
  self.fc = nn.Linear(hidden_size, output_size)
38
 
39
  def forward(self, x):
40
- h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
41
- c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
42
  out, _ = self.lstm(x, (h0, c0))
43
- return self.fc(out[:, -1, :])
44
-
45
- # Load model
46
- model = LSTMModel()
47
- model.load_state_dict(torch.load('lstm_model.pth', map_location=torch.device('cpu')))
48
- model.eval()
49
-
50
- def forecast(past_prices, steps=30):
51
- try:
52
- prices = [float(x.strip().replace('$', '')) for x in past_prices.split(',')]
53
- if len(prices) < 60:
54
- return "Error: At least 60 prices required."
55
- prices_scaled = scaler.transform(np.array(prices).reshape(-1, 1))
56
- current_seq = torch.from_numpy(prices_scaled[-60:].reshape(1, 60, 1)).float()
57
- predictions = []
58
- for _ in range(steps):
59
- with torch.no_grad():
60
- pred_scaled = model(current_seq).item()
61
- predictions.append(pred_scaled)
62
- current_seq = torch.cat((current_seq[:, 1:, :], torch.tensor([[[pred_scaled]]]).float()), dim=1)
63
- predictions = scaler.inverse_transform(np.array(predictions).reshape(-1, 1)).flatten()
64
- return pd.DataFrame({'Forecast': predictions}).to_string()
65
- except Exception as e:
66
- return f"Error: {str(e)}"
67
-
68
- # Create Gradio interface
69
- iface = gr.Interface(
70
- fn=forecast,
71
- inputs=gr.Textbox(label="Past Prices (comma-separated, e.g., 273.36,273.52,...)"),
72
- outputs=gr.Textbox(label="Forecasted Prices")
73
- )
74
- iface.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
2
  import pandas as pd
3
+ import numpy as np
4
+ import torch
5
+ import pickle
6
+ import matplotlib.pyplot as plt
7
+ import io
8
+ from torch import nn
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ with open("arima.pkl", "rb") as f:
12
+ arima_model = pickle.load(f)
13
 
 
 
14
 
 
15
  class LSTMModel(nn.Module):
16
+ def __init__(self, input_size=1, hidden_size=50, num_layers=2, output_size=1):
17
+ super(LSTMModel, self).__init__()
18
+ self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
 
 
19
  self.fc = nn.Linear(hidden_size, output_size)
20
 
21
  def forward(self, x):
22
+ h0 = torch.zeros(2, x.size(0), 50)
23
+ c0 = torch.zeros(2, x.size(0), 50)
24
  out, _ = self.lstm(x, (h0, c0))
25
+ out = self.fc(out[:, -1, :])
26
+ return out
27
+
28
+ # Load trained LSTM
29
+ lstm_model = LSTMModel()
30
+ lstm_model.load_state_dict(torch.load("lstm.pth", map_location=torch.device('cpu')))
31
+ lstm_model.eval()
32
+
33
+
34
+ def predict_arima(values, horizon=10):
35
+ forecast = arima_model.forecast(steps=horizon)
36
+ return forecast.tolist()
37
+
38
+ def predict_lstm(values, horizon=10):
39
+ seq = torch.tensor(values[-50:], dtype=torch.float32).view(1, -1, 1)
40
+ preds = []
41
+ for _ in range(horizon):
42
+ with torch.no_grad():
43
+ pred = lstm_model(seq).item()
44
+ preds.append(pred)
45
+ seq = torch.cat([seq[:, 1:, :], torch.tensor([[[pred]]])], dim=1)
46
+ return preds
47
+
48
+
49
+ def forecast(file, horizon, model_choice):
50
+ df = pd.read_csv(file.name)
51
+ if "Close" not in df.columns:
52
+ return "❌ CSV must contain a 'Close' column", None
53
+
54
+ values = df["Close"].values.tolist()
55
+
56
+ # Run forecasts
57
+ preds_arima = predict_arima(values, horizon)
58
+ preds_lstm = predict_lstm(values, horizon)
59
+
60
+ # Prepare DataFrames
61
+ future_index = [f"t+{i+1}" for i in range(horizon)]
62
+ forecast_df = pd.DataFrame({
63
+ "Future": future_index,
64
+ "ARIMA Forecast": preds_arima,
65
+ "LSTM Forecast": preds_lstm
66
+ })
67
+
68
+ # Plot
69
+ plt.figure(figsize=(10,5))
70
+ plt.plot(range(len(values)), values, label="Historical")
71
+ if model_choice in ["ARIMA", "Compare Both"]:
72
+ plt.plot(range(len(values), len(values)+horizon), preds_arima, label="ARIMA Forecast")
73
+ if model_choice in ["LSTM", "Compare Both"]:
74
+ plt.plot(range(len(values), len(values)+horizon), preds_lstm, label="LSTM Forecast")
75
+
76
+ plt.title(f"{model_choice} Stock Forecast")
77
+ plt.xlabel("Time")
78
+ plt.ylabel("Price")
79
+ plt.legend()
80
+
81
+ buf = io.BytesIO()
82
+ plt.savefig(buf, format="png")
83
+ buf.seek(0)
84
+
85
+ return forecast_df, buf
86
+
87
+
88
+ with gr.Blocks() as demo:
89
+ gr.Markdown("# 📈 Stock Price Forecasting Demo")
90
+ gr.Markdown(
91
+ "Upload a CSV containing stock prices (must have a **'Close'** column). "
92
+ "Choose ARIMA, LSTM, or Compare Both, then set forecast horizon."
93
+ )
94
+
95
+ with gr.Row():
96
+ file = gr.File(label="Upload CSV", file_types=[".csv"])
97
+ horizon = gr.Slider(5, 30, value=10, step=1, label="Forecast Horizon (days)")
98
+ model_choice = gr.Radio(["ARIMA", "LSTM", "Compare Both"], label="Model", value="Compare Both")
99
+
100
+ output_table = gr.DataFrame(label="Forecasted Prices")
101
+ output_plot = gr.Image(type="pil", label="Forecast Plot")
102
+
103
+ submit = gr.Button("Run Forecast")
104
+ submit.click(forecast, inputs=[file, horizon, model_choice], outputs=[output_table, output_plot])
105
+
106
+ demo.launch()