File size: 5,656 Bytes
f73838f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
import yfinance as yf
import pandas as pd
import numpy as np
from statsmodels.tsa.arima.model import ARIMA
import joblib
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')

# Load the saved ARIMA model (upload 'arima_model.pkl' to your Space)
try:
    checkpoint = joblib.load('arima_model.pkl')
    loaded_fit = checkpoint['model_fit']
    last_train_date = checkpoint['last_date']
    order = checkpoint['order']
    print(f"Model loaded successfully. Last training date: {last_train_date}")
except FileNotFoundError:
    print("Model file not found. Starting with a fresh fit.")
    loaded_fit = None
    last_train_date = None
    order = (5, 1, 0)

# Function to fetch S&P 500 data
def fetch_data(start_date=None, period="max"):
    ticker = yf.Ticker("^GSPC")
    if start_date:
        data = ticker.history(start=start_date, period=period)
    else:
        data = ticker.history(period=period)
    data = data['Close'].dropna()
    # normalize index to tz-naive datetimes to avoid tz-aware vs tz-naive comparisons
    try:
        idx = data.index
        # try removing timezone if present
        if getattr(idx, 'tz', None) is not None:
            try:
                data.index = idx.tz_convert(None)
            except Exception:
                data.index = idx.tz_localize(None)
    except Exception:
        # fallback: ensure datetime conversion
        data.index = pd.to_datetime(data.index)
    return data

# Function to update model with new data if needed
def update_model(model_fit, new_data, order):
    if hasattr(model_fit.model.endog, 'index'):
        updated_fit = model_fit.append(new_data, refit=False)
    else:
        updated_fit = model_fit.append(new_data.values, refit=False)
    return updated_fit

# Function to predict next n steps
def predict_arima(model_fit, n_steps=1):
    predictions = model_fit.forecast(steps=n_steps)
    return predictions

# Main prediction function for Gradio
def forecast_sp500_arima(n_days, refit=False):
    global loaded_fit, last_train_date  # To update global state if needed
    
    data = fetch_data()
    
    if refit or loaded_fit is None:
        # Refit on full current data
        model = ARIMA(data, order=order)
        loaded_fit = model.fit()
        last_train_date = data.index[-1].date()
        print("Model refitted on latest data.")
    else:
        # Determine last model date
        if hasattr(loaded_fit.model.endog, 'index'):
            # ensure we have a pandas.Timestamp
            last_model_date = pd.to_datetime(loaded_fit.model.endog.index[-1])
        else:
            # last_train_date was saved as a date object; convert to Timestamp
            last_model_date = pd.to_datetime(last_train_date)

        # Use date() for comparison to avoid tz-aware vs tz-naive issues
        new_start_str = (last_model_date.date() + timedelta(days=1)).strftime('%Y-%m-%d')
        new_data = fetch_data(start_date=new_start_str)
        
        appended = False
        if len(new_data) > 0:
            new_first = pd.to_datetime(new_data.index[0])
            # compare dates (tz-naive) to avoid TypeError when indices have tz info
            if new_first.date() > last_model_date.date():
                # Instead of using append (which can change the model's index to a RangeIndex),
                # refit the ARIMA on the full current data to preserve a DatetimeIndex and
                # avoid indexing issues during prediction.
                try:
                    model = ARIMA(data, order=order)
                    loaded_fit = model.fit()
                    appended = True
                    print("Model refitted with new data.")
                except Exception as e:
                    print(f"Refit failed: {e}. Using existing model.")
            else:
                print(f"New data starts at {new_first}, model ends at {last_model_date}; no extension.")
        else:
            print("No new data available.")
        
        if appended:
            # keep last_train_date as a date for consistency
            last_train_date = data.index[-1].date()
    
    predictions = predict_arima(loaded_fit, n_days)
    
    last_date = data.index[-1]
    future_dates = [last_date + timedelta(days=i+1) for i in range(n_days)]
    
    results = pd.DataFrame({
        'Date': future_dates,
        'Predicted Close': predictions
    })
    
    # Last actual price
    last_actual = data.iloc[-1]
    
    return f"Last Actual Close ({last_date.date()}): ${last_actual:.2f}\n\nForecast:\n{results.to_string(index=False)}"

# Gradio interface
with gr.Blocks(title="S&P 500 ARIMA Forecaster (Saved Model)") as demo:
    gr.Markdown("# S&P 500 Stock Price Forecaster\nUsing saved ARIMA model with optional updates. \n Use int number for Price Forecast Prediction.")
    
    with gr.Row():
        n_days = gr.Slider(minimum=1, maximum=30, value=5, label="Number of days to forecast")
        refit_btn = gr.Checkbox(label="Refit model on latest data (ignores saved model)", value=False)
    
    predict_btn = gr.Button("Generate Forecast")
    
    output = gr.Textbox(label="Forecast Results")
    
    predict_btn.click(
        fn=forecast_sp500_arima,
        inputs=[n_days, refit_btn],
        outputs=output
    )
    
    gr.Markdown("### Notes:\n- Loads saved ARIMA model from 'arima_model.pkl'.\n- Checks and appends new data only if it extends the model's index.\n- Falls back gracefully if append fails.\n- Data fetched via yfinance.\n- ARIMA order (5,1,0) used.\n- Upload 'arima_model.pkl' to your Space.")

if __name__ == "__main__":
    demo.launch()