Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from datetime import datetime, timedelta | |
| import pickle | |
| import os | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # TensorFlow/Keras imports | |
| from tensorflow.keras.models import load_model | |
| from sklearn.preprocessing import MinMaxScaler | |
| # ARIMA and Prophet | |
| from statsmodels.tsa.arima.model import ARIMA | |
| from prophet import Prophet | |
| # -------------------------- | |
| # Load models safely | |
| # -------------------------- | |
| def load_models(): | |
| try: | |
| # ARIMA | |
| with open('arima_model.pkl', 'rb') as f: | |
| arima_model = pickle.load(f) | |
| # Prophet | |
| with open('prophet_model.pkl', 'rb') as f: | |
| prophet_model = pickle.load(f) | |
| # LSTM + scaler | |
| from tensorflow.keras.models import load_model | |
| lstm_model = load_model('lstm_model.keras') | |
| return arima_model, prophet_model, lstm_model, scaler | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| return None, None, None, None | |
| arima_model, prophet_model, lstm_model, scaler = load_models() | |
| SEQ_LENGTH = 60 | |
| # -------------------------- | |
| # Fetch stock data | |
| # -------------------------- | |
| def fetch_stock_data(ticker, days=365): | |
| """ | |
| Fetch stock data from local CSV fallback. | |
| Community Spaces cannot access the internet. | |
| """ | |
| ticker = ticker.upper().strip() | |
| filename = f"{ticker}.csv" | |
| if not os.path.exists(filename): | |
| return None, f"No data found for {ticker}. Upload {ticker}.csv in the Space root." | |
| df = pd.read_csv(filename, index_col=0, parse_dates=True) | |
| if 'Close' in df.columns: | |
| df = df[['Close']].copy() | |
| else: | |
| df.columns = ['Price'] | |
| df.columns = ['Price'] | |
| df['Price'] = pd.to_numeric(df['Price'], errors='coerce') | |
| df = df.dropna() | |
| df = df.tail(days) | |
| if df.empty: | |
| return None, f"No valid data found in {filename} for {ticker}." | |
| return df, None | |
| # -------------------------- | |
| # Forecasting functions | |
| # -------------------------- | |
| def make_arima_forecast(data, days): | |
| try: | |
| data['Price'] = pd.to_numeric(data['Price'], errors='coerce') | |
| data = data.dropna() | |
| model = ARIMA(data['Price'], order=(1,1,1)) | |
| fitted = model.fit() | |
| forecast = fitted.forecast(steps=days) | |
| return forecast.values | |
| except Exception as e: | |
| print(f"ARIMA Error: {e}") | |
| return None | |
| def make_prophet_forecast(data, days): | |
| try: | |
| prophet_data = pd.DataFrame({'ds': data.index, 'y': data['Price'].values}) | |
| model = Prophet( | |
| daily_seasonality=True, | |
| weekly_seasonality=True, | |
| yearly_seasonality=True, | |
| changepoint_prior_scale=0.05 | |
| ) | |
| model.fit(prophet_data) | |
| future = model.make_future_dataframe(periods=days) | |
| forecast = model.predict(future) | |
| return forecast['yhat'].tail(days).values | |
| except Exception as e: | |
| print(f"Prophet Error: {e}") | |
| return None | |
| def make_lstm_forecast(data, days, model, scaler, seq_length=60): | |
| try: | |
| scaled_data = scaler.transform(data[['Price']]) | |
| last_sequence = scaled_data[-seq_length:].reshape(1, seq_length, 1) | |
| predictions = [] | |
| current_sequence = last_sequence.copy() | |
| for _ in range(days): | |
| pred = model.predict(current_sequence, verbose=0) | |
| predictions.append(pred[0,0]) | |
| current_sequence = np.append(current_sequence[:,1:,:], pred.reshape(1,1,1), axis=1) | |
| predictions = scaler.inverse_transform(np.array(predictions).reshape(-1,1)) | |
| return predictions.flatten() | |
| except Exception as e: | |
| print(f"LSTM Error: {e}") | |
| return None | |
| # -------------------------- | |
| # Plot function | |
| # -------------------------- | |
| def create_forecast_plot(historical_data, forecasts, ticker, model_names): | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=historical_data.index, | |
| y=historical_data['Price'], | |
| mode='lines', | |
| name='Historical Price', | |
| line=dict(color='blue', width=2) | |
| )) | |
| last_date = pd.to_datetime(historical_data.index[-1]) | |
| future_dates = pd.date_range( | |
| start=last_date + timedelta(days=1), | |
| periods=len(forecasts[0]) | |
| ) | |
| colors = ['red', 'purple', 'orange'] | |
| for i, (forecast, name) in enumerate(zip(forecasts, model_names)): | |
| if forecast is not None: | |
| fig.add_trace(go.Scatter( | |
| x=future_dates, | |
| y=forecast, | |
| mode='lines+markers', | |
| name=f'{name} Forecast', | |
| line=dict(color=colors[i], width=2, dash='dash'), | |
| marker=dict(size=6) | |
| )) | |
| fig.update_layout( | |
| title=f'{ticker} Stock Price Forecast', | |
| xaxis_title='Date', | |
| yaxis_title='Price ($)', | |
| hovermode='x unified', | |
| template='plotly_white', | |
| height=600, | |
| showlegend=True | |
| ) | |
| return fig | |
| # -------------------------- | |
| # Main prediction function | |
| # -------------------------- | |
| def predict_stock(ticker, forecast_days, model_choice): | |
| if not ticker: | |
| return None, "Please enter a stock ticker symbol", None | |
| data, error = fetch_stock_data(ticker, days=730) | |
| if error: | |
| return None, f"Error: {error}", None | |
| forecasts = [] | |
| model_names = [] | |
| if model_choice in ["All Models", "ARIMA"]: | |
| arima_forecast = make_arima_forecast(data, forecast_days) | |
| if arima_forecast is not None: | |
| forecasts.append(arima_forecast) | |
| model_names.append("ARIMA") | |
| if model_choice in ["All Models", "Prophet"]: | |
| prophet_forecast = make_prophet_forecast(data, forecast_days) | |
| if prophet_forecast is not None: | |
| forecasts.append(prophet_forecast) | |
| model_names.append("Prophet") | |
| if model_choice in ["All Models", "LSTM"] and lstm_model is not None: | |
| lstm_forecast = make_lstm_forecast(data, forecast_days, lstm_model, scaler, SEQ_LENGTH) | |
| if lstm_forecast is not None: | |
| forecasts.append(lstm_forecast) | |
| model_names.append("LSTM") | |
| if not forecasts: | |
| return None, "Failed to generate forecasts.", None | |
| fig = create_forecast_plot(data, forecasts, ticker, model_names) | |
| # Forecast table | |
| future_dates = pd.date_range( | |
| start=pd.to_datetime(data.index[-1]) + timedelta(days=1), | |
| periods=forecast_days | |
| ) | |
| forecast_df = pd.DataFrame({'Date': future_dates.strftime('%Y-%m-%d')}) | |
| for forecast, name in zip(forecasts, model_names): | |
| forecast_df[f'{name} Prediction ($)'] = np.round(forecast, 2) | |
| # Summary | |
| summary = f"๐ **Forecast Summary for {ticker}**\n\n" \ | |
| f"- Current Price: ${data['Price'].iloc[-1]:.2f}\n" \ | |
| f"- Forecast Period: {forecast_days} days\n" \ | |
| f"- Models Used: {', '.join(model_names)}\n\n" \ | |
| f"**Predicted Price Range (Day {forecast_days}):**" | |
| for forecast, name in zip(forecasts, model_names): | |
| final_price = forecast[-1] | |
| change = ((final_price - data['Price'].iloc[-1]) / data['Price'].iloc[-1]) * 100 | |
| summary += f"\n- {name}: ${final_price:.2f} ({change:+.2f}%)" | |
| return fig, summary, forecast_df | |
| # -------------------------- | |
| # Gradio Interface | |
| # -------------------------- | |
| with gr.Blocks(title="Stock Price Forecasting", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# ๐ Stock Price Forecasting App\nPredict future stock prices using ARIMA, Prophet, and LSTM models.\nUpload CSV files in the Space root for offline use.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| ticker_input = gr.Textbox(label="Stock Ticker Symbol", placeholder="e.g., AAPL", value="AAPL") | |
| forecast_days = gr.Slider(minimum=1, maximum=90, value=30, step=1, label="Forecast Days") | |
| model_choice = gr.Radio(choices=["All Models", "ARIMA", "Prophet", "LSTM"], value="All Models", label="Select Model(s)") | |
| predict_btn = gr.Button("๐ฎ Generate Forecast", variant="primary") | |
| with gr.Column(scale=2): | |
| output_plot = gr.Plot(label="Forecast Visualization") | |
| output_summary = gr.Markdown(label="Forecast Summary") | |
| output_table = gr.Dataframe(label="Detailed Forecast", interactive=False) | |
| predict_btn.click(fn=predict_stock, inputs=[ticker_input, forecast_days, model_choice], | |
| outputs=[output_plot, output_summary, output_table]) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch() | |