File size: 8,556 Bytes
3c76f87
 
 
 
 
 
f19cf54
3c76f87
 
 
f19cf54
 
 
 
 
 
 
 
 
 
 
3c76f87
 
f19cf54
3c76f87
 
f19cf54
 
3c76f87
 
f19cf54
 
7e6a7e0
 
f19cf54
3c76f87
 
 
 
 
 
f19cf54
0926439
f19cf54
 
 
3c76f87
f19cf54
 
 
 
 
0926439
 
f19cf54
0926439
 
 
 
 
3c76f87
0926439
f19cf54
 
0926439
 
f19cf54
0926439
f19cf54
0926439
 
f19cf54
 
 
3c76f87
 
f19cf54
 
 
3c76f87
 
 
 
 
 
 
 
 
f19cf54
3c76f87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f19cf54
3c76f87
 
 
 
f19cf54
 
 
 
3c76f87
 
 
 
 
f19cf54
 
 
3c76f87
 
 
 
 
 
 
 
 
f19cf54
 
 
 
 
 
 
3c76f87
 
 
 
 
 
 
 
 
 
 
f19cf54
3c76f87
 
 
 
 
 
 
f19cf54
3c76f87
 
 
f19cf54
 
 
3c76f87
 
 
f19cf54
 
3c76f87
 
f19cf54
3c76f87
 
f19cf54
3c76f87
 
 
 
 
f19cf54
3c76f87
 
 
 
 
f19cf54
3c76f87
 
 
 
 
f19cf54
3c76f87
f19cf54
 
3c76f87
f19cf54
 
3c76f87
f19cf54
3c76f87
 
 
 
 
f19cf54
 
 
 
 
 
 
3c76f87
 
 
 
f19cf54
3c76f87
 
f19cf54
 
 
3c76f87
f19cf54
3c76f87
 
f19cf54
 
 
 
3c76f87
 
f19cf54
 
 
 
 
3c76f87
f19cf54
3c76f87
f19cf54
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import gradio as gr
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from datetime import datetime, timedelta
import pickle
import os
import warnings
warnings.filterwarnings('ignore')

# TensorFlow/Keras imports
from tensorflow.keras.models import load_model
from sklearn.preprocessing import MinMaxScaler

# ARIMA and Prophet
from statsmodels.tsa.arima.model import ARIMA
from prophet import Prophet

# --------------------------
# Load models safely
# --------------------------
def load_models():
    try:
        # ARIMA
        with open('arima_model.pkl', 'rb') as f:
            arima_model = pickle.load(f)

        # Prophet
        with open('prophet_model.pkl', 'rb') as f:
            prophet_model = pickle.load(f)

        # LSTM + scaler
        from tensorflow.keras.models import load_model
        lstm_model = load_model('lstm_model.keras')

        return arima_model, prophet_model, lstm_model, scaler
    except Exception as e:
        print(f"Error loading models: {e}")
        return None, None, None, None

arima_model, prophet_model, lstm_model, scaler = load_models()
SEQ_LENGTH = 60

# --------------------------
# Fetch stock data
# --------------------------
def fetch_stock_data(ticker, days=365):
    """
    Fetch stock data from local CSV fallback.
    Community Spaces cannot access the internet.
    """
    ticker = ticker.upper().strip()
    filename = f"{ticker}.csv"
    if not os.path.exists(filename):
        return None, f"No data found for {ticker}. Upload {ticker}.csv in the Space root."

    df = pd.read_csv(filename, index_col=0, parse_dates=True)
    if 'Close' in df.columns:
        df = df[['Close']].copy()
    else:
        df.columns = ['Price']

    df.columns = ['Price']
    df['Price'] = pd.to_numeric(df['Price'], errors='coerce')
    df = df.dropna()
    df = df.tail(days)

    if df.empty:
        return None, f"No valid data found in {filename} for {ticker}."
    return df, None

# --------------------------
# Forecasting functions
# --------------------------
def make_arima_forecast(data, days):
    try:
        data['Price'] = pd.to_numeric(data['Price'], errors='coerce')
        data = data.dropna()
        model = ARIMA(data['Price'], order=(1,1,1))
        fitted = model.fit()
        forecast = fitted.forecast(steps=days)
        return forecast.values
    except Exception as e:
        print(f"ARIMA Error: {e}")
        return None

def make_prophet_forecast(data, days):
    try:
        prophet_data = pd.DataFrame({'ds': data.index, 'y': data['Price'].values})
        model = Prophet(
            daily_seasonality=True,
            weekly_seasonality=True,
            yearly_seasonality=True,
            changepoint_prior_scale=0.05
        )
        model.fit(prophet_data)
        future = model.make_future_dataframe(periods=days)
        forecast = model.predict(future)
        return forecast['yhat'].tail(days).values
    except Exception as e:
        print(f"Prophet Error: {e}")
        return None

def make_lstm_forecast(data, days, model, scaler, seq_length=60):
    try:
        scaled_data = scaler.transform(data[['Price']])
        last_sequence = scaled_data[-seq_length:].reshape(1, seq_length, 1)

        predictions = []
        current_sequence = last_sequence.copy()
        for _ in range(days):
            pred = model.predict(current_sequence, verbose=0)
            predictions.append(pred[0,0])
            current_sequence = np.append(current_sequence[:,1:,:], pred.reshape(1,1,1), axis=1)

        predictions = scaler.inverse_transform(np.array(predictions).reshape(-1,1))
        return predictions.flatten()
    except Exception as e:
        print(f"LSTM Error: {e}")
        return None

# --------------------------
# Plot function
# --------------------------
def create_forecast_plot(historical_data, forecasts, ticker, model_names):
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=historical_data.index,
        y=historical_data['Price'],
        mode='lines',
        name='Historical Price',
        line=dict(color='blue', width=2)
    ))

    last_date = pd.to_datetime(historical_data.index[-1])
    future_dates = pd.date_range(
        start=last_date + timedelta(days=1),
        periods=len(forecasts[0])
    )

    colors = ['red', 'purple', 'orange']
    for i, (forecast, name) in enumerate(zip(forecasts, model_names)):
        if forecast is not None:
            fig.add_trace(go.Scatter(
                x=future_dates,
                y=forecast,
                mode='lines+markers',
                name=f'{name} Forecast',
                line=dict(color=colors[i], width=2, dash='dash'),
                marker=dict(size=6)
            ))

    fig.update_layout(
        title=f'{ticker} Stock Price Forecast',
        xaxis_title='Date',
        yaxis_title='Price ($)',
        hovermode='x unified',
        template='plotly_white',
        height=600,
        showlegend=True
    )
    return fig

# --------------------------
# Main prediction function
# --------------------------
def predict_stock(ticker, forecast_days, model_choice):
    if not ticker:
        return None, "Please enter a stock ticker symbol", None

    data, error = fetch_stock_data(ticker, days=730)
    if error:
        return None, f"Error: {error}", None

    forecasts = []
    model_names = []

    if model_choice in ["All Models", "ARIMA"]:
        arima_forecast = make_arima_forecast(data, forecast_days)
        if arima_forecast is not None:
            forecasts.append(arima_forecast)
            model_names.append("ARIMA")

    if model_choice in ["All Models", "Prophet"]:
        prophet_forecast = make_prophet_forecast(data, forecast_days)
        if prophet_forecast is not None:
            forecasts.append(prophet_forecast)
            model_names.append("Prophet")

    if model_choice in ["All Models", "LSTM"] and lstm_model is not None:
        lstm_forecast = make_lstm_forecast(data, forecast_days, lstm_model, scaler, SEQ_LENGTH)
        if lstm_forecast is not None:
            forecasts.append(lstm_forecast)
            model_names.append("LSTM")

    if not forecasts:
        return None, "Failed to generate forecasts.", None

    fig = create_forecast_plot(data, forecasts, ticker, model_names)

    # Forecast table
    future_dates = pd.date_range(
        start=pd.to_datetime(data.index[-1]) + timedelta(days=1),
        periods=forecast_days
    )
    forecast_df = pd.DataFrame({'Date': future_dates.strftime('%Y-%m-%d')})
    for forecast, name in zip(forecasts, model_names):
        forecast_df[f'{name} Prediction ($)'] = np.round(forecast, 2)

    # Summary
    summary = f"๐Ÿ“Š **Forecast Summary for {ticker}**\n\n" \
              f"- Current Price: ${data['Price'].iloc[-1]:.2f}\n" \
              f"- Forecast Period: {forecast_days} days\n" \
              f"- Models Used: {', '.join(model_names)}\n\n" \
              f"**Predicted Price Range (Day {forecast_days}):**"
    for forecast, name in zip(forecasts, model_names):
        final_price = forecast[-1]
        change = ((final_price - data['Price'].iloc[-1]) / data['Price'].iloc[-1]) * 100
        summary += f"\n- {name}: ${final_price:.2f} ({change:+.2f}%)"

    return fig, summary, forecast_df

# --------------------------
# Gradio Interface
# --------------------------
with gr.Blocks(title="Stock Price Forecasting", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# ๐Ÿ“ˆ Stock Price Forecasting App\nPredict future stock prices using ARIMA, Prophet, and LSTM models.\nUpload CSV files in the Space root for offline use.")
    with gr.Row():
        with gr.Column(scale=1):
            ticker_input = gr.Textbox(label="Stock Ticker Symbol", placeholder="e.g., AAPL", value="AAPL")
            forecast_days = gr.Slider(minimum=1, maximum=90, value=30, step=1, label="Forecast Days")
            model_choice = gr.Radio(choices=["All Models", "ARIMA", "Prophet", "LSTM"], value="All Models", label="Select Model(s)")
            predict_btn = gr.Button("๐Ÿ”ฎ Generate Forecast", variant="primary")
        with gr.Column(scale=2):
            output_plot = gr.Plot(label="Forecast Visualization")
    output_summary = gr.Markdown(label="Forecast Summary")
    output_table = gr.Dataframe(label="Detailed Forecast", interactive=False)

    predict_btn.click(fn=predict_stock, inputs=[ticker_input, forecast_days, model_choice],
                      outputs=[output_plot, output_summary, output_table])

# Launch
if __name__ == "__main__":
    demo.launch()