import gradio as gr import yfinance as yf import pandas as pd import numpy as np from statsmodels.tsa.arima.model import ARIMA import joblib from datetime import datetime, timedelta import warnings warnings.filterwarnings('ignore') # Load the saved ARIMA model (upload 'arima_model.pkl' to your Space) try: checkpoint = joblib.load('arima_model.pkl') loaded_fit = checkpoint['model_fit'] last_train_date = checkpoint['last_date'] order = checkpoint['order'] print(f"Model loaded successfully. Last training date: {last_train_date}") except FileNotFoundError: print("Model file not found. Starting with a fresh fit.") loaded_fit = None last_train_date = None order = (5, 1, 0) # Function to fetch S&P 500 data def fetch_data(start_date=None, period="max"): ticker = yf.Ticker("^GSPC") if start_date: data = ticker.history(start=start_date, period=period) else: data = ticker.history(period=period) data = data['Close'].dropna() # normalize index to tz-naive datetimes to avoid tz-aware vs tz-naive comparisons try: idx = data.index # try removing timezone if present if getattr(idx, 'tz', None) is not None: try: data.index = idx.tz_convert(None) except Exception: data.index = idx.tz_localize(None) except Exception: # fallback: ensure datetime conversion data.index = pd.to_datetime(data.index) return data # Function to update model with new data if needed def update_model(model_fit, new_data, order): if hasattr(model_fit.model.endog, 'index'): updated_fit = model_fit.append(new_data, refit=False) else: updated_fit = model_fit.append(new_data.values, refit=False) return updated_fit # Function to predict next n steps def predict_arima(model_fit, n_steps=1): predictions = model_fit.forecast(steps=n_steps) return predictions # Main prediction function for Gradio def forecast_sp500_arima(n_days, refit=False): global loaded_fit, last_train_date # To update global state if needed data = fetch_data() if refit or loaded_fit is None: # Refit on full current data model = ARIMA(data, order=order) loaded_fit = model.fit() last_train_date = data.index[-1].date() print("Model refitted on latest data.") else: # Determine last model date if hasattr(loaded_fit.model.endog, 'index'): # ensure we have a pandas.Timestamp last_model_date = pd.to_datetime(loaded_fit.model.endog.index[-1]) else: # last_train_date was saved as a date object; convert to Timestamp last_model_date = pd.to_datetime(last_train_date) # Use date() for comparison to avoid tz-aware vs tz-naive issues new_start_str = (last_model_date.date() + timedelta(days=1)).strftime('%Y-%m-%d') new_data = fetch_data(start_date=new_start_str) appended = False if len(new_data) > 0: new_first = pd.to_datetime(new_data.index[0]) # compare dates (tz-naive) to avoid TypeError when indices have tz info if new_first.date() > last_model_date.date(): # Instead of using append (which can change the model's index to a RangeIndex), # refit the ARIMA on the full current data to preserve a DatetimeIndex and # avoid indexing issues during prediction. try: model = ARIMA(data, order=order) loaded_fit = model.fit() appended = True print("Model refitted with new data.") except Exception as e: print(f"Refit failed: {e}. Using existing model.") else: print(f"New data starts at {new_first}, model ends at {last_model_date}; no extension.") else: print("No new data available.") if appended: # keep last_train_date as a date for consistency last_train_date = data.index[-1].date() predictions = predict_arima(loaded_fit, n_days) last_date = data.index[-1] future_dates = [last_date + timedelta(days=i+1) for i in range(n_days)] results = pd.DataFrame({ 'Date': future_dates, 'Predicted Close': predictions }) # Last actual price last_actual = data.iloc[-1] return f"Last Actual Close ({last_date.date()}): ${last_actual:.2f}\n\nForecast:\n{results.to_string(index=False)}" # Gradio interface with gr.Blocks(title="S&P 500 ARIMA Forecaster (Saved Model)") as demo: gr.Markdown("# S&P 500 Stock Price Forecaster\nUsing saved ARIMA model with optional updates. \n Use int number for Price Forecast Prediction.") with gr.Row(): n_days = gr.Slider(minimum=1, maximum=30, value=5, label="Number of days to forecast") refit_btn = gr.Checkbox(label="Refit model on latest data (ignores saved model)", value=False) predict_btn = gr.Button("Generate Forecast") output = gr.Textbox(label="Forecast Results") predict_btn.click( fn=forecast_sp500_arima, inputs=[n_days, refit_btn], outputs=output ) gr.Markdown("### Notes:\n- Loads saved ARIMA model from 'arima_model.pkl'.\n- Checks and appends new data only if it extends the model's index.\n- Falls back gracefully if append fails.\n- Data fetched via yfinance.\n- ARIMA order (5,1,0) used.\n- Upload 'arima_model.pkl' to your Space.") if __name__ == "__main__": demo.launch()