Shaikat01 commited on
Commit
650f909
·
verified ·
1 Parent(s): d1e0d1d

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -406
app.py DELETED
@@ -1,406 +0,0 @@
1
- import gradio as gr
2
- import pandas as pd
3
- import numpy as np
4
- import plotly.graph_objects as go
5
- from datetime import datetime, timedelta
6
- import yfinance as yf
7
- from statsmodels.tsa.arima.model import ARIMA
8
- from prophet import Prophet
9
- import warnings
10
- warnings.filterwarnings('ignore')
11
-
12
- # NO PRE-TRAINED MODELS - Train on demand with user's data
13
- # This avoids the 50GB storage limit issue
14
-
15
- def fetch_stock_data(ticker, days=730):
16
- """Fetch stock data from Yahoo Finance"""
17
- try:
18
- end_date = datetime.now()
19
- start_date = end_date - timedelta(days=days)
20
- df = yf.download(ticker, start=start_date, end=end_date, progress=False)
21
- if df.empty:
22
- return None, f"No data found for ticker: {ticker}"
23
- df = df[['Close']].copy()
24
- df.columns = ['Price']
25
- df = df.dropna()
26
- return df, None
27
- except Exception as e:
28
- return None, str(e)
29
-
30
- def make_arima_forecast(data, days):
31
- """Train ARIMA and make forecast"""
32
- try:
33
- # Train ARIMA model on-the-fly
34
- model = ARIMA(data['Price'], order=(1, 1, 1))
35
- fitted = model.fit()
36
- forecast = fitted.forecast(steps=days)
37
- return forecast.values
38
- except Exception as e:
39
- print(f"ARIMA Error: {e}")
40
- return None
41
-
42
- def make_prophet_forecast(data, days):
43
- """Train Prophet and make forecast"""
44
- try:
45
- # Prepare data for Prophet
46
- prophet_data = pd.DataFrame({
47
- 'ds': data.index,
48
- 'y': data['Price'].values
49
- })
50
-
51
- # Create and train model on-the-fly
52
- model = Prophet(
53
- daily_seasonality=False,
54
- weekly_seasonality=True,
55
- yearly_seasonality=True,
56
- changepoint_prior_scale=0.05,
57
- seasonality_mode='multiplicative'
58
- )
59
- model.fit(prophet_data)
60
-
61
- # Make forecast
62
- future = model.make_future_dataframe(periods=days)
63
- forecast = model.predict(future)
64
- return forecast['yhat'].tail(days).values
65
- except Exception as e:
66
- print(f"Prophet Error: {e}")
67
- return None
68
-
69
- def make_simple_ml_forecast(data, days):
70
- """Simple exponential smoothing forecast (lightweight alternative to LSTM)"""
71
- try:
72
- from statsmodels.tsa.holtwinters import ExponentialSmoothing
73
-
74
- # Train exponential smoothing model
75
- model = ExponentialSmoothing(
76
- data['Price'],
77
- seasonal_periods=30,
78
- trend='add',
79
- seasonal='add'
80
- )
81
- fitted = model.fit()
82
- forecast = fitted.forecast(steps=days)
83
- return forecast.values
84
- except Exception as e:
85
- print(f"ML Forecast Error: {e}")
86
- return None
87
-
88
- def calculate_moving_average_forecast(data, days, window=20):
89
- """Simple moving average forecast"""
90
- try:
91
- ma = data['Price'].rolling(window=window).mean().iloc[-1]
92
- trend = (data['Price'].iloc[-1] - data['Price'].iloc[-window]) / window
93
- forecast = [ma + trend * i for i in range(1, days + 1)]
94
- return np.array(forecast)
95
- except Exception as e:
96
- print(f"MA Error: {e}")
97
- return None
98
-
99
- def create_forecast_plot(historical_data, forecasts, ticker, model_names):
100
- """Create interactive plotly chart"""
101
- fig = go.Figure()
102
-
103
- # Show last 90 days of historical data for clarity
104
- recent_data = historical_data.tail(90)
105
-
106
- # Historical data
107
- fig.add_trace(go.Scatter(
108
- x=recent_data.index,
109
- y=recent_data['Price'],
110
- mode='lines',
111
- name='Historical Price',
112
- line=dict(color='blue', width=2)
113
- ))
114
-
115
- # Generate future dates
116
- last_date = historical_data.index[-1]
117
- future_dates = pd.date_range(start=last_date + timedelta(days=1),
118
- periods=len(forecasts[0]))
119
-
120
- # Plot forecasts
121
- colors = ['red', 'purple', 'orange', 'green']
122
- for i, (forecast, name) in enumerate(zip(forecasts, model_names)):
123
- if forecast is not None:
124
- fig.add_trace(go.Scatter(
125
- x=future_dates,
126
- y=forecast,
127
- mode='lines+markers',
128
- name=f'{name} Forecast',
129
- line=dict(color=colors[i], width=2, dash='dash'),
130
- marker=dict(size=4)
131
- ))
132
-
133
- # Add vertical line at prediction start
134
- fig.add_vline(
135
- x=last_date,
136
- line_dash="dash",
137
- line_color="gray",
138
- annotation_text="Forecast Start"
139
- )
140
-
141
- fig.update_layout(
142
- title=f'{ticker} Stock Price Forecast',
143
- xaxis_title='Date',
144
- yaxis_title='Price ($)',
145
- hovermode='x unified',
146
- template='plotly_white',
147
- height=600,
148
- showlegend=True,
149
- legend=dict(
150
- yanchor="top",
151
- y=0.99,
152
- xanchor="left",
153
- x=0.01,
154
- bgcolor="rgba(255, 255, 255, 0.8)"
155
- )
156
- )
157
-
158
- return fig
159
-
160
- def predict_stock(ticker, forecast_days, model_choice):
161
- """Main prediction function"""
162
- # Validate inputs
163
- if not ticker:
164
- return None, "❌ Please enter a stock ticker symbol", None
165
-
166
- ticker = ticker.upper().strip()
167
-
168
- # Show loading message
169
- status_msg = f"🔄 Fetching data for {ticker}..."
170
-
171
- # Fetch data (2 years for better training)
172
- data, error = fetch_stock_data(ticker, days=730)
173
- if error:
174
- return None, f"❌ Error: {error}", None
175
-
176
- if len(data) < 60:
177
- return None, f"❌ Insufficient data for {ticker}. Need at least 60 days of history.", None
178
-
179
- status_msg += f"\n✅ Found {len(data)} days of data\n🔄 Training models..."
180
-
181
- # Make forecasts based on model choice
182
- forecasts = []
183
- model_names = []
184
-
185
- if model_choice in ["All Models", "ARIMA"]:
186
- arima_forecast = make_arima_forecast(data, forecast_days)
187
- if arima_forecast is not None:
188
- forecasts.append(arima_forecast)
189
- model_names.append("ARIMA")
190
-
191
- if model_choice in ["All Models", "Prophet"]:
192
- prophet_forecast = make_prophet_forecast(data, forecast_days)
193
- if prophet_forecast is not None:
194
- forecasts.append(prophet_forecast)
195
- model_names.append("Prophet")
196
-
197
- if model_choice in ["All Models", "Exp. Smoothing"]:
198
- ml_forecast = make_simple_ml_forecast(data, forecast_days)
199
- if ml_forecast is not None:
200
- forecasts.append(ml_forecast)
201
- model_names.append("Exp. Smoothing")
202
-
203
- if model_choice in ["All Models", "Moving Average"]:
204
- ma_forecast = calculate_moving_average_forecast(data, forecast_days)
205
- if ma_forecast is not None:
206
- forecasts.append(ma_forecast)
207
- model_names.append("Moving Average")
208
-
209
- if not forecasts:
210
- return None, "❌ Failed to generate forecasts. Please try again.", None
211
-
212
- # Create plot
213
- fig = create_forecast_plot(data, forecasts, ticker, model_names)
214
-
215
- # Create forecast table
216
- future_dates = pd.date_range(
217
- start=data.index[-1] + timedelta(days=1),
218
- periods=forecast_days
219
- )
220
-
221
- forecast_df = pd.DataFrame({'Date': future_dates.strftime('%Y-%m-%d')})
222
- for forecast, name in zip(forecasts, model_names):
223
- forecast_df[f'{name} ($)'] = np.round(forecast, 2)
224
-
225
- # Calculate statistics
226
- current_price = data['Price'].iloc[-1]
227
- avg_forecast = np.mean([f[-1] for f in forecasts])
228
- avg_change = ((avg_forecast - current_price) / current_price) * 100
229
-
230
- # Summary statistics
231
- summary = f"""
232
- ## 📊 Forecast Summary for **{ticker}**
233
-
234
- ### Current Information
235
- - **Current Price**: ${current_price:.2f}
236
- - **Data Points**: {len(data)} days
237
- - **Last Updated**: {data.index[-1].strftime('%Y-%m-%d')}
238
-
239
- ### Forecast Details
240
- - **Forecast Period**: {forecast_days} days
241
- - **Models Used**: {', '.join(model_names)}
242
- - **End Date**: {future_dates[-1].strftime('%Y-%m-%d')}
243
-
244
- ### Predicted Prices (Day {forecast_days})
245
- """
246
-
247
- for forecast, name in zip(forecasts, model_names):
248
- final_price = forecast[-1]
249
- change = ((final_price - current_price) / current_price) * 100
250
- emoji = "📈" if change > 0 else "📉"
251
- summary += f"\n{emoji} **{name}**: ${final_price:.2f} ({change:+.2f}%)"
252
-
253
- summary += f"""
254
-
255
- ### Average Prediction
256
- - **Average Price**: ${avg_forecast:.2f}
257
- - **Expected Change**: {avg_change:+.2f}%
258
-
259
- ---
260
- ⚠️ **Risk Warning**: Past performance does not guarantee future results. Use for research only.
261
- """
262
-
263
- return fig, summary, forecast_df
264
-
265
- # Create Gradio Interface
266
- with gr.Blocks(title="Stock Price Forecasting", theme=gr.themes.Soft()) as demo:
267
- gr.Markdown(
268
- """
269
- # 📈 AI Stock Price Forecasting
270
-
271
- ### Predict future stock prices using multiple time-series models
272
-
273
- This app trains models **in real-time** using the latest stock data. No pre-trained models needed!
274
-
275
- **✨ Features:**
276
- - Real-time data from Yahoo Finance
277
- - Multiple forecasting algorithms
278
- - Interactive visualizations
279
- - No storage limits - models train on demand
280
-
281
- ---
282
- """
283
- )
284
-
285
- with gr.Row():
286
- with gr.Column(scale=1):
287
- gr.Markdown("### 🎯 Input Parameters")
288
-
289
- ticker_input = gr.Textbox(
290
- label="📊 Stock Ticker Symbol",
291
- placeholder="e.g., AAPL, GOOGL, TSLA, MSFT",
292
- value="AAPL",
293
- info="Enter any valid stock ticker"
294
- )
295
-
296
- forecast_days = gr.Slider(
297
- minimum=7,
298
- maximum=90,
299
- value=30,
300
- step=1,
301
- label="📅 Forecast Period (Days)",
302
- info="Number of days to forecast"
303
- )
304
-
305
- model_choice = gr.Radio(
306
- choices=["All Models", "ARIMA", "Prophet", "Exp. Smoothing", "Moving Average"],
307
- value="All Models",
308
- label="🤖 Select Model(s)",
309
- info="Choose which forecasting model to use"
310
- )
311
-
312
- predict_btn = gr.Button(
313
- "🔮 Generate Forecast",
314
- variant="primary",
315
- size="lg",
316
- scale=1
317
- )
318
-
319
- gr.Markdown(
320
- """
321
- ### 💡 Quick Tips
322
- - Use 30 days for short-term
323
- - Use 60-90 days for trends
324
- - "All Models" shows comparison
325
- """
326
- )
327
-
328
- with gr.Column(scale=2):
329
- output_plot = gr.Plot(label="📈 Forecast Visualization")
330
-
331
- with gr.Row():
332
- with gr.Column():
333
- output_summary = gr.Markdown(label="📋 Analysis Summary")
334
-
335
- with gr.Row():
336
- output_table = gr.Dataframe(
337
- label="📊 Detailed Forecast Table",
338
- wrap=True,
339
- interactive=False,
340
- height=400
341
- )
342
-
343
- # Examples
344
- gr.Markdown("### 🎯 Try These Examples")
345
- gr.Examples(
346
- examples=[
347
- ["AAPL", 30, "All Models"],
348
- ["GOOGL", 14, "Prophet"],
349
- ["TSLA", 60, "ARIMA"],
350
- ["MSFT", 45, "Exp. Smoothing"],
351
- ["NVDA", 30, "All Models"],
352
- ],
353
- inputs=[ticker_input, forecast_days, model_choice],
354
- label="Popular Stocks"
355
- )
356
-
357
- # Connect the button to the function
358
- predict_btn.click(
359
- fn=predict_stock,
360
- inputs=[ticker_input, forecast_days, model_choice],
361
- outputs=[output_plot, output_summary, output_table]
362
- )
363
-
364
- gr.Markdown(
365
- """
366
- ---
367
- ## 📚 About the Models
368
-
369
- | Model | Best For | Speed | Accuracy |
370
- |-------|----------|-------|----------|
371
- | **ARIMA** | Short-term, stationary data | ⚡⚡⚡ Fast | ⭐⭐⭐ |
372
- | **Prophet** | Seasonality, trends | ⚡⚡ Medium | ⭐⭐⭐⭐ |
373
- | **Exp. Smoothing** | Smooth trends | ⚡⚡⚡ Fast | ⭐⭐⭐ |
374
- | **Moving Average** | Simple baseline | ⚡⚡⚡⚡ Very Fast | ⭐⭐ |
375
-
376
- ## ⚠️ Important Disclaimer
377
-
378
- **This tool is for educational and research purposes only.**
379
-
380
- - Stock predictions are inherently uncertain
381
- - Past performance ≠ future results
382
- - Always do your own research
383
- - Consult financial advisors before investing
384
- - Never invest more than you can afford to lose
385
-
386
- ## 🔒 Privacy & Data
387
-
388
- - No data is stored permanently
389
- - Models train fresh for each prediction
390
- - Stock data fetched from Yahoo Finance API
391
- - No personal information collected
392
-
393
- ---
394
-
395
- **Made with ❤️ using Gradio & Python**
396
- """
397
- )
398
-
399
- # Launch the app
400
- if __name__ == "__main__":
401
- demo.launch(
402
- share=False,
403
- show_error=True,
404
- server_name="0.0.0.0",
405
- server_port=7860
406
- )