Shaikat01 commited on
Commit
f19cf54
·
verified ·
1 Parent(s): 03e4f68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -184
app.py CHANGED
@@ -4,69 +4,80 @@ import numpy as np
4
  import plotly.graph_objects as go
5
  from datetime import datetime, timedelta
6
  import pickle
7
- import yfinance as yf
8
- from statsmodels.tsa.arima.model import ARIMA
9
- from prophet import Prophet
10
- from tensorflow import keras
11
- from sklearn.preprocessing import MinMaxScaler
12
  import warnings
13
  warnings.filterwarnings('ignore')
14
- import os
15
- from datetime import datetime, timedelta
16
 
17
- # Load your saved models (update paths as needed)
18
- # For Hugging Face, these will be in the same directory as app.py
 
 
 
 
 
 
 
 
 
19
  def load_models():
20
- """Load all three models"""
21
  try:
22
- # Load ARIMA model
23
  with open('arima_model.pkl', 'rb') as f:
24
  arima_model = pickle.load(f)
25
-
26
- # Load Prophet model
27
  with open('prophet_model.pkl', 'rb') as f:
28
  prophet_model = pickle.load(f)
29
-
30
- # Load LSTM model and scaler
31
- lstm_model = keras.models.load_model('lstm_model.h5')
32
  with open('lstm_scaler.pkl', 'rb') as f:
33
  scaler = pickle.load(f)
34
-
35
  return arima_model, prophet_model, lstm_model, scaler
36
  except Exception as e:
37
  print(f"Error loading models: {e}")
38
  return None, None, None, None
39
 
40
- # Global variables for models
41
  arima_model, prophet_model, lstm_model, scaler = load_models()
42
- SEQ_LENGTH = 60 # Should match your training
43
-
44
 
 
 
 
45
  def fetch_stock_data(ticker, days=365):
46
- ticker = ticker.upper()
 
 
 
 
47
  filename = f"{ticker}.csv"
48
  if not os.path.exists(filename):
49
- return None, f"No data found for ticker {ticker}. Upload {ticker}.csv to use offline mode."
50
 
51
  df = pd.read_csv(filename, index_col=0, parse_dates=True)
52
  if 'Close' in df.columns:
53
  df = df[['Close']].copy()
54
  else:
55
  df.columns = ['Price']
56
- df.columns = ['Price']
57
 
 
 
58
  df = df.dropna()
59
  df = df.tail(days)
 
60
  if df.empty:
61
- return None, f"No data available in {filename}"
62
  return df, None
63
 
64
-
 
 
65
  def make_arima_forecast(data, days):
66
- """Make ARIMA forecast"""
67
  try:
68
- # Retrain ARIMA with recent data (or use loaded model)
69
- model = ARIMA(data['Price'], order=(1, 1, 1))
 
70
  fitted = model.fit()
71
  forecast = fitted.forecast(steps=days)
72
  return forecast.values
@@ -75,15 +86,8 @@ def make_arima_forecast(data, days):
75
  return None
76
 
77
  def make_prophet_forecast(data, days):
78
- """Make Prophet forecast"""
79
  try:
80
- # Prepare data for Prophet
81
- prophet_data = pd.DataFrame({
82
- 'ds': data.index,
83
- 'y': data['Price'].values
84
- })
85
-
86
- # Create and fit model
87
  model = Prophet(
88
  daily_seasonality=True,
89
  weekly_seasonality=True,
@@ -91,8 +95,6 @@ def make_prophet_forecast(data, days):
91
  changepoint_prior_scale=0.05
92
  )
93
  model.fit(prophet_data)
94
-
95
- # Make forecast
96
  future = model.make_future_dataframe(periods=days)
97
  forecast = model.predict(future)
98
  return forecast['yhat'].tail(days).values
@@ -101,38 +103,28 @@ def make_prophet_forecast(data, days):
101
  return None
102
 
103
  def make_lstm_forecast(data, days, model, scaler, seq_length=60):
104
- """Make LSTM forecast"""
105
  try:
106
- # Scale the data
107
  scaled_data = scaler.transform(data[['Price']])
108
-
109
- # Prepare the last sequence
110
  last_sequence = scaled_data[-seq_length:].reshape(1, seq_length, 1)
111
-
112
  predictions = []
113
  current_sequence = last_sequence.copy()
114
-
115
- # Generate predictions day by day
116
  for _ in range(days):
117
  pred = model.predict(current_sequence, verbose=0)
118
- predictions.append(pred[0, 0])
119
-
120
- # Update sequence
121
- current_sequence = np.append(current_sequence[:, 1:, :],
122
- pred.reshape(1, 1, 1), axis=1)
123
-
124
- # Inverse transform predictions
125
- predictions = scaler.inverse_transform(np.array(predictions).reshape(-1, 1))
126
  return predictions.flatten()
127
  except Exception as e:
128
  print(f"LSTM Error: {e}")
129
  return None
130
 
 
 
 
131
  def create_forecast_plot(historical_data, forecasts, ticker, model_names):
132
- """Create interactive plotly chart"""
133
  fig = go.Figure()
134
-
135
- # Historical data
136
  fig.add_trace(go.Scatter(
137
  x=historical_data.index,
138
  y=historical_data['Price'],
@@ -140,13 +132,13 @@ def create_forecast_plot(historical_data, forecasts, ticker, model_names):
140
  name='Historical Price',
141
  line=dict(color='blue', width=2)
142
  ))
143
-
144
- # Generate future dates
145
- last_date = historical_data.index[-1]
146
- future_dates = pd.date_range(start=last_date + timedelta(days=1),
147
- periods=len(forecasts[0]))
148
-
149
- # Plot forecasts
150
  colors = ['red', 'purple', 'orange']
151
  for i, (forecast, name) in enumerate(zip(forecasts, model_names)):
152
  if forecast is not None:
@@ -158,7 +150,7 @@ def create_forecast_plot(historical_data, forecasts, ticker, model_names):
158
  line=dict(color=colors[i], width=2, dash='dash'),
159
  marker=dict(size=6)
160
  ))
161
-
162
  fig.update_layout(
163
  title=f'{ticker} Stock Price Forecast',
164
  xaxis_title='Date',
@@ -166,169 +158,88 @@ def create_forecast_plot(historical_data, forecasts, ticker, model_names):
166
  hovermode='x unified',
167
  template='plotly_white',
168
  height=600,
169
- showlegend=True,
170
- legend=dict(
171
- yanchor="top",
172
- y=0.99,
173
- xanchor="left",
174
- x=0.01
175
- )
176
  )
177
-
178
  return fig
179
 
 
 
 
180
  def predict_stock(ticker, forecast_days, model_choice):
181
- """Main prediction function"""
182
- # Validate inputs
183
  if not ticker:
184
  return None, "Please enter a stock ticker symbol", None
185
-
186
- ticker = ticker.upper().strip()
187
-
188
- # Fetch data
189
- data, error = fetch_stock_data(ticker, days=730) # 2 years of data
190
  if error:
191
  return None, f"Error: {error}", None
192
-
193
- # Make forecasts based on model choice
194
  forecasts = []
195
  model_names = []
196
-
197
  if model_choice in ["All Models", "ARIMA"]:
198
  arima_forecast = make_arima_forecast(data, forecast_days)
199
  if arima_forecast is not None:
200
  forecasts.append(arima_forecast)
201
  model_names.append("ARIMA")
202
-
203
  if model_choice in ["All Models", "Prophet"]:
204
  prophet_forecast = make_prophet_forecast(data, forecast_days)
205
  if prophet_forecast is not None:
206
  forecasts.append(prophet_forecast)
207
  model_names.append("Prophet")
208
-
209
  if model_choice in ["All Models", "LSTM"] and lstm_model is not None:
210
  lstm_forecast = make_lstm_forecast(data, forecast_days, lstm_model, scaler, SEQ_LENGTH)
211
  if lstm_forecast is not None:
212
  forecasts.append(lstm_forecast)
213
  model_names.append("LSTM")
214
-
215
  if not forecasts:
216
- return None, "Failed to generate forecasts. Please try again.", None
217
-
218
- # Create plot
219
  fig = create_forecast_plot(data, forecasts, ticker, model_names)
220
-
221
- # Create forecast table
222
  future_dates = pd.date_range(
223
- start=data.index[-1] + timedelta(days=1),
224
  periods=forecast_days
225
  )
226
-
227
  forecast_df = pd.DataFrame({'Date': future_dates.strftime('%Y-%m-%d')})
228
  for forecast, name in zip(forecasts, model_names):
229
  forecast_df[f'{name} Prediction ($)'] = np.round(forecast, 2)
230
-
231
- # Summary statistics
232
- summary = f"""
233
- 📊 **Forecast Summary for {ticker}**
234
-
235
- - Current Price: ${data['Price'].iloc[-1]:.2f}
236
- - Forecast Period: {forecast_days} days
237
- - Models Used: {', '.join(model_names)}
238
-
239
- **Predicted Price Range (Day {forecast_days}):**
240
- """
241
-
242
  for forecast, name in zip(forecasts, model_names):
243
  final_price = forecast[-1]
244
  change = ((final_price - data['Price'].iloc[-1]) / data['Price'].iloc[-1]) * 100
245
  summary += f"\n- {name}: ${final_price:.2f} ({change:+.2f}%)"
246
-
247
  return fig, summary, forecast_df
248
 
249
- # Create Gradio Interface
 
 
250
  with gr.Blocks(title="Stock Price Forecasting", theme=gr.themes.Soft()) as demo:
251
- gr.Markdown(
252
- """
253
- # 📈 Stock Price Forecasting App
254
-
255
- Predict future stock prices using ARIMA, Prophet, and LSTM models.
256
- Enter a stock ticker symbol and select forecast parameters below.
257
-
258
- **Note:** Predictions are for educational purposes only. Not financial advice.
259
- """
260
- )
261
-
262
  with gr.Row():
263
  with gr.Column(scale=1):
264
- ticker_input = gr.Textbox(
265
- label="Stock Ticker Symbol",
266
- placeholder="e.g., AAPL, GOOGL, TSLA",
267
- value="AAPL"
268
- )
269
-
270
- forecast_days = gr.Slider(
271
- minimum=1,
272
- maximum=90,
273
- value=30,
274
- step=1,
275
- label="Forecast Days"
276
- )
277
-
278
- model_choice = gr.Radio(
279
- choices=["All Models", "ARIMA", "Prophet", "LSTM"],
280
- value="All Models",
281
- label="Select Model(s)"
282
- )
283
-
284
- predict_btn = gr.Button("🔮 Generate Forecast", variant="primary", size="lg")
285
-
286
  with gr.Column(scale=2):
287
  output_plot = gr.Plot(label="Forecast Visualization")
288
-
289
- with gr.Row():
290
- output_summary = gr.Markdown(label="Forecast Summary")
291
-
292
- with gr.Row():
293
- output_table = gr.Dataframe(
294
- label="Detailed Forecast",
295
- wrap=True,
296
- interactive=False
297
- )
298
-
299
- # Examples
300
- gr.Examples(
301
- examples=[
302
- ["AAPL", 30, "All Models"],
303
- ["GOOGL", 14, "Prophet"],
304
- ["TSLA", 60, "LSTM"],
305
- ["MSFT", 45, "ARIMA"],
306
- ],
307
- inputs=[ticker_input, forecast_days, model_choice],
308
- )
309
-
310
- # Connect the button to the function
311
- predict_btn.click(
312
- fn=predict_stock,
313
- inputs=[ticker_input, forecast_days, model_choice],
314
- outputs=[output_plot, output_summary, output_table]
315
- )
316
-
317
- gr.Markdown(
318
- """
319
- ---
320
- ### 📚 About the Models
321
-
322
- - **ARIMA**: Statistical model for time series forecasting
323
- - **Prophet**: Facebook's forecasting tool, excellent for seasonality
324
- - **LSTM**: Deep learning model that captures complex patterns
325
-
326
- ### ⚠️ Disclaimer
327
- This tool is for educational and research purposes only. Stock market predictions are inherently uncertain.
328
- Always conduct thorough research and consult with financial advisors before making investment decisions.
329
- """
330
- )
331
 
332
- # Launch the app
333
  if __name__ == "__main__":
334
- demo.launch()
 
4
  import plotly.graph_objects as go
5
  from datetime import datetime, timedelta
6
  import pickle
7
+ import os
 
 
 
 
8
  import warnings
9
  warnings.filterwarnings('ignore')
 
 
10
 
11
+ # TensorFlow/Keras imports
12
+ from tensorflow.keras.models import load_model
13
+ from sklearn.preprocessing import MinMaxScaler
14
+
15
+ # ARIMA and Prophet
16
+ from statsmodels.tsa.arima.model import ARIMA
17
+ from prophet import Prophet
18
+
19
+ # --------------------------
20
+ # Load models safely
21
+ # --------------------------
22
  def load_models():
 
23
  try:
24
+ # ARIMA
25
  with open('arima_model.pkl', 'rb') as f:
26
  arima_model = pickle.load(f)
27
+
28
+ # Prophet
29
  with open('prophet_model.pkl', 'rb') as f:
30
  prophet_model = pickle.load(f)
31
+
32
+ # LSTM + scaler
33
+ lstm_model = load_model('lstm_model') # use SavedModel folder
34
  with open('lstm_scaler.pkl', 'rb') as f:
35
  scaler = pickle.load(f)
36
+
37
  return arima_model, prophet_model, lstm_model, scaler
38
  except Exception as e:
39
  print(f"Error loading models: {e}")
40
  return None, None, None, None
41
 
 
42
  arima_model, prophet_model, lstm_model, scaler = load_models()
43
+ SEQ_LENGTH = 60
 
44
 
45
+ # --------------------------
46
+ # Fetch stock data
47
+ # --------------------------
48
  def fetch_stock_data(ticker, days=365):
49
+ """
50
+ Fetch stock data from local CSV fallback.
51
+ Community Spaces cannot access the internet.
52
+ """
53
+ ticker = ticker.upper().strip()
54
  filename = f"{ticker}.csv"
55
  if not os.path.exists(filename):
56
+ return None, f"No data found for {ticker}. Upload {ticker}.csv in the Space root."
57
 
58
  df = pd.read_csv(filename, index_col=0, parse_dates=True)
59
  if 'Close' in df.columns:
60
  df = df[['Close']].copy()
61
  else:
62
  df.columns = ['Price']
 
63
 
64
+ df.columns = ['Price']
65
+ df['Price'] = pd.to_numeric(df['Price'], errors='coerce')
66
  df = df.dropna()
67
  df = df.tail(days)
68
+
69
  if df.empty:
70
+ return None, f"No valid data found in {filename} for {ticker}."
71
  return df, None
72
 
73
+ # --------------------------
74
+ # Forecasting functions
75
+ # --------------------------
76
  def make_arima_forecast(data, days):
 
77
  try:
78
+ data['Price'] = pd.to_numeric(data['Price'], errors='coerce')
79
+ data = data.dropna()
80
+ model = ARIMA(data['Price'], order=(1,1,1))
81
  fitted = model.fit()
82
  forecast = fitted.forecast(steps=days)
83
  return forecast.values
 
86
  return None
87
 
88
  def make_prophet_forecast(data, days):
 
89
  try:
90
+ prophet_data = pd.DataFrame({'ds': data.index, 'y': data['Price'].values})
 
 
 
 
 
 
91
  model = Prophet(
92
  daily_seasonality=True,
93
  weekly_seasonality=True,
 
95
  changepoint_prior_scale=0.05
96
  )
97
  model.fit(prophet_data)
 
 
98
  future = model.make_future_dataframe(periods=days)
99
  forecast = model.predict(future)
100
  return forecast['yhat'].tail(days).values
 
103
  return None
104
 
105
  def make_lstm_forecast(data, days, model, scaler, seq_length=60):
 
106
  try:
 
107
  scaled_data = scaler.transform(data[['Price']])
 
 
108
  last_sequence = scaled_data[-seq_length:].reshape(1, seq_length, 1)
109
+
110
  predictions = []
111
  current_sequence = last_sequence.copy()
 
 
112
  for _ in range(days):
113
  pred = model.predict(current_sequence, verbose=0)
114
+ predictions.append(pred[0,0])
115
+ current_sequence = np.append(current_sequence[:,1:,:], pred.reshape(1,1,1), axis=1)
116
+
117
+ predictions = scaler.inverse_transform(np.array(predictions).reshape(-1,1))
 
 
 
 
118
  return predictions.flatten()
119
  except Exception as e:
120
  print(f"LSTM Error: {e}")
121
  return None
122
 
123
+ # --------------------------
124
+ # Plot function
125
+ # --------------------------
126
  def create_forecast_plot(historical_data, forecasts, ticker, model_names):
 
127
  fig = go.Figure()
 
 
128
  fig.add_trace(go.Scatter(
129
  x=historical_data.index,
130
  y=historical_data['Price'],
 
132
  name='Historical Price',
133
  line=dict(color='blue', width=2)
134
  ))
135
+
136
+ last_date = pd.to_datetime(historical_data.index[-1])
137
+ future_dates = pd.date_range(
138
+ start=last_date + timedelta(days=1),
139
+ periods=len(forecasts[0])
140
+ )
141
+
142
  colors = ['red', 'purple', 'orange']
143
  for i, (forecast, name) in enumerate(zip(forecasts, model_names)):
144
  if forecast is not None:
 
150
  line=dict(color=colors[i], width=2, dash='dash'),
151
  marker=dict(size=6)
152
  ))
153
+
154
  fig.update_layout(
155
  title=f'{ticker} Stock Price Forecast',
156
  xaxis_title='Date',
 
158
  hovermode='x unified',
159
  template='plotly_white',
160
  height=600,
161
+ showlegend=True
 
 
 
 
 
 
162
  )
 
163
  return fig
164
 
165
+ # --------------------------
166
+ # Main prediction function
167
+ # --------------------------
168
  def predict_stock(ticker, forecast_days, model_choice):
 
 
169
  if not ticker:
170
  return None, "Please enter a stock ticker symbol", None
171
+
172
+ data, error = fetch_stock_data(ticker, days=730)
 
 
 
173
  if error:
174
  return None, f"Error: {error}", None
175
+
 
176
  forecasts = []
177
  model_names = []
178
+
179
  if model_choice in ["All Models", "ARIMA"]:
180
  arima_forecast = make_arima_forecast(data, forecast_days)
181
  if arima_forecast is not None:
182
  forecasts.append(arima_forecast)
183
  model_names.append("ARIMA")
184
+
185
  if model_choice in ["All Models", "Prophet"]:
186
  prophet_forecast = make_prophet_forecast(data, forecast_days)
187
  if prophet_forecast is not None:
188
  forecasts.append(prophet_forecast)
189
  model_names.append("Prophet")
190
+
191
  if model_choice in ["All Models", "LSTM"] and lstm_model is not None:
192
  lstm_forecast = make_lstm_forecast(data, forecast_days, lstm_model, scaler, SEQ_LENGTH)
193
  if lstm_forecast is not None:
194
  forecasts.append(lstm_forecast)
195
  model_names.append("LSTM")
196
+
197
  if not forecasts:
198
+ return None, "Failed to generate forecasts.", None
199
+
 
200
  fig = create_forecast_plot(data, forecasts, ticker, model_names)
201
+
202
+ # Forecast table
203
  future_dates = pd.date_range(
204
+ start=pd.to_datetime(data.index[-1]) + timedelta(days=1),
205
  periods=forecast_days
206
  )
 
207
  forecast_df = pd.DataFrame({'Date': future_dates.strftime('%Y-%m-%d')})
208
  for forecast, name in zip(forecasts, model_names):
209
  forecast_df[f'{name} Prediction ($)'] = np.round(forecast, 2)
210
+
211
+ # Summary
212
+ summary = f"📊 **Forecast Summary for {ticker}**\n\n" \
213
+ f"- Current Price: ${data['Price'].iloc[-1]:.2f}\n" \
214
+ f"- Forecast Period: {forecast_days} days\n" \
215
+ f"- Models Used: {', '.join(model_names)}\n\n" \
216
+ f"**Predicted Price Range (Day {forecast_days}):**"
 
 
 
 
 
217
  for forecast, name in zip(forecasts, model_names):
218
  final_price = forecast[-1]
219
  change = ((final_price - data['Price'].iloc[-1]) / data['Price'].iloc[-1]) * 100
220
  summary += f"\n- {name}: ${final_price:.2f} ({change:+.2f}%)"
221
+
222
  return fig, summary, forecast_df
223
 
224
+ # --------------------------
225
+ # Gradio Interface
226
+ # --------------------------
227
  with gr.Blocks(title="Stock Price Forecasting", theme=gr.themes.Soft()) as demo:
228
+ gr.Markdown("# 📈 Stock Price Forecasting App\nPredict future stock prices using ARIMA, Prophet, and LSTM models.\nUpload CSV files in the Space root for offline use.")
 
 
 
 
 
 
 
 
 
 
229
  with gr.Row():
230
  with gr.Column(scale=1):
231
+ ticker_input = gr.Textbox(label="Stock Ticker Symbol", placeholder="e.g., AAPL", value="AAPL")
232
+ forecast_days = gr.Slider(minimum=1, maximum=90, value=30, step=1, label="Forecast Days")
233
+ model_choice = gr.Radio(choices=["All Models", "ARIMA", "Prophet", "LSTM"], value="All Models", label="Select Model(s)")
234
+ predict_btn = gr.Button("🔮 Generate Forecast", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  with gr.Column(scale=2):
236
  output_plot = gr.Plot(label="Forecast Visualization")
237
+ output_summary = gr.Markdown(label="Forecast Summary")
238
+ output_table = gr.Dataframe(label="Detailed Forecast", interactive=False)
239
+
240
+ predict_btn.click(fn=predict_stock, inputs=[ticker_input, forecast_days, model_choice],
241
+ outputs=[output_plot, output_summary, output_table])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ # Launch
244
  if __name__ == "__main__":
245
+ demo.launch()