mrshibly commited on
Commit
7c8a504
·
verified ·
1 Parent(s): 073592c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -2
app.py CHANGED
@@ -6,12 +6,15 @@ import pickle
6
  import matplotlib.pyplot as plt
7
  import io
8
  from torch import nn
 
9
 
10
 
 
11
  with open("arima.pkl", "rb") as f:
12
  arima_model = pickle.load(f)
13
 
14
 
 
15
  class LSTMModel(nn.Module):
16
  def __init__(self, input_size=1, hidden_size=50, num_layers=2, output_size=1):
17
  super(LSTMModel, self).__init__()
@@ -25,16 +28,20 @@ class LSTMModel(nn.Module):
25
  out = self.fc(out[:, -1, :])
26
  return out
27
 
 
28
  # Load trained LSTM
29
  lstm_model = LSTMModel()
30
  lstm_model.load_state_dict(torch.load("lstm.pth", map_location=torch.device('cpu')))
31
  lstm_model.eval()
32
 
33
 
 
34
  def predict_arima(values, horizon=10):
35
  forecast = arima_model.forecast(steps=horizon)
36
  return forecast.tolist()
37
 
 
 
38
  def predict_lstm(values, horizon=10):
39
  seq = torch.tensor(values[-50:], dtype=torch.float32).view(1, -1, 1)
40
  preds = []
@@ -46,6 +53,7 @@ def predict_lstm(values, horizon=10):
46
  return preds
47
 
48
 
 
49
  def forecast(file, horizon, model_choice):
50
  df = pd.read_csv(file.name)
51
  if "Close" not in df.columns:
@@ -78,13 +86,17 @@ def forecast(file, horizon, model_choice):
78
  plt.ylabel("Price")
79
  plt.legend()
80
 
 
81
  buf = io.BytesIO()
82
  plt.savefig(buf, format="png")
83
  buf.seek(0)
84
-
85
- return forecast_df, buf
 
 
86
 
87
 
 
88
  with gr.Blocks() as demo:
89
  gr.Markdown("# 📈 Stock Price Forecasting Demo")
90
  gr.Markdown(
 
6
  import matplotlib.pyplot as plt
7
  import io
8
  from torch import nn
9
+ from PIL import Image
10
 
11
 
12
+ # Load ARIMA model
13
  with open("arima.pkl", "rb") as f:
14
  arima_model = pickle.load(f)
15
 
16
 
17
+ # Define LSTM Model
18
  class LSTMModel(nn.Module):
19
  def __init__(self, input_size=1, hidden_size=50, num_layers=2, output_size=1):
20
  super(LSTMModel, self).__init__()
 
28
  out = self.fc(out[:, -1, :])
29
  return out
30
 
31
+
32
  # Load trained LSTM
33
  lstm_model = LSTMModel()
34
  lstm_model.load_state_dict(torch.load("lstm.pth", map_location=torch.device('cpu')))
35
  lstm_model.eval()
36
 
37
 
38
+ # ARIMA Prediction
39
  def predict_arima(values, horizon=10):
40
  forecast = arima_model.forecast(steps=horizon)
41
  return forecast.tolist()
42
 
43
+
44
+ # LSTM Prediction
45
  def predict_lstm(values, horizon=10):
46
  seq = torch.tensor(values[-50:], dtype=torch.float32).view(1, -1, 1)
47
  preds = []
 
53
  return preds
54
 
55
 
56
+ # Forecast Function
57
  def forecast(file, horizon, model_choice):
58
  df = pd.read_csv(file.name)
59
  if "Close" not in df.columns:
 
86
  plt.ylabel("Price")
87
  plt.legend()
88
 
89
+ # Save plot to buffer and convert to PIL
90
  buf = io.BytesIO()
91
  plt.savefig(buf, format="png")
92
  buf.seek(0)
93
+ plt.close()
94
+ img = Image.open(buf)
95
+
96
+ return forecast_df, img
97
 
98
 
99
+ # Gradio UI
100
  with gr.Blocks() as demo:
101
  gr.Markdown("# 📈 Stock Price Forecasting Demo")
102
  gr.Markdown(