Spaces:
Sleeping
Sleeping
File size: 5,656 Bytes
f73838f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | import gradio as gr
import yfinance as yf
import pandas as pd
import numpy as np
from statsmodels.tsa.arima.model import ARIMA
import joblib
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')
# Load the saved ARIMA model (upload 'arima_model.pkl' to your Space)
try:
checkpoint = joblib.load('arima_model.pkl')
loaded_fit = checkpoint['model_fit']
last_train_date = checkpoint['last_date']
order = checkpoint['order']
print(f"Model loaded successfully. Last training date: {last_train_date}")
except FileNotFoundError:
print("Model file not found. Starting with a fresh fit.")
loaded_fit = None
last_train_date = None
order = (5, 1, 0)
# Function to fetch S&P 500 data
def fetch_data(start_date=None, period="max"):
ticker = yf.Ticker("^GSPC")
if start_date:
data = ticker.history(start=start_date, period=period)
else:
data = ticker.history(period=period)
data = data['Close'].dropna()
# normalize index to tz-naive datetimes to avoid tz-aware vs tz-naive comparisons
try:
idx = data.index
# try removing timezone if present
if getattr(idx, 'tz', None) is not None:
try:
data.index = idx.tz_convert(None)
except Exception:
data.index = idx.tz_localize(None)
except Exception:
# fallback: ensure datetime conversion
data.index = pd.to_datetime(data.index)
return data
# Function to update model with new data if needed
def update_model(model_fit, new_data, order):
if hasattr(model_fit.model.endog, 'index'):
updated_fit = model_fit.append(new_data, refit=False)
else:
updated_fit = model_fit.append(new_data.values, refit=False)
return updated_fit
# Function to predict next n steps
def predict_arima(model_fit, n_steps=1):
predictions = model_fit.forecast(steps=n_steps)
return predictions
# Main prediction function for Gradio
def forecast_sp500_arima(n_days, refit=False):
global loaded_fit, last_train_date # To update global state if needed
data = fetch_data()
if refit or loaded_fit is None:
# Refit on full current data
model = ARIMA(data, order=order)
loaded_fit = model.fit()
last_train_date = data.index[-1].date()
print("Model refitted on latest data.")
else:
# Determine last model date
if hasattr(loaded_fit.model.endog, 'index'):
# ensure we have a pandas.Timestamp
last_model_date = pd.to_datetime(loaded_fit.model.endog.index[-1])
else:
# last_train_date was saved as a date object; convert to Timestamp
last_model_date = pd.to_datetime(last_train_date)
# Use date() for comparison to avoid tz-aware vs tz-naive issues
new_start_str = (last_model_date.date() + timedelta(days=1)).strftime('%Y-%m-%d')
new_data = fetch_data(start_date=new_start_str)
appended = False
if len(new_data) > 0:
new_first = pd.to_datetime(new_data.index[0])
# compare dates (tz-naive) to avoid TypeError when indices have tz info
if new_first.date() > last_model_date.date():
# Instead of using append (which can change the model's index to a RangeIndex),
# refit the ARIMA on the full current data to preserve a DatetimeIndex and
# avoid indexing issues during prediction.
try:
model = ARIMA(data, order=order)
loaded_fit = model.fit()
appended = True
print("Model refitted with new data.")
except Exception as e:
print(f"Refit failed: {e}. Using existing model.")
else:
print(f"New data starts at {new_first}, model ends at {last_model_date}; no extension.")
else:
print("No new data available.")
if appended:
# keep last_train_date as a date for consistency
last_train_date = data.index[-1].date()
predictions = predict_arima(loaded_fit, n_days)
last_date = data.index[-1]
future_dates = [last_date + timedelta(days=i+1) for i in range(n_days)]
results = pd.DataFrame({
'Date': future_dates,
'Predicted Close': predictions
})
# Last actual price
last_actual = data.iloc[-1]
return f"Last Actual Close ({last_date.date()}): ${last_actual:.2f}\n\nForecast:\n{results.to_string(index=False)}"
# Gradio interface
with gr.Blocks(title="S&P 500 ARIMA Forecaster (Saved Model)") as demo:
gr.Markdown("# S&P 500 Stock Price Forecaster\nUsing saved ARIMA model with optional updates. \n Use int number for Price Forecast Prediction.")
with gr.Row():
n_days = gr.Slider(minimum=1, maximum=30, value=5, label="Number of days to forecast")
refit_btn = gr.Checkbox(label="Refit model on latest data (ignores saved model)", value=False)
predict_btn = gr.Button("Generate Forecast")
output = gr.Textbox(label="Forecast Results")
predict_btn.click(
fn=forecast_sp500_arima,
inputs=[n_days, refit_btn],
outputs=output
)
gr.Markdown("### Notes:\n- Loads saved ARIMA model from 'arima_model.pkl'.\n- Checks and appends new data only if it extends the model's index.\n- Falls back gracefully if append fails.\n- Data fetched via yfinance.\n- ARIMA order (5,1,0) used.\n- Upload 'arima_model.pkl' to your Space.")
if __name__ == "__main__":
demo.launch() |