Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from datetime import datetime, timedelta | |
| import yfinance as yf | |
| from statsmodels.tsa.arima.model import ARIMA | |
| from prophet import Prophet | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # NO PRE-TRAINED MODELS - Train on demand with user's data | |
| # This avoids the 50GB storage limit issue | |
| def fetch_stock_data(ticker, days=730): | |
| """Fetch stock data from Yahoo Finance""" | |
| try: | |
| end_date = datetime.now() | |
| start_date = end_date - timedelta(days=days) | |
| df = yf.download(ticker, start=start_date, end=end_date, progress=False) | |
| if df.empty: | |
| return None, f"No data found for ticker: {ticker}" | |
| df = df[['Close']].copy() | |
| df.columns = ['Price'] | |
| df = df.dropna() | |
| return df, None | |
| except Exception as e: | |
| return None, str(e) | |
| def make_arima_forecast(data, days): | |
| """Train ARIMA and make forecast""" | |
| try: | |
| # Train ARIMA model on-the-fly | |
| model = ARIMA(data['Price'], order=(1, 1, 1)) | |
| fitted = model.fit() | |
| forecast = fitted.forecast(steps=days) | |
| return forecast.values | |
| except Exception as e: | |
| print(f"ARIMA Error: {e}") | |
| return None | |
| def make_prophet_forecast(data, days): | |
| """Train Prophet and make forecast""" | |
| try: | |
| # Prepare data for Prophet | |
| prophet_data = pd.DataFrame({ | |
| 'ds': data.index, | |
| 'y': data['Price'].values | |
| }) | |
| # Create and train model on-the-fly | |
| model = Prophet( | |
| daily_seasonality=False, | |
| weekly_seasonality=True, | |
| yearly_seasonality=True, | |
| changepoint_prior_scale=0.05, | |
| seasonality_mode='multiplicative' | |
| ) | |
| model.fit(prophet_data) | |
| # Make forecast | |
| future = model.make_future_dataframe(periods=days) | |
| forecast = model.predict(future) | |
| return forecast['yhat'].tail(days).values | |
| except Exception as e: | |
| print(f"Prophet Error: {e}") | |
| return None | |
| def make_simple_ml_forecast(data, days): | |
| """Simple exponential smoothing forecast (lightweight alternative to LSTM)""" | |
| try: | |
| from statsmodels.tsa.holtwinters import ExponentialSmoothing | |
| # Train exponential smoothing model | |
| model = ExponentialSmoothing( | |
| data['Price'], | |
| seasonal_periods=30, | |
| trend='add', | |
| seasonal='add' | |
| ) | |
| fitted = model.fit() | |
| forecast = fitted.forecast(steps=days) | |
| return forecast.values | |
| except Exception as e: | |
| print(f"ML Forecast Error: {e}") | |
| return None | |
| def calculate_moving_average_forecast(data, days, window=20): | |
| """Simple moving average forecast""" | |
| try: | |
| ma = data['Price'].rolling(window=window).mean().iloc[-1] | |
| trend = (data['Price'].iloc[-1] - data['Price'].iloc[-window]) / window | |
| forecast = [ma + trend * i for i in range(1, days + 1)] | |
| return np.array(forecast) | |
| except Exception as e: | |
| print(f"MA Error: {e}") | |
| return None | |
| def create_forecast_plot(historical_data, forecasts, ticker, model_names): | |
| """Create interactive plotly chart""" | |
| fig = go.Figure() | |
| # Show last 90 days of historical data for clarity | |
| recent_data = historical_data.tail(90) | |
| # Historical data | |
| fig.add_trace(go.Scatter( | |
| x=recent_data.index, | |
| y=recent_data['Price'], | |
| mode='lines', | |
| name='Historical Price', | |
| line=dict(color='blue', width=2) | |
| )) | |
| # Generate future dates | |
| last_date = historical_data.index[-1] | |
| future_dates = pd.date_range(start=last_date + timedelta(days=1), | |
| periods=len(forecasts[0])) | |
| # Plot forecasts | |
| colors = ['red', 'purple', 'orange', 'green'] | |
| for i, (forecast, name) in enumerate(zip(forecasts, model_names)): | |
| if forecast is not None: | |
| fig.add_trace(go.Scatter( | |
| x=future_dates, | |
| y=forecast, | |
| mode='lines+markers', | |
| name=f'{name} Forecast', | |
| line=dict(color=colors[i], width=2, dash='dash'), | |
| marker=dict(size=4) | |
| )) | |
| # Add vertical line at prediction start | |
| fig.add_vline( | |
| x=last_date, | |
| line_dash="dash", | |
| line_color="gray", | |
| annotation_text="Forecast Start" | |
| ) | |
| fig.update_layout( | |
| title=f'{ticker} Stock Price Forecast', | |
| xaxis_title='Date', | |
| yaxis_title='Price ($)', | |
| hovermode='x unified', | |
| template='plotly_white', | |
| height=600, | |
| showlegend=True, | |
| legend=dict( | |
| yanchor="top", | |
| y=0.99, | |
| xanchor="left", | |
| x=0.01, | |
| bgcolor="rgba(255, 255, 255, 0.8)" | |
| ) | |
| ) | |
| return fig | |
| def predict_stock(ticker, forecast_days, model_choice): | |
| """Main prediction function""" | |
| # Validate inputs | |
| if not ticker: | |
| return None, "❌ Please enter a stock ticker symbol", None | |
| ticker = ticker.upper().strip() | |
| # Show loading message | |
| status_msg = f"🔄 Fetching data for {ticker}..." | |
| # Fetch data (2 years for better training) | |
| data, error = fetch_stock_data(ticker, days=730) | |
| if error: | |
| return None, f"❌ Error: {error}", None | |
| if len(data) < 60: | |
| return None, f"❌ Insufficient data for {ticker}. Need at least 60 days of history.", None | |
| status_msg += f"\n✅ Found {len(data)} days of data\n🔄 Training models..." | |
| # Make forecasts based on model choice | |
| forecasts = [] | |
| model_names = [] | |
| if model_choice in ["All Models", "ARIMA"]: | |
| arima_forecast = make_arima_forecast(data, forecast_days) | |
| if arima_forecast is not None: | |
| forecasts.append(arima_forecast) | |
| model_names.append("ARIMA") | |
| if model_choice in ["All Models", "Prophet"]: | |
| prophet_forecast = make_prophet_forecast(data, forecast_days) | |
| if prophet_forecast is not None: | |
| forecasts.append(prophet_forecast) | |
| model_names.append("Prophet") | |
| if model_choice in ["All Models", "Exp. Smoothing"]: | |
| ml_forecast = make_simple_ml_forecast(data, forecast_days) | |
| if ml_forecast is not None: | |
| forecasts.append(ml_forecast) | |
| model_names.append("Exp. Smoothing") | |
| if model_choice in ["All Models", "Moving Average"]: | |
| ma_forecast = calculate_moving_average_forecast(data, forecast_days) | |
| if ma_forecast is not None: | |
| forecasts.append(ma_forecast) | |
| model_names.append("Moving Average") | |
| if not forecasts: | |
| return None, "❌ Failed to generate forecasts. Please try again.", None | |
| # Create plot | |
| fig = create_forecast_plot(data, forecasts, ticker, model_names) | |
| # Create forecast table | |
| future_dates = pd.date_range( | |
| start=data.index[-1] + timedelta(days=1), | |
| periods=forecast_days | |
| ) | |
| forecast_df = pd.DataFrame({'Date': future_dates.strftime('%Y-%m-%d')}) | |
| for forecast, name in zip(forecasts, model_names): | |
| forecast_df[f'{name} ($)'] = np.round(forecast, 2) | |
| # Calculate statistics | |
| current_price = data['Price'].iloc[-1] | |
| avg_forecast = np.mean([f[-1] for f in forecasts]) | |
| avg_change = ((avg_forecast - current_price) / current_price) * 100 | |
| # Summary statistics | |
| summary = f""" | |
| ## 📊 Forecast Summary for **{ticker}** | |
| ### Current Information | |
| - **Current Price**: ${current_price:.2f} | |
| - **Data Points**: {len(data)} days | |
| - **Last Updated**: {data.index[-1].strftime('%Y-%m-%d')} | |
| ### Forecast Details | |
| - **Forecast Period**: {forecast_days} days | |
| - **Models Used**: {', '.join(model_names)} | |
| - **End Date**: {future_dates[-1].strftime('%Y-%m-%d')} | |
| ### Predicted Prices (Day {forecast_days}) | |
| """ | |
| for forecast, name in zip(forecasts, model_names): | |
| final_price = forecast[-1] | |
| change = ((final_price - current_price) / current_price) * 100 | |
| emoji = "📈" if change > 0 else "📉" | |
| summary += f"\n{emoji} **{name}**: ${final_price:.2f} ({change:+.2f}%)" | |
| summary += f""" | |
| ### Average Prediction | |
| - **Average Price**: ${avg_forecast:.2f} | |
| - **Expected Change**: {avg_change:+.2f}% | |
| --- | |
| ⚠️ **Risk Warning**: Past performance does not guarantee future results. Use for research only. | |
| """ | |
| return fig, summary, forecast_df | |
| # Create Gradio Interface | |
| with gr.Blocks(title="Stock Price Forecasting", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 📈 AI Stock Price Forecasting | |
| ### Predict future stock prices using multiple time-series models | |
| This app trains models **in real-time** using the latest stock data. No pre-trained models needed! | |
| **✨ Features:** | |
| - Real-time data from Yahoo Finance | |
| - Multiple forecasting algorithms | |
| - Interactive visualizations | |
| - No storage limits - models train on demand | |
| --- | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 🎯 Input Parameters") | |
| ticker_input = gr.Textbox( | |
| label="📊 Stock Ticker Symbol", | |
| placeholder="e.g., AAPL, GOOGL, TSLA, MSFT", | |
| value="AAPL", | |
| info="Enter any valid stock ticker" | |
| ) | |
| forecast_days = gr.Slider( | |
| minimum=7, | |
| maximum=90, | |
| value=30, | |
| step=1, | |
| label="📅 Forecast Period (Days)", | |
| info="Number of days to forecast" | |
| ) | |
| model_choice = gr.Radio( | |
| choices=["All Models", "ARIMA", "Prophet", "Exp. Smoothing", "Moving Average"], | |
| value="All Models", | |
| label="🤖 Select Model(s)", | |
| info="Choose which forecasting model to use" | |
| ) | |
| predict_btn = gr.Button( | |
| "🔮 Generate Forecast", | |
| variant="primary", | |
| size="lg", | |
| scale=1 | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### 💡 Quick Tips | |
| - Use 30 days for short-term | |
| - Use 60-90 days for trends | |
| - "All Models" shows comparison | |
| """ | |
| ) | |
| with gr.Column(scale=2): | |
| output_plot = gr.Plot(label="📈 Forecast Visualization") | |
| with gr.Row(): | |
| with gr.Column(): | |
| output_summary = gr.Markdown(label="📋 Analysis Summary") | |
| with gr.Row(): | |
| output_table = gr.Dataframe( | |
| label="📊 Detailed Forecast Table", | |
| wrap=True, | |
| interactive=False, | |
| height=400 | |
| ) | |
| # Examples | |
| gr.Markdown("### 🎯 Try These Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["AAPL", 30, "All Models"], | |
| ["GOOGL", 14, "Prophet"], | |
| ["TSLA", 60, "ARIMA"], | |
| ["MSFT", 45, "Exp. Smoothing"], | |
| ["NVDA", 30, "All Models"], | |
| ], | |
| inputs=[ticker_input, forecast_days, model_choice], | |
| label="Popular Stocks" | |
| ) | |
| # Connect the button to the function | |
| predict_btn.click( | |
| fn=predict_stock, | |
| inputs=[ticker_input, forecast_days, model_choice], | |
| outputs=[output_plot, output_summary, output_table] | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ## 📚 About the Models | |
| | Model | Best For | Speed | Accuracy | | |
| |-------|----------|-------|----------| | |
| | **ARIMA** | Short-term, stationary data | ⚡⚡⚡ Fast | ⭐⭐⭐ | | |
| | **Prophet** | Seasonality, trends | ⚡⚡ Medium | ⭐⭐⭐⭐ | | |
| | **Exp. Smoothing** | Smooth trends | ⚡⚡⚡ Fast | ⭐⭐⭐ | | |
| | **Moving Average** | Simple baseline | ⚡⚡⚡⚡ Very Fast | ⭐⭐ | | |
| ## ⚠️ Important Disclaimer | |
| **This tool is for educational and research purposes only.** | |
| - Stock predictions are inherently uncertain | |
| - Past performance ≠ future results | |
| - Always do your own research | |
| - Consult financial advisors before investing | |
| - Never invest more than you can afford to lose | |
| ## 🔒 Privacy & Data | |
| - No data is stored permanently | |
| - Models train fresh for each prediction | |
| - Stock data fetched from Yahoo Finance API | |
| - No personal information collected | |
| --- | |
| **Made with ❤️ using Gradio & Python** | |
| """ | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=False, | |
| show_error=True, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) |