Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from datetime import datetime, timedelta | |
| import pickle | |
| import yfinance as yf | |
| from statsmodels.tsa.arima.model import ARIMA | |
| from prophet import Prophet | |
| from tensorflow import keras | |
| from sklearn.preprocessing import MinMaxScaler | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Load your saved models (update paths as needed) | |
| # For Hugging Face, these will be in the same directory as app.py | |
| def load_models(): | |
| """Load all three models""" | |
| try: | |
| # Load ARIMA model | |
| with open('arima_model.pkl', 'rb') as f: | |
| arima_model = pickle.load(f) | |
| # Load Prophet model | |
| with open('prophet_model.pkl', 'rb') as f: | |
| prophet_model = pickle.load(f) | |
| # Load LSTM model and scaler | |
| lstm_model = keras.models.load_model('lstm_model.h5') | |
| with open('lstm_scaler.pkl', 'rb') as f: | |
| scaler = pickle.load(f) | |
| return arima_model, prophet_model, lstm_model, scaler | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| return None, None, None, None | |
| # Global variables for models | |
| arima_model, prophet_model, lstm_model, scaler = load_models() | |
| SEQ_LENGTH = 60 # Should match your training | |
| def fetch_stock_data(ticker, days=365): | |
| """Fetch stock data from Yahoo Finance""" | |
| try: | |
| end_date = datetime.now() | |
| start_date = end_date - timedelta(days=days) | |
| # Add retry logic and better error handling | |
| import time | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| df = yf.download(ticker, start=start_date, end=end_date, progress=False, auto_adjust=True) | |
| if not df.empty: | |
| break | |
| time.sleep(1) # Wait before retry | |
| except Exception as e: | |
| if attempt == max_retries - 1: | |
| raise e | |
| time.sleep(1) | |
| if df.empty: | |
| return None, f"No data found for ticker: {ticker}. Please verify the ticker symbol is correct." | |
| # Handle both multi-level and single-level columns | |
| if isinstance(df.columns, pd.MultiIndex): | |
| df = df['Close'].to_frame() | |
| elif 'Close' in df.columns: | |
| df = df[['Close']].copy() | |
| else: | |
| # Try to find a price column | |
| price_col = [col for col in df.columns if 'close' in col.lower()] | |
| if price_col: | |
| df = df[[price_col[0]]].copy() | |
| else: | |
| return None, f"Could not find price data for {ticker}" | |
| df.columns = ['Price'] | |
| df = df.dropna() | |
| if len(df) < 100: | |
| return None, f"Insufficient data for {ticker}. Only {len(df)} days available." | |
| return df, None | |
| except Exception as e: | |
| return None, f"Error fetching data for {ticker}: {str(e)}" | |
| def make_arima_forecast(data, days): | |
| """Make ARIMA forecast""" | |
| try: | |
| # Retrain ARIMA with recent data (or use loaded model) | |
| model = ARIMA(data['Price'], order=(1, 1, 1)) | |
| fitted = model.fit() | |
| forecast = fitted.forecast(steps=days) | |
| return forecast.values | |
| except Exception as e: | |
| print(f"ARIMA Error: {e}") | |
| return None | |
| def make_prophet_forecast(data, days): | |
| """Make Prophet forecast""" | |
| try: | |
| # Prepare data for Prophet | |
| prophet_data = pd.DataFrame({ | |
| 'ds': data.index, | |
| 'y': data['Price'].values | |
| }) | |
| # Create and fit model | |
| model = Prophet( | |
| daily_seasonality=True, | |
| weekly_seasonality=True, | |
| yearly_seasonality=True, | |
| changepoint_prior_scale=0.05 | |
| ) | |
| model.fit(prophet_data) | |
| # Make forecast | |
| future = model.make_future_dataframe(periods=days) | |
| forecast = model.predict(future) | |
| return forecast['yhat'].tail(days).values | |
| except Exception as e: | |
| print(f"Prophet Error: {e}") | |
| return None | |
| def make_lstm_forecast(data, days, model, scaler, seq_length=60): | |
| """Make LSTM forecast""" | |
| try: | |
| # Scale the data | |
| scaled_data = scaler.transform(data[['Price']]) | |
| # Prepare the last sequence | |
| last_sequence = scaled_data[-seq_length:].reshape(1, seq_length, 1) | |
| predictions = [] | |
| current_sequence = last_sequence.copy() | |
| # Generate predictions day by day | |
| for _ in range(days): | |
| pred = model.predict(current_sequence, verbose=0) | |
| predictions.append(pred[0, 0]) | |
| # Update sequence | |
| current_sequence = np.append(current_sequence[:, 1:, :], | |
| pred.reshape(1, 1, 1), axis=1) | |
| # Inverse transform predictions | |
| predictions = scaler.inverse_transform(np.array(predictions).reshape(-1, 1)) | |
| return predictions.flatten() | |
| except Exception as e: | |
| print(f"LSTM Error: {e}") | |
| return None | |
| def create_forecast_plot(historical_data, forecasts, ticker, model_names): | |
| """Create interactive plotly chart""" | |
| fig = go.Figure() | |
| # Historical data | |
| fig.add_trace(go.Scatter( | |
| x=historical_data.index, | |
| y=historical_data['Price'], | |
| mode='lines', | |
| name='Historical Price', | |
| line=dict(color='blue', width=2) | |
| )) | |
| # Generate future dates | |
| last_date = historical_data.index[-1] | |
| future_dates = pd.date_range(start=last_date + timedelta(days=1), | |
| periods=len(forecasts[0])) | |
| # Plot forecasts | |
| colors = ['red', 'purple', 'orange'] | |
| for i, (forecast, name) in enumerate(zip(forecasts, model_names)): | |
| if forecast is not None: | |
| fig.add_trace(go.Scatter( | |
| x=future_dates, | |
| y=forecast, | |
| mode='lines+markers', | |
| name=f'{name} Forecast', | |
| line=dict(color=colors[i], width=2, dash='dash'), | |
| marker=dict(size=6) | |
| )) | |
| fig.update_layout( | |
| title=f'{ticker} Stock Price Forecast', | |
| xaxis_title='Date', | |
| yaxis_title='Price ($)', | |
| hovermode='x unified', | |
| template='plotly_white', | |
| height=600, | |
| showlegend=True, | |
| legend=dict( | |
| yanchor="top", | |
| y=0.99, | |
| xanchor="left", | |
| x=0.01 | |
| ) | |
| ) | |
| return fig | |
| def predict_stock(ticker, forecast_days, model_choice): | |
| """Main prediction function""" | |
| # Validate inputs | |
| if not ticker: | |
| return None, "Please enter a stock ticker symbol", None | |
| ticker = ticker.upper().strip() | |
| # Fetch data | |
| data, error = fetch_stock_data(ticker, days=730) # 2 years of data | |
| if error: | |
| return None, f"Error: {error}", None | |
| # Make forecasts based on model choice | |
| forecasts = [] | |
| model_names = [] | |
| if model_choice in ["All Models", "ARIMA"]: | |
| arima_forecast = make_arima_forecast(data, forecast_days) | |
| if arima_forecast is not None: | |
| forecasts.append(arima_forecast) | |
| model_names.append("ARIMA") | |
| if model_choice in ["All Models", "Prophet"]: | |
| prophet_forecast = make_prophet_forecast(data, forecast_days) | |
| if prophet_forecast is not None: | |
| forecasts.append(prophet_forecast) | |
| model_names.append("Prophet") | |
| if model_choice in ["All Models", "LSTM"] and lstm_model is not None: | |
| lstm_forecast = make_lstm_forecast(data, forecast_days, lstm_model, scaler, SEQ_LENGTH) | |
| if lstm_forecast is not None: | |
| forecasts.append(lstm_forecast) | |
| model_names.append("LSTM") | |
| if not forecasts: | |
| return None, "Failed to generate forecasts. Please try again.", None | |
| # Create plot | |
| fig = create_forecast_plot(data, forecasts, ticker, model_names) | |
| # Create forecast table | |
| future_dates = pd.date_range( | |
| start=data.index[-1] + timedelta(days=1), | |
| periods=forecast_days | |
| ) | |
| forecast_df = pd.DataFrame({'Date': future_dates.strftime('%Y-%m-%d')}) | |
| for forecast, name in zip(forecasts, model_names): | |
| forecast_df[f'{name} Prediction ($)'] = np.round(forecast, 2) | |
| # Summary statistics | |
| summary = f""" | |
| ๐ **Forecast Summary for {ticker}** | |
| - Current Price: ${data['Price'].iloc[-1]:.2f} | |
| - Forecast Period: {forecast_days} days | |
| - Models Used: {', '.join(model_names)} | |
| **Predicted Price Range (Day {forecast_days}):** | |
| """ | |
| for forecast, name in zip(forecasts, model_names): | |
| final_price = forecast[-1] | |
| change = ((final_price - data['Price'].iloc[-1]) / data['Price'].iloc[-1]) * 100 | |
| summary += f"\n- {name}: ${final_price:.2f} ({change:+.2f}%)" | |
| return fig, summary, forecast_df | |
| # Create Gradio Interface | |
| with gr.Blocks(title="Stock Price Forecasting", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # ๐ Stock Price Forecasting App | |
| Predict future stock prices using ARIMA, Prophet, and LSTM models. | |
| Enter a stock ticker symbol and select forecast parameters below. | |
| **Note:** Predictions are for educational purposes only. Not financial advice. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| ticker_input = gr.Textbox( | |
| label="Stock Ticker Symbol", | |
| placeholder="e.g., AAPL, GOOGL, TSLA", | |
| value="AAPL" | |
| ) | |
| forecast_days = gr.Slider( | |
| minimum=1, | |
| maximum=90, | |
| value=30, | |
| step=1, | |
| label="Forecast Days" | |
| ) | |
| model_choice = gr.Radio( | |
| choices=["All Models", "ARIMA", "Prophet", "LSTM"], | |
| value="All Models", | |
| label="Select Model(s)" | |
| ) | |
| predict_btn = gr.Button("๐ฎ Generate Forecast", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| output_plot = gr.Plot(label="Forecast Visualization") | |
| with gr.Row(): | |
| output_summary = gr.Markdown(label="Forecast Summary") | |
| with gr.Row(): | |
| output_table = gr.Dataframe( | |
| label="Detailed Forecast", | |
| wrap=True, | |
| interactive=False | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["AAPL", 30, "All Models"], | |
| ["GOOGL", 14, "Prophet"], | |
| ["TSLA", 60, "LSTM"], | |
| ["MSFT", 45, "ARIMA"], | |
| ], | |
| inputs=[ticker_input, forecast_days, model_choice], | |
| ) | |
| # Connect the button to the function | |
| predict_btn.click( | |
| fn=predict_stock, | |
| inputs=[ticker_input, forecast_days, model_choice], | |
| outputs=[output_plot, output_summary, output_table] | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### ๐ About the Models | |
| - **ARIMA**: Statistical model for time series forecasting | |
| - **Prophet**: Facebook's forecasting tool, excellent for seasonality | |
| - **LSTM**: Deep learning model that captures complex patterns | |
| ### โ ๏ธ Disclaimer | |
| This tool is for educational and research purposes only. Stock market predictions are inherently uncertain. | |
| Always conduct thorough research and consult with financial advisors before making investment decisions. | |
| """ | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |