Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import yfinance as yf | |
| import pandas as pd | |
| import numpy as np | |
| from statsmodels.tsa.arima.model import ARIMA | |
| import joblib | |
| from datetime import datetime, timedelta | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Load the saved ARIMA model (upload 'arima_model.pkl' to your Space) | |
| try: | |
| checkpoint = joblib.load('arima_model.pkl') | |
| loaded_fit = checkpoint['model_fit'] | |
| last_train_date = checkpoint['last_date'] | |
| order = checkpoint['order'] | |
| print(f"Model loaded successfully. Last training date: {last_train_date}") | |
| except FileNotFoundError: | |
| print("Model file not found. Starting with a fresh fit.") | |
| loaded_fit = None | |
| last_train_date = None | |
| order = (5, 1, 0) | |
| # Function to fetch S&P 500 data | |
| def fetch_data(start_date=None, period="max"): | |
| ticker = yf.Ticker("^GSPC") | |
| if start_date: | |
| data = ticker.history(start=start_date, period=period) | |
| else: | |
| data = ticker.history(period=period) | |
| data = data['Close'].dropna() | |
| # normalize index to tz-naive datetimes to avoid tz-aware vs tz-naive comparisons | |
| try: | |
| idx = data.index | |
| # try removing timezone if present | |
| if getattr(idx, 'tz', None) is not None: | |
| try: | |
| data.index = idx.tz_convert(None) | |
| except Exception: | |
| data.index = idx.tz_localize(None) | |
| except Exception: | |
| # fallback: ensure datetime conversion | |
| data.index = pd.to_datetime(data.index) | |
| return data | |
| # Function to update model with new data if needed | |
| def update_model(model_fit, new_data, order): | |
| if hasattr(model_fit.model.endog, 'index'): | |
| updated_fit = model_fit.append(new_data, refit=False) | |
| else: | |
| updated_fit = model_fit.append(new_data.values, refit=False) | |
| return updated_fit | |
| # Function to predict next n steps | |
| def predict_arima(model_fit, n_steps=1): | |
| predictions = model_fit.forecast(steps=n_steps) | |
| return predictions | |
| # Main prediction function for Gradio | |
| def forecast_sp500_arima(n_days, refit=False): | |
| global loaded_fit, last_train_date # To update global state if needed | |
| data = fetch_data() | |
| if refit or loaded_fit is None: | |
| # Refit on full current data | |
| model = ARIMA(data, order=order) | |
| loaded_fit = model.fit() | |
| last_train_date = data.index[-1].date() | |
| print("Model refitted on latest data.") | |
| else: | |
| # Determine last model date | |
| if hasattr(loaded_fit.model.endog, 'index'): | |
| # ensure we have a pandas.Timestamp | |
| last_model_date = pd.to_datetime(loaded_fit.model.endog.index[-1]) | |
| else: | |
| # last_train_date was saved as a date object; convert to Timestamp | |
| last_model_date = pd.to_datetime(last_train_date) | |
| # Use date() for comparison to avoid tz-aware vs tz-naive issues | |
| new_start_str = (last_model_date.date() + timedelta(days=1)).strftime('%Y-%m-%d') | |
| new_data = fetch_data(start_date=new_start_str) | |
| appended = False | |
| if len(new_data) > 0: | |
| new_first = pd.to_datetime(new_data.index[0]) | |
| # compare dates (tz-naive) to avoid TypeError when indices have tz info | |
| if new_first.date() > last_model_date.date(): | |
| # Instead of using append (which can change the model's index to a RangeIndex), | |
| # refit the ARIMA on the full current data to preserve a DatetimeIndex and | |
| # avoid indexing issues during prediction. | |
| try: | |
| model = ARIMA(data, order=order) | |
| loaded_fit = model.fit() | |
| appended = True | |
| print("Model refitted with new data.") | |
| except Exception as e: | |
| print(f"Refit failed: {e}. Using existing model.") | |
| else: | |
| print(f"New data starts at {new_first}, model ends at {last_model_date}; no extension.") | |
| else: | |
| print("No new data available.") | |
| if appended: | |
| # keep last_train_date as a date for consistency | |
| last_train_date = data.index[-1].date() | |
| predictions = predict_arima(loaded_fit, n_days) | |
| last_date = data.index[-1] | |
| future_dates = [last_date + timedelta(days=i+1) for i in range(n_days)] | |
| results = pd.DataFrame({ | |
| 'Date': future_dates, | |
| 'Predicted Close': predictions | |
| }) | |
| # Last actual price | |
| last_actual = data.iloc[-1] | |
| return f"Last Actual Close ({last_date.date()}): ${last_actual:.2f}\n\nForecast:\n{results.to_string(index=False)}" | |
| # Gradio interface | |
| with gr.Blocks(title="S&P 500 ARIMA Forecaster (Saved Model)") as demo: | |
| gr.Markdown("# S&P 500 Stock Price Forecaster\nUsing saved ARIMA model with optional updates. \n Use int number for Price Forecast Prediction.") | |
| with gr.Row(): | |
| n_days = gr.Slider(minimum=1, maximum=30, value=5, label="Number of days to forecast") | |
| refit_btn = gr.Checkbox(label="Refit model on latest data (ignores saved model)", value=False) | |
| predict_btn = gr.Button("Generate Forecast") | |
| output = gr.Textbox(label="Forecast Results") | |
| predict_btn.click( | |
| fn=forecast_sp500_arima, | |
| inputs=[n_days, refit_btn], | |
| outputs=output | |
| ) | |
| gr.Markdown("### Notes:\n- Loads saved ARIMA model from 'arima_model.pkl'.\n- Checks and appends new data only if it extends the model's index.\n- Falls back gracefully if append fails.\n- Data fetched via yfinance.\n- ARIMA order (5,1,0) used.\n- Upload 'arima_model.pkl' to your Space.") | |
| if __name__ == "__main__": | |
| demo.launch() |