Shaikat01 commited on
Commit
42aee05
ยท
verified ยท
1 Parent(s): e434763

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -109
app.py CHANGED
@@ -4,83 +4,60 @@ import numpy as np
4
  import plotly.graph_objects as go
5
  from datetime import datetime, timedelta
6
  import pickle
7
- import os
8
- import warnings
9
- warnings.filterwarnings('ignore')
10
-
11
- # TensorFlow/Keras imports
12
- from tensorflow.keras.models import load_model
13
- from sklearn.preprocessing import MinMaxScaler
14
-
15
- # ARIMA and Prophet
16
  from statsmodels.tsa.arima.model import ARIMA
17
  from prophet import Prophet
 
 
 
 
18
 
19
- # --------------------------
20
- # Load models safely
21
- # --------------------------
22
  def load_models():
 
23
  try:
24
- # ARIMA
25
  with open('arima_model.pkl', 'rb') as f:
26
  arima_model = pickle.load(f)
27
-
28
- # Prophet
29
  with open('prophet_model.pkl', 'rb') as f:
30
  prophet_model = pickle.load(f)
31
-
32
- # LSTM model
33
- lstm_model = load_model('lstm_model.keras')
34
-
35
- # LSTM scaler
36
  with open('lstm_scaler.pkl', 'rb') as f:
37
  scaler = pickle.load(f)
38
-
39
  return arima_model, prophet_model, lstm_model, scaler
40
  except Exception as e:
41
  print(f"Error loading models: {e}")
42
  return None, None, None, None
43
 
44
-
45
  arima_model, prophet_model, lstm_model, scaler = load_models()
46
- SEQ_LENGTH = 60
47
 
48
- # --------------------------
49
- # Fetch stock data
50
- # --------------------------
51
  def fetch_stock_data(ticker, days=365):
52
- """
53
- Fetch stock data from local CSV fallback.
54
- Community Spaces cannot access the internet.
55
- """
56
- ticker = ticker.upper().strip()
57
- filename = f"{ticker}.csv"
58
- if not os.path.exists(filename):
59
- return None, f"No data found for {ticker}. Upload {ticker}.csv in the Space root."
60
-
61
- df = pd.read_csv(filename, index_col=0, parse_dates=True)
62
- if 'Close' in df.columns:
63
  df = df[['Close']].copy()
64
- else:
65
  df.columns = ['Price']
 
 
 
66
 
67
- df.columns = ['Price']
68
- df['Price'] = pd.to_numeric(df['Price'], errors='coerce')
69
- df = df.dropna()
70
- df = df.tail(days)
71
-
72
- if df.empty:
73
- return None, f"No valid data found in {filename} for {ticker}."
74
- return df, None
75
-
76
- # --------------------------
77
- # Forecasting functions
78
- # --------------------------
79
  def make_arima_forecast(data, days):
 
80
  try:
81
- data['Price'] = pd.to_numeric(data['Price'], errors='coerce')
82
- data = data.dropna()
83
- model = ARIMA(data['Price'], order=(1,1,1))
84
  fitted = model.fit()
85
  forecast = fitted.forecast(steps=days)
86
  return forecast.values
@@ -89,8 +66,15 @@ def make_arima_forecast(data, days):
89
  return None
90
 
91
  def make_prophet_forecast(data, days):
 
92
  try:
93
- prophet_data = pd.DataFrame({'ds': data.index, 'y': data['Price'].values})
 
 
 
 
 
 
94
  model = Prophet(
95
  daily_seasonality=True,
96
  weekly_seasonality=True,
@@ -98,6 +82,8 @@ def make_prophet_forecast(data, days):
98
  changepoint_prior_scale=0.05
99
  )
100
  model.fit(prophet_data)
 
 
101
  future = model.make_future_dataframe(periods=days)
102
  forecast = model.predict(future)
103
  return forecast['yhat'].tail(days).values
@@ -106,28 +92,38 @@ def make_prophet_forecast(data, days):
106
  return None
107
 
108
  def make_lstm_forecast(data, days, model, scaler, seq_length=60):
 
109
  try:
 
110
  scaled_data = scaler.transform(data[['Price']])
 
 
111
  last_sequence = scaled_data[-seq_length:].reshape(1, seq_length, 1)
112
-
113
  predictions = []
114
  current_sequence = last_sequence.copy()
 
 
115
  for _ in range(days):
116
  pred = model.predict(current_sequence, verbose=0)
117
- predictions.append(pred[0,0])
118
- current_sequence = np.append(current_sequence[:,1:,:], pred.reshape(1,1,1), axis=1)
119
-
120
- predictions = scaler.inverse_transform(np.array(predictions).reshape(-1,1))
 
 
 
 
121
  return predictions.flatten()
122
  except Exception as e:
123
  print(f"LSTM Error: {e}")
124
  return None
125
 
126
- # --------------------------
127
- # Plot function
128
- # --------------------------
129
  def create_forecast_plot(historical_data, forecasts, ticker, model_names):
 
130
  fig = go.Figure()
 
 
131
  fig.add_trace(go.Scatter(
132
  x=historical_data.index,
133
  y=historical_data['Price'],
@@ -135,13 +131,13 @@ def create_forecast_plot(historical_data, forecasts, ticker, model_names):
135
  name='Historical Price',
136
  line=dict(color='blue', width=2)
137
  ))
138
-
139
- last_date = pd.to_datetime(historical_data.index[-1])
140
- future_dates = pd.date_range(
141
- start=last_date + timedelta(days=1),
142
- periods=len(forecasts[0])
143
- )
144
-
145
  colors = ['red', 'purple', 'orange']
146
  for i, (forecast, name) in enumerate(zip(forecasts, model_names)):
147
  if forecast is not None:
@@ -153,7 +149,7 @@ def create_forecast_plot(historical_data, forecasts, ticker, model_names):
153
  line=dict(color=colors[i], width=2, dash='dash'),
154
  marker=dict(size=6)
155
  ))
156
-
157
  fig.update_layout(
158
  title=f'{ticker} Stock Price Forecast',
159
  xaxis_title='Date',
@@ -161,88 +157,169 @@ def create_forecast_plot(historical_data, forecasts, ticker, model_names):
161
  hovermode='x unified',
162
  template='plotly_white',
163
  height=600,
164
- showlegend=True
 
 
 
 
 
 
165
  )
 
166
  return fig
167
 
168
- # --------------------------
169
- # Main prediction function
170
- # --------------------------
171
  def predict_stock(ticker, forecast_days, model_choice):
 
 
172
  if not ticker:
173
  return None, "Please enter a stock ticker symbol", None
174
-
175
- data, error = fetch_stock_data(ticker, days=730)
 
 
 
176
  if error:
177
  return None, f"Error: {error}", None
178
-
 
179
  forecasts = []
180
  model_names = []
181
-
182
  if model_choice in ["All Models", "ARIMA"]:
183
  arima_forecast = make_arima_forecast(data, forecast_days)
184
  if arima_forecast is not None:
185
  forecasts.append(arima_forecast)
186
  model_names.append("ARIMA")
187
-
188
  if model_choice in ["All Models", "Prophet"]:
189
  prophet_forecast = make_prophet_forecast(data, forecast_days)
190
  if prophet_forecast is not None:
191
  forecasts.append(prophet_forecast)
192
  model_names.append("Prophet")
193
-
194
  if model_choice in ["All Models", "LSTM"] and lstm_model is not None:
195
  lstm_forecast = make_lstm_forecast(data, forecast_days, lstm_model, scaler, SEQ_LENGTH)
196
  if lstm_forecast is not None:
197
  forecasts.append(lstm_forecast)
198
  model_names.append("LSTM")
199
-
200
  if not forecasts:
201
- return None, "Failed to generate forecasts.", None
202
-
 
203
  fig = create_forecast_plot(data, forecasts, ticker, model_names)
204
-
205
- # Forecast table
206
  future_dates = pd.date_range(
207
- start=pd.to_datetime(data.index[-1]) + timedelta(days=1),
208
  periods=forecast_days
209
  )
 
210
  forecast_df = pd.DataFrame({'Date': future_dates.strftime('%Y-%m-%d')})
211
  for forecast, name in zip(forecasts, model_names):
212
  forecast_df[f'{name} Prediction ($)'] = np.round(forecast, 2)
213
-
214
- # Summary
215
- summary = f"๐Ÿ“Š **Forecast Summary for {ticker}**\n\n" \
216
- f"- Current Price: ${data['Price'].iloc[-1]:.2f}\n" \
217
- f"- Forecast Period: {forecast_days} days\n" \
218
- f"- Models Used: {', '.join(model_names)}\n\n" \
219
- f"**Predicted Price Range (Day {forecast_days}):**"
 
 
 
 
 
220
  for forecast, name in zip(forecasts, model_names):
221
  final_price = forecast[-1]
222
  change = ((final_price - data['Price'].iloc[-1]) / data['Price'].iloc[-1]) * 100
223
  summary += f"\n- {name}: ${final_price:.2f} ({change:+.2f}%)"
224
-
225
  return fig, summary, forecast_df
226
 
227
- # --------------------------
228
- # Gradio Interface
229
- # --------------------------
230
  with gr.Blocks(title="Stock Price Forecasting", theme=gr.themes.Soft()) as demo:
231
- gr.Markdown("# ๐Ÿ“ˆ Stock Price Forecasting App\nPredict future stock prices using ARIMA, Prophet, and LSTM models.\nUpload CSV files in the Space root for offline use.")
 
 
 
 
 
 
 
 
 
 
232
  with gr.Row():
233
  with gr.Column(scale=1):
234
- ticker_input = gr.Textbox(label="Stock Ticker Symbol", placeholder="e.g., AAPL", value="AAPL")
235
- forecast_days = gr.Slider(minimum=1, maximum=90, value=30, step=1, label="Forecast Days")
236
- model_choice = gr.Radio(choices=["All Models", "ARIMA", "Prophet", "LSTM"], value="All Models", label="Select Model(s)")
237
- predict_btn = gr.Button("๐Ÿ”ฎ Generate Forecast", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  with gr.Column(scale=2):
239
  output_plot = gr.Plot(label="Forecast Visualization")
240
- output_summary = gr.Markdown(label="Forecast Summary")
241
- output_table = gr.Dataframe(label="Detailed Forecast", interactive=False)
242
-
243
- predict_btn.click(fn=predict_stock, inputs=[ticker_input, forecast_days, model_choice],
244
- outputs=[output_plot, output_summary, output_table])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
- # Launch
247
  if __name__ == "__main__":
248
- demo.launch()
 
4
  import plotly.graph_objects as go
5
  from datetime import datetime, timedelta
6
  import pickle
7
+ import yfinance as yf
 
 
 
 
 
 
 
 
8
  from statsmodels.tsa.arima.model import ARIMA
9
  from prophet import Prophet
10
+ from tensorflow import keras
11
+ from sklearn.preprocessing import MinMaxScaler
12
+ import warnings
13
+ warnings.filterwarnings('ignore')
14
 
15
+ # Load your saved models (update paths as needed)
16
+ # For Hugging Face, these will be in the same directory as app.py
 
17
  def load_models():
18
+ """Load all three models"""
19
  try:
20
+ # Load ARIMA model
21
  with open('arima_model.pkl', 'rb') as f:
22
  arima_model = pickle.load(f)
23
+
24
+ # Load Prophet model
25
  with open('prophet_model.pkl', 'rb') as f:
26
  prophet_model = pickle.load(f)
27
+
28
+ # Load LSTM model and scaler
29
+ lstm_model = keras.models.load_model('lstm_model.h5')
 
 
30
  with open('lstm_scaler.pkl', 'rb') as f:
31
  scaler = pickle.load(f)
32
+
33
  return arima_model, prophet_model, lstm_model, scaler
34
  except Exception as e:
35
  print(f"Error loading models: {e}")
36
  return None, None, None, None
37
 
38
+ # Global variables for models
39
  arima_model, prophet_model, lstm_model, scaler = load_models()
40
+ SEQ_LENGTH = 60 # Should match your training
41
 
 
 
 
42
  def fetch_stock_data(ticker, days=365):
43
+ """Fetch stock data from Yahoo Finance"""
44
+ try:
45
+ end_date = datetime.now()
46
+ start_date = end_date - timedelta(days=days)
47
+ df = yf.download(ticker, start=start_date, end=end_date, progress=False)
48
+ if df.empty:
49
+ return None, f"No data found for ticker: {ticker}"
 
 
 
 
50
  df = df[['Close']].copy()
 
51
  df.columns = ['Price']
52
+ return df, None
53
+ except Exception as e:
54
+ return None, str(e)
55
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def make_arima_forecast(data, days):
57
+ """Make ARIMA forecast"""
58
  try:
59
+ # Retrain ARIMA with recent data (or use loaded model)
60
+ model = ARIMA(data['Price'], order=(1, 1, 1))
 
61
  fitted = model.fit()
62
  forecast = fitted.forecast(steps=days)
63
  return forecast.values
 
66
  return None
67
 
68
  def make_prophet_forecast(data, days):
69
+ """Make Prophet forecast"""
70
  try:
71
+ # Prepare data for Prophet
72
+ prophet_data = pd.DataFrame({
73
+ 'ds': data.index,
74
+ 'y': data['Price'].values
75
+ })
76
+
77
+ # Create and fit model
78
  model = Prophet(
79
  daily_seasonality=True,
80
  weekly_seasonality=True,
 
82
  changepoint_prior_scale=0.05
83
  )
84
  model.fit(prophet_data)
85
+
86
+ # Make forecast
87
  future = model.make_future_dataframe(periods=days)
88
  forecast = model.predict(future)
89
  return forecast['yhat'].tail(days).values
 
92
  return None
93
 
94
  def make_lstm_forecast(data, days, model, scaler, seq_length=60):
95
+ """Make LSTM forecast"""
96
  try:
97
+ # Scale the data
98
  scaled_data = scaler.transform(data[['Price']])
99
+
100
+ # Prepare the last sequence
101
  last_sequence = scaled_data[-seq_length:].reshape(1, seq_length, 1)
102
+
103
  predictions = []
104
  current_sequence = last_sequence.copy()
105
+
106
+ # Generate predictions day by day
107
  for _ in range(days):
108
  pred = model.predict(current_sequence, verbose=0)
109
+ predictions.append(pred[0, 0])
110
+
111
+ # Update sequence
112
+ current_sequence = np.append(current_sequence[:, 1:, :],
113
+ pred.reshape(1, 1, 1), axis=1)
114
+
115
+ # Inverse transform predictions
116
+ predictions = scaler.inverse_transform(np.array(predictions).reshape(-1, 1))
117
  return predictions.flatten()
118
  except Exception as e:
119
  print(f"LSTM Error: {e}")
120
  return None
121
 
 
 
 
122
  def create_forecast_plot(historical_data, forecasts, ticker, model_names):
123
+ """Create interactive plotly chart"""
124
  fig = go.Figure()
125
+
126
+ # Historical data
127
  fig.add_trace(go.Scatter(
128
  x=historical_data.index,
129
  y=historical_data['Price'],
 
131
  name='Historical Price',
132
  line=dict(color='blue', width=2)
133
  ))
134
+
135
+ # Generate future dates
136
+ last_date = historical_data.index[-1]
137
+ future_dates = pd.date_range(start=last_date + timedelta(days=1),
138
+ periods=len(forecasts[0]))
139
+
140
+ # Plot forecasts
141
  colors = ['red', 'purple', 'orange']
142
  for i, (forecast, name) in enumerate(zip(forecasts, model_names)):
143
  if forecast is not None:
 
149
  line=dict(color=colors[i], width=2, dash='dash'),
150
  marker=dict(size=6)
151
  ))
152
+
153
  fig.update_layout(
154
  title=f'{ticker} Stock Price Forecast',
155
  xaxis_title='Date',
 
157
  hovermode='x unified',
158
  template='plotly_white',
159
  height=600,
160
+ showlegend=True,
161
+ legend=dict(
162
+ yanchor="top",
163
+ y=0.99,
164
+ xanchor="left",
165
+ x=0.01
166
+ )
167
  )
168
+
169
  return fig
170
 
 
 
 
171
  def predict_stock(ticker, forecast_days, model_choice):
172
+ """Main prediction function"""
173
+ # Validate inputs
174
  if not ticker:
175
  return None, "Please enter a stock ticker symbol", None
176
+
177
+ ticker = ticker.upper().strip()
178
+
179
+ # Fetch data
180
+ data, error = fetch_stock_data(ticker, days=730) # 2 years of data
181
  if error:
182
  return None, f"Error: {error}", None
183
+
184
+ # Make forecasts based on model choice
185
  forecasts = []
186
  model_names = []
187
+
188
  if model_choice in ["All Models", "ARIMA"]:
189
  arima_forecast = make_arima_forecast(data, forecast_days)
190
  if arima_forecast is not None:
191
  forecasts.append(arima_forecast)
192
  model_names.append("ARIMA")
193
+
194
  if model_choice in ["All Models", "Prophet"]:
195
  prophet_forecast = make_prophet_forecast(data, forecast_days)
196
  if prophet_forecast is not None:
197
  forecasts.append(prophet_forecast)
198
  model_names.append("Prophet")
199
+
200
  if model_choice in ["All Models", "LSTM"] and lstm_model is not None:
201
  lstm_forecast = make_lstm_forecast(data, forecast_days, lstm_model, scaler, SEQ_LENGTH)
202
  if lstm_forecast is not None:
203
  forecasts.append(lstm_forecast)
204
  model_names.append("LSTM")
205
+
206
  if not forecasts:
207
+ return None, "Failed to generate forecasts. Please try again.", None
208
+
209
+ # Create plot
210
  fig = create_forecast_plot(data, forecasts, ticker, model_names)
211
+
212
+ # Create forecast table
213
  future_dates = pd.date_range(
214
+ start=data.index[-1] + timedelta(days=1),
215
  periods=forecast_days
216
  )
217
+
218
  forecast_df = pd.DataFrame({'Date': future_dates.strftime('%Y-%m-%d')})
219
  for forecast, name in zip(forecasts, model_names):
220
  forecast_df[f'{name} Prediction ($)'] = np.round(forecast, 2)
221
+
222
+ # Summary statistics
223
+ summary = f"""
224
+ ๐Ÿ“Š **Forecast Summary for {ticker}**
225
+
226
+ - Current Price: ${data['Price'].iloc[-1]:.2f}
227
+ - Forecast Period: {forecast_days} days
228
+ - Models Used: {', '.join(model_names)}
229
+
230
+ **Predicted Price Range (Day {forecast_days}):**
231
+ """
232
+
233
  for forecast, name in zip(forecasts, model_names):
234
  final_price = forecast[-1]
235
  change = ((final_price - data['Price'].iloc[-1]) / data['Price'].iloc[-1]) * 100
236
  summary += f"\n- {name}: ${final_price:.2f} ({change:+.2f}%)"
237
+
238
  return fig, summary, forecast_df
239
 
240
+ # Create Gradio Interface
 
 
241
  with gr.Blocks(title="Stock Price Forecasting", theme=gr.themes.Soft()) as demo:
242
+ gr.Markdown(
243
+ """
244
+ # ๐Ÿ“ˆ Stock Price Forecasting App
245
+
246
+ Predict future stock prices using ARIMA, Prophet, and LSTM models.
247
+ Enter a stock ticker symbol and select forecast parameters below.
248
+
249
+ **Note:** Predictions are for educational purposes only. Not financial advice.
250
+ """
251
+ )
252
+
253
  with gr.Row():
254
  with gr.Column(scale=1):
255
+ ticker_input = gr.Textbox(
256
+ label="Stock Ticker Symbol",
257
+ placeholder="e.g., AAPL, GOOGL, TSLA",
258
+ value="AAPL"
259
+ )
260
+
261
+ forecast_days = gr.Slider(
262
+ minimum=1,
263
+ maximum=90,
264
+ value=30,
265
+ step=1,
266
+ label="Forecast Days"
267
+ )
268
+
269
+ model_choice = gr.Radio(
270
+ choices=["All Models", "ARIMA", "Prophet", "LSTM"],
271
+ value="All Models",
272
+ label="Select Model(s)"
273
+ )
274
+
275
+ predict_btn = gr.Button("๐Ÿ”ฎ Generate Forecast", variant="primary", size="lg")
276
+
277
  with gr.Column(scale=2):
278
  output_plot = gr.Plot(label="Forecast Visualization")
279
+
280
+ with gr.Row():
281
+ output_summary = gr.Markdown(label="Forecast Summary")
282
+
283
+ with gr.Row():
284
+ output_table = gr.Dataframe(
285
+ label="Detailed Forecast",
286
+ wrap=True,
287
+ interactive=False
288
+ )
289
+
290
+ # Examples
291
+ gr.Examples(
292
+ examples=[
293
+ ["AAPL", 30, "All Models"],
294
+ ["GOOGL", 14, "Prophet"],
295
+ ["TSLA", 60, "LSTM"],
296
+ ["MSFT", 45, "ARIMA"],
297
+ ],
298
+ inputs=[ticker_input, forecast_days, model_choice],
299
+ )
300
+
301
+ # Connect the button to the function
302
+ predict_btn.click(
303
+ fn=predict_stock,
304
+ inputs=[ticker_input, forecast_days, model_choice],
305
+ outputs=[output_plot, output_summary, output_table]
306
+ )
307
+
308
+ gr.Markdown(
309
+ """
310
+ ---
311
+ ### ๐Ÿ“š About the Models
312
+
313
+ - **ARIMA**: Statistical model for time series forecasting
314
+ - **Prophet**: Facebook's forecasting tool, excellent for seasonality
315
+ - **LSTM**: Deep learning model that captures complex patterns
316
+
317
+ ### โš ๏ธ Disclaimer
318
+ This tool is for educational and research purposes only. Stock market predictions are inherently uncertain.
319
+ Always conduct thorough research and consult with financial advisors before making investment decisions.
320
+ """
321
+ )
322
 
323
+ # Launch the app
324
  if __name__ == "__main__":
325
+ demo.launch()