Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import yfinance as yf
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import plotly.graph_objects as go
|
| 6 |
+
from plotly.subplots import make_subplots
|
| 7 |
+
from statsmodels.tsa.api import VAR
|
| 8 |
+
from statsmodels.tsa.stattools import adfuller
|
| 9 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 10 |
+
from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score
|
| 11 |
+
from datetime import datetime, timedelta
|
| 12 |
+
|
| 13 |
+
# Helper functions remain unchanged
|
| 14 |
+
def download_data(tickers, start_date, end_date):
|
| 15 |
+
data = {}
|
| 16 |
+
for name, ticker in tickers.items():
|
| 17 |
+
data[name] = yf.download(ticker, start=start_date, end=end_date)
|
| 18 |
+
return data
|
| 19 |
+
|
| 20 |
+
def calculate_returns_and_volatility(data, rolling_window):
|
| 21 |
+
stock_data = data['stock']
|
| 22 |
+
stock_data['Log_Returns'] = np.log(stock_data['Adj Close'] / stock_data['Adj Close'].shift(1))
|
| 23 |
+
stock_data['Volatility'] = stock_data['Log_Returns'].rolling(window=rolling_window).std() * np.sqrt(252)
|
| 24 |
+
stock_data = stock_data.dropna()
|
| 25 |
+
data['stock'] = stock_data
|
| 26 |
+
|
| 27 |
+
sp500_data = data['sp500']
|
| 28 |
+
sp500_data['Log_Returns'] = np.log(sp500_data['Adj Close'] / sp500_data['Adj Close'].shift(1))
|
| 29 |
+
sp500_data['SP500_Volatility'] = sp500_data['Log_Returns'].rolling(window=rolling_window).std() * np.sqrt(252)
|
| 30 |
+
sp500_data = sp500_data.dropna()
|
| 31 |
+
data['sp500'] = sp500_data
|
| 32 |
+
|
| 33 |
+
return data
|
| 34 |
+
|
| 35 |
+
def merge_data(data):
|
| 36 |
+
merged_data = data['stock'][['Volatility']].copy()
|
| 37 |
+
merged_data['SP500'] = data['sp500']['Adj Close']
|
| 38 |
+
merged_data['SP500_Volatility'] = data['sp500']['SP500_Volatility']
|
| 39 |
+
merged_data['VIX'] = np.log(data['vix']['Adj Close'])
|
| 40 |
+
merged_data['SP500_Returns'] = data['sp500']['Log_Returns']
|
| 41 |
+
merged_data['Volume'] = data['stock']['Volume']
|
| 42 |
+
merged_data['Stock_Returns'] = data['stock']['Log_Returns']
|
| 43 |
+
merged_data = merged_data.dropna()
|
| 44 |
+
return merged_data
|
| 45 |
+
|
| 46 |
+
def check_stationarity_and_difference(df):
|
| 47 |
+
"""
|
| 48 |
+
Perform ADF test for stationarity and apply differencing if necessary.
|
| 49 |
+
"""
|
| 50 |
+
for column in df.columns:
|
| 51 |
+
result = adfuller(df[column].dropna())
|
| 52 |
+
p_value = result[1]
|
| 53 |
+
if p_value > 0.05:
|
| 54 |
+
# Non-stationary series; apply differencing
|
| 55 |
+
df[column] = df[column].diff()
|
| 56 |
+
else:
|
| 57 |
+
pass # Series is stationary; no differencing needed
|
| 58 |
+
|
| 59 |
+
def normalize_data(df):
|
| 60 |
+
scaler = MinMaxScaler(feature_range=(0, 1))
|
| 61 |
+
scaled_data = pd.DataFrame(scaler.fit_transform(df), columns=df.columns, index=df.index)
|
| 62 |
+
return scaled_data, scaler
|
| 63 |
+
|
| 64 |
+
def fit_var_model(scaled_data, max_lags=30):
|
| 65 |
+
model = VAR(scaled_data)
|
| 66 |
+
lag_order_results = model.select_order(maxlags=max_lags)
|
| 67 |
+
optimal_lag = lag_order_results.aic
|
| 68 |
+
results = model.fit(optimal_lag)
|
| 69 |
+
return results, optimal_lag
|
| 70 |
+
|
| 71 |
+
def forecast_future_values(results, scaled_data, scaler, steps, optimal_lag):
|
| 72 |
+
forecast_95, lower_95, upper_95 = results.forecast_interval(
|
| 73 |
+
scaled_data.values[-optimal_lag:], steps=steps, alpha=0.05)
|
| 74 |
+
forecast_68, lower_68, upper_68 = results.forecast_interval(
|
| 75 |
+
scaled_data.values[-optimal_lag:], steps=steps, alpha=0.32)
|
| 76 |
+
|
| 77 |
+
forecast_original = scaler.inverse_transform(forecast_95)
|
| 78 |
+
lower_95_original = scaler.inverse_transform(lower_95)
|
| 79 |
+
upper_95_original = scaler.inverse_transform(upper_95)
|
| 80 |
+
lower_68_original = scaler.inverse_transform(lower_68)
|
| 81 |
+
upper_68_original = scaler.inverse_transform(upper_68)
|
| 82 |
+
|
| 83 |
+
return forecast_original, lower_95_original, upper_95_original, lower_68_original, upper_68_original
|
| 84 |
+
|
| 85 |
+
# Plotting functions remain unchanged
|
| 86 |
+
def plot_forecast(merged_data, future_dates, volatility_predictions, lower_volatility_95, upper_volatility_95, lower_volatility_68, upper_volatility_68):
|
| 87 |
+
fig = go.Figure()
|
| 88 |
+
|
| 89 |
+
# Plot historical volatility
|
| 90 |
+
fig.add_trace(go.Scatter(x=merged_data.index, y=merged_data['Volatility'], mode='lines', name='Historical Volatility'))
|
| 91 |
+
|
| 92 |
+
# Plot 95% confidence intervals
|
| 93 |
+
fig.add_trace(go.Scatter(x=future_dates, y=upper_volatility_95, fill=None, mode='lines', line_color='lightgray', name='95% CI Upper'))
|
| 94 |
+
fig.add_trace(go.Scatter(x=future_dates, y=lower_volatility_95, fill='tonexty', mode='lines', line_color='lightgray', name='95% CI Lower'))
|
| 95 |
+
|
| 96 |
+
# Plot 68% confidence intervals
|
| 97 |
+
fig.add_trace(go.Scatter(x=future_dates, y=upper_volatility_68, fill=None, mode='lines', line_color='blue', name='68% CI Upper'))
|
| 98 |
+
fig.add_trace(go.Scatter(x=future_dates, y=lower_volatility_68, fill='tonexty', mode='lines', line_color='blue', name='68% CI Lower'))
|
| 99 |
+
|
| 100 |
+
# Plot predicted volatility
|
| 101 |
+
fig.add_trace(go.Scatter(x=future_dates, y=volatility_predictions, mode='lines',line_color='orange' ,name='Predicted Volatility', line=dict(dash='dot', width=4)))
|
| 102 |
+
|
| 103 |
+
fig.update_layout(title='VAR Predicted Volatility with Confidence Intervals',
|
| 104 |
+
xaxis_title='Date', yaxis_title='Volatility',
|
| 105 |
+
template='plotly_white')
|
| 106 |
+
|
| 107 |
+
return fig
|
| 108 |
+
|
| 109 |
+
def plot_extended_forecast(forecast_data_extended, future_dates, volatility_predictions):
|
| 110 |
+
"""
|
| 111 |
+
Plot extended actual historical volatility and predicted future volatility using Plotly.
|
| 112 |
+
"""
|
| 113 |
+
# Align the length of future dates and predicted values
|
| 114 |
+
future_dates = future_dates[:len(volatility_predictions)]
|
| 115 |
+
volatility_predictions = volatility_predictions[:len(future_dates)]
|
| 116 |
+
|
| 117 |
+
fig = go.Figure()
|
| 118 |
+
|
| 119 |
+
# Plot extended actual historical volatility
|
| 120 |
+
fig.add_trace(go.Scatter(x=forecast_data_extended.index, y=forecast_data_extended['Volatility'], mode='lines', name='Extended Historical Volatility'))
|
| 121 |
+
|
| 122 |
+
# Plot predicted future volatility
|
| 123 |
+
fig.add_trace(go.Scatter(x=future_dates, y=volatility_predictions, mode='lines', name='Predicted Future Volatility', line=dict(dash='dash')))
|
| 124 |
+
|
| 125 |
+
fig.update_layout(title='Predicted Volatility with Extended Actual Data',
|
| 126 |
+
xaxis_title='Date',
|
| 127 |
+
yaxis_title='Volatility',
|
| 128 |
+
template='plotly_white')
|
| 129 |
+
|
| 130 |
+
return fig
|
| 131 |
+
|
| 132 |
+
def calculate_performance_metrics(forecast_data_extended, future_dates, volatility_predictions):
|
| 133 |
+
"""
|
| 134 |
+
Calculate performance metrics and return as markdown text.
|
| 135 |
+
"""
|
| 136 |
+
# Ensure future_dates are in the same format as the forecast_data index
|
| 137 |
+
new_future_dates = pd.to_datetime(future_dates)
|
| 138 |
+
|
| 139 |
+
# Create a DataFrame for future dates and predicted values
|
| 140 |
+
predicted_df = pd.DataFrame({
|
| 141 |
+
'Date': new_future_dates,
|
| 142 |
+
'Predicted Volatility': volatility_predictions
|
| 143 |
+
}).set_index('Date')
|
| 144 |
+
|
| 145 |
+
# Extract the actual future volatility values for the prediction period
|
| 146 |
+
actual_volatility = forecast_data_extended.loc[new_future_dates, 'Volatility']
|
| 147 |
+
|
| 148 |
+
# Create DataFrame for actual values
|
| 149 |
+
actual_df = pd.DataFrame({
|
| 150 |
+
'Date': actual_volatility.index,
|
| 151 |
+
'Actual Volatility': actual_volatility.values
|
| 152 |
+
}).set_index('Date')
|
| 153 |
+
|
| 154 |
+
# Join the actual and predicted DataFrames on the Date index
|
| 155 |
+
results_df = actual_df.join(predicted_df, how='inner')
|
| 156 |
+
|
| 157 |
+
# Metrics calculation
|
| 158 |
+
rmse = np.sqrt(mean_squared_error(results_df['Actual Volatility'], results_df['Predicted Volatility']))
|
| 159 |
+
mape = mean_absolute_percentage_error(results_df['Actual Volatility'], results_df['Predicted Volatility'])
|
| 160 |
+
mae = mean_absolute_error(results_df['Actual Volatility'], results_df['Predicted Volatility'])
|
| 161 |
+
mse = mean_squared_error(results_df['Actual Volatility'], results_df['Predicted Volatility'])
|
| 162 |
+
r2 = r2_score(results_df['Actual Volatility'], results_df['Predicted Volatility'])
|
| 163 |
+
|
| 164 |
+
metrics = f"""
|
| 165 |
+
**RMSE**: {rmse:.4f}
|
| 166 |
+
**MAPE**: {mape:.2%}
|
| 167 |
+
**MAE**: {mae:.4f}
|
| 168 |
+
**MSE**: {mse:.4f}
|
| 169 |
+
**R²**: {r2:.4f}
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
return metrics
|
| 173 |
+
|
| 174 |
+
def plot_residuals_plotly(results):
|
| 175 |
+
"""
|
| 176 |
+
Plot residuals of VAR model using Plotly.
|
| 177 |
+
"""
|
| 178 |
+
residuals = results.resid
|
| 179 |
+
fig = go.Figure()
|
| 180 |
+
for col in residuals.columns:
|
| 181 |
+
fig.add_trace(go.Scatter(x=residuals.index, y=residuals[col], mode='lines', name=f'Residuals: {col}'))
|
| 182 |
+
|
| 183 |
+
fig.update_layout(title='Residuals of VAR Model',
|
| 184 |
+
xaxis_title='Date', yaxis_title='Residuals',
|
| 185 |
+
template='plotly_white')
|
| 186 |
+
return fig
|
| 187 |
+
|
| 188 |
+
def calculate_metrics_and_plot_errors_plotly(forecast_data_extended, future_dates, volatility_predictions):
|
| 189 |
+
"""
|
| 190 |
+
Calculate performance metrics and plot prediction errors using Plotly.
|
| 191 |
+
"""
|
| 192 |
+
# Ensure future_dates are in the same format as the forecast_data index
|
| 193 |
+
new_future_dates = pd.to_datetime(future_dates)
|
| 194 |
+
|
| 195 |
+
# Create a DataFrame for future dates and predicted values
|
| 196 |
+
predicted_df = pd.DataFrame({
|
| 197 |
+
'Date': new_future_dates,
|
| 198 |
+
'Predicted Volatility': volatility_predictions
|
| 199 |
+
}).set_index('Date')
|
| 200 |
+
|
| 201 |
+
# Extract the actual future volatility values for the prediction period
|
| 202 |
+
actual_volatility = forecast_data_extended.loc[new_future_dates, 'Volatility']
|
| 203 |
+
|
| 204 |
+
# Create DataFrame for actual values
|
| 205 |
+
actual_df = pd.DataFrame({
|
| 206 |
+
'Date': actual_volatility.index,
|
| 207 |
+
'Actual Volatility': actual_volatility.values
|
| 208 |
+
}).set_index('Date')
|
| 209 |
+
|
| 210 |
+
# Join the actual and predicted DataFrames on the Date index
|
| 211 |
+
results_df = actual_df.join(predicted_df, how='inner')
|
| 212 |
+
|
| 213 |
+
# Calculate errors over time
|
| 214 |
+
results_df['Error'] = results_df['Actual Volatility'] - results_df['Predicted Volatility']
|
| 215 |
+
|
| 216 |
+
# Create a Plotly figure with two subplots
|
| 217 |
+
fig = make_subplots(rows=2, cols=1, subplot_titles=("Scatter Plot of Predicted vs Actual Volatility", "Prediction Error Over Time"))
|
| 218 |
+
|
| 219 |
+
# Scatter plot of predicted vs actual values
|
| 220 |
+
fig.add_trace(
|
| 221 |
+
go.Scatter(
|
| 222 |
+
x=results_df['Actual Volatility'],
|
| 223 |
+
y=results_df['Predicted Volatility'],
|
| 224 |
+
mode='markers',
|
| 225 |
+
name='Predicted vs Actual'
|
| 226 |
+
),
|
| 227 |
+
row=1, col=1
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Add a line y = x
|
| 231 |
+
min_vol = min(results_df['Actual Volatility'].min(), results_df['Predicted Volatility'].min())
|
| 232 |
+
max_vol = max(results_df['Actual Volatility'].max(), results_df['Predicted Volatility'].max())
|
| 233 |
+
fig.add_trace(
|
| 234 |
+
go.Scatter(
|
| 235 |
+
x=[min_vol, max_vol],
|
| 236 |
+
y=[min_vol, max_vol],
|
| 237 |
+
mode='lines',
|
| 238 |
+
name='Perfect Prediction',
|
| 239 |
+
line=dict(dash='dash', color='red')
|
| 240 |
+
),
|
| 241 |
+
row=1, col=1
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Error plot over time
|
| 245 |
+
fig.add_trace(
|
| 246 |
+
go.Scatter(
|
| 247 |
+
x=results_df.index,
|
| 248 |
+
y=results_df['Error'],
|
| 249 |
+
mode='lines+markers',
|
| 250 |
+
name='Prediction Error'
|
| 251 |
+
),
|
| 252 |
+
row=2, col=1
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
fig.update_layout(height=700, title="Model Performance: Prediction Errors", template='plotly_white')
|
| 256 |
+
fig.update_xaxes(title_text='Actual Volatility', row=1, col=1)
|
| 257 |
+
fig.update_yaxes(title_text='Predicted Volatility', row=1, col=1)
|
| 258 |
+
fig.update_xaxes(title_text='Date', row=2, col=1)
|
| 259 |
+
fig.update_yaxes(title_text='Error (Actual - Predicted)', row=2, col=1)
|
| 260 |
+
|
| 261 |
+
return fig
|
| 262 |
+
|
| 263 |
+
def extended_forecast_evaluation(tickers, rolling_window, forecast_start_date,
|
| 264 |
+
forecast_end_date, future_dates, volatility_predictions):
|
| 265 |
+
"""
|
| 266 |
+
Extend forecast evaluation by comparing with actual data over an extended period.
|
| 267 |
+
"""
|
| 268 |
+
# Derive extended_start_date to ensure we have enough data for the rolling window
|
| 269 |
+
extended_start_date = (forecast_start_date - timedelta(days=rolling_window * 3)).strftime('%Y-%m-%d')
|
| 270 |
+
|
| 271 |
+
# Extended end date includes extra days for comparison
|
| 272 |
+
extended_end_date = forecast_end_date + timedelta(days=extra_days)
|
| 273 |
+
|
| 274 |
+
# Download the extended actual data for the stock
|
| 275 |
+
extended_actual_data = yf.download(tickers['stock'], start=extended_start_date, end=extended_end_date.strftime('%Y-%m-%d'))
|
| 276 |
+
|
| 277 |
+
# Calculate daily returns and rolling volatility for the extended data
|
| 278 |
+
extended_actual_data['Returns'] = extended_actual_data['Adj Close'].pct_change()
|
| 279 |
+
extended_actual_data['Volatility'] = extended_actual_data['Returns'].rolling(window=rolling_window).std() * np.sqrt(252)
|
| 280 |
+
|
| 281 |
+
# Create forecast horizon DataFrame
|
| 282 |
+
forecast_horizon = pd.DataFrame(index=future_dates)
|
| 283 |
+
forecast_horizon['Volatility'] = np.nan
|
| 284 |
+
|
| 285 |
+
# Combine extended actual data with forecast horizon
|
| 286 |
+
forecast_data_extended = pd.concat([extended_actual_data, forecast_horizon], axis=0).sort_index()
|
| 287 |
+
forecast_data_extended['Volatility'] = forecast_data_extended['Volatility'].fillna(method='ffill')
|
| 288 |
+
forecast_data_extended = forecast_data_extended.dropna(subset=['Volatility'])
|
| 289 |
+
|
| 290 |
+
return forecast_data_extended
|
| 291 |
+
|
| 292 |
+
# Set page configuration for a wide layout
|
| 293 |
+
st.set_page_config(layout="wide")
|
| 294 |
+
|
| 295 |
+
st.title("Stock Volatility Prediction")
|
| 296 |
+
|
| 297 |
+
st.sidebar.title("Input Parameters")
|
| 298 |
+
|
| 299 |
+
# How-to-use instructions in an expander
|
| 300 |
+
with st.sidebar.expander("How to Use the App", expanded=False):
|
| 301 |
+
st.markdown("""
|
| 302 |
+
**Step 1**: Select the page you want to use (Real-time Predictions or Model Performance).
|
| 303 |
+
**Step 2**: Enter the stock ticker symbol you wish to analyze.
|
| 304 |
+
**Step 3**: Adjust the start and end dates for your analysis.
|
| 305 |
+
**Step 4**: Configure additional parameters like rolling window and forecast horizon.
|
| 306 |
+
**Step 5**: Click the **Run Model** button to generate the forecasts and view the results.
|
| 307 |
+
""")
|
| 308 |
+
|
| 309 |
+
# Pages
|
| 310 |
+
page = st.sidebar.radio("Choose Page", ("Real-time Predictions", "Model Performance"))
|
| 311 |
+
|
| 312 |
+
# Common Sidebar inputs within an expander (opened by default)
|
| 313 |
+
with st.sidebar.expander("Ticker and Date Selection", expanded=True):
|
| 314 |
+
stock_ticker = st.text_input("Stock Ticker", value="ASML", help="Enter the ticker symbol of the stock you want to analyze (e.g., AAPL for Apple Inc.).")
|
| 315 |
+
|
| 316 |
+
# Hide VIX and SP500 tickers by using default values internally
|
| 317 |
+
tickers = {"stock": stock_ticker, "sp500": "^GSPC", "vix": "^VIX"}
|
| 318 |
+
|
| 319 |
+
# Additional parameters within another expander (opened by default)
|
| 320 |
+
with st.sidebar.expander("Model Parameters", expanded=True):
|
| 321 |
+
rolling_window = st.number_input(
|
| 322 |
+
"Rolling Window",
|
| 323 |
+
min_value=1,
|
| 324 |
+
value=21,
|
| 325 |
+
help="The number of days to use for calculating the rolling volatility."
|
| 326 |
+
)
|
| 327 |
+
n_days = st.number_input(
|
| 328 |
+
"Forecast Horizon (Days)",
|
| 329 |
+
min_value=1,
|
| 330 |
+
value=30,
|
| 331 |
+
help="The number of future days over which to forecast volatility."
|
| 332 |
+
)
|
| 333 |
+
if page == "Model Performance":
|
| 334 |
+
extra_days = st.number_input(
|
| 335 |
+
"Extra Days of Actual Data for Comparison",
|
| 336 |
+
min_value=1,
|
| 337 |
+
value=15,
|
| 338 |
+
help="Additional days of actual future data to include for comparison with the forecast."
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Separate Start and End Dates for each page within the expander
|
| 342 |
+
if page == "Real-time Predictions":
|
| 343 |
+
with st.sidebar.expander("Date Range Selection", expanded=True):
|
| 344 |
+
start_date_rt = st.date_input(
|
| 345 |
+
"Start Date",
|
| 346 |
+
value=datetime(2020, 1, 1),
|
| 347 |
+
key='start_date_rt',
|
| 348 |
+
help="The start date for getting the historical data."
|
| 349 |
+
)
|
| 350 |
+
end_date_rt = st.date_input(
|
| 351 |
+
"End Date",
|
| 352 |
+
value=datetime.now(),
|
| 353 |
+
key='end_date_rt',
|
| 354 |
+
max_value=datetime.now(),
|
| 355 |
+
help="The end date for getting the historical data."
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# Context description in the main body
|
| 359 |
+
st.markdown("""
|
| 360 |
+
### Real-time Predictions
|
| 361 |
+
This apps allows you to generate real-time forecasts of stock price volatility using an advanced multi-variate deep learning learning model using external factors. Volatility is calculated as the rolling standard deviation of the stock's daily log returns. The model provides confidence intervals (68% and 95%) to represent uncertainty in the predictions.
|
| 362 |
+
""")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# Run button
|
| 366 |
+
run_button = st.sidebar.button("Run Model", key='run_button_rt')
|
| 367 |
+
|
| 368 |
+
# Placeholder for plots
|
| 369 |
+
plot_placeholder = st.empty()
|
| 370 |
+
|
| 371 |
+
if run_button:
|
| 372 |
+
with st.spinner("Downloading data and processing..."):
|
| 373 |
+
data = download_data(tickers, start_date_rt, end_date_rt)
|
| 374 |
+
data = calculate_returns_and_volatility(data, rolling_window)
|
| 375 |
+
merged_data = merge_data(data)
|
| 376 |
+
|
| 377 |
+
# Preprocess the data
|
| 378 |
+
scaled_data, scaler = normalize_data(merged_data)
|
| 379 |
+
|
| 380 |
+
# Fit the VAR model
|
| 381 |
+
results, optimal_lag = fit_var_model(scaled_data)
|
| 382 |
+
|
| 383 |
+
# Forecast future values
|
| 384 |
+
forecast_original, lower_95_original, upper_95_original, lower_68_original, upper_68_original = forecast_future_values(
|
| 385 |
+
results, scaled_data, scaler, n_days, optimal_lag)
|
| 386 |
+
|
| 387 |
+
volatility_predictions = forecast_original[:, 0]
|
| 388 |
+
lower_volatility_95 = lower_95_original[:, 0]
|
| 389 |
+
upper_volatility_95 = upper_95_original[:, 0]
|
| 390 |
+
lower_volatility_68 = lower_68_original[:, 0]
|
| 391 |
+
upper_volatility_68 = upper_68_original[:, 0]
|
| 392 |
+
|
| 393 |
+
future_dates = pd.date_range(start=end_date_rt + timedelta(days=1), periods=n_days, freq='B')
|
| 394 |
+
|
| 395 |
+
# Display the forecast plot
|
| 396 |
+
forecast_fig = plot_forecast(merged_data, future_dates, volatility_predictions,
|
| 397 |
+
lower_volatility_95, upper_volatility_95,
|
| 398 |
+
lower_volatility_68, upper_volatility_68)
|
| 399 |
+
|
| 400 |
+
# Store results in session_state
|
| 401 |
+
st.session_state['rt_results'] = {'forecast_fig': forecast_fig}
|
| 402 |
+
|
| 403 |
+
# Display the plot using the placeholder
|
| 404 |
+
with plot_placeholder:
|
| 405 |
+
st.subheader("Forecasted Volatility")
|
| 406 |
+
st.plotly_chart(forecast_fig)
|
| 407 |
+
elif 'rt_results' in st.session_state:
|
| 408 |
+
# Display stored plot using the placeholder
|
| 409 |
+
with plot_placeholder:
|
| 410 |
+
st.subheader("Forecasted Volatility")
|
| 411 |
+
st.plotly_chart(st.session_state['rt_results']['forecast_fig'])
|
| 412 |
+
|
| 413 |
+
elif page == "Model Performance":
|
| 414 |
+
with st.sidebar.expander("Date Range Selection", expanded=True):
|
| 415 |
+
# Model Performance page date inputs
|
| 416 |
+
start_date_mp = st.date_input(
|
| 417 |
+
"Start Date",
|
| 418 |
+
value=datetime(2020, 1, 1),
|
| 419 |
+
key='start_date_mp',
|
| 420 |
+
help="The start date for downloading historical data."
|
| 421 |
+
)
|
| 422 |
+
# Calculate the maximum allowable end date for model performance
|
| 423 |
+
today = datetime.now().date()
|
| 424 |
+
max_end_date_mp = today - timedelta(days=int(n_days + extra_days))
|
| 425 |
+
end_date_mp = st.date_input(
|
| 426 |
+
"End Date",
|
| 427 |
+
value=max_end_date_mp,
|
| 428 |
+
max_value=max_end_date_mp,
|
| 429 |
+
key='end_date_mp',
|
| 430 |
+
help="The end date for training the model. Cannot exceed the maximum allowed date."
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Context description in the main body
|
| 434 |
+
st.markdown("""
|
| 435 |
+
### Model Performance
|
| 436 |
+
Here you assess how well the model forecasts volatility by comparing predicted values with actual historical (unseen) data. djust the parameters in the sidebar and click **Run Model** to assess performance.
|
| 437 |
+
""")
|
| 438 |
+
|
| 439 |
+
with st.expander("The following analyses are performed", expanded=False):
|
| 440 |
+
st.markdown("""
|
| 441 |
+
1. **Predicted vs Actual Volatility**: The app compares the predicted stock volatility with actual volatility over a given time period. Volatility is calculated as the rolling standard deviation of daily log returns. The forecasted values are plotted alongside actual values to visualize performance.
|
| 442 |
+
|
| 443 |
+
2. **Residual Analysis**: Residuals represent the difference between the actual and predicted values. A plot of the residuals helps identify patterns or systematic errors in the predictions, such as under or overestimation.
|
| 444 |
+
|
| 445 |
+
3. **Error Metrics**: The app calculates several error metrics to quantify the accuracy of the predictions:
|
| 446 |
+
- **RMSE (Root Mean Squared Error)**: Measures the average magnitude of errors in the predictions, penalizing larger errors.
|
| 447 |
+
- **MAE (Mean Absolute Error)**: Represents the average absolute difference between predicted and actual volatility.
|
| 448 |
+
- **MAPE (Mean Absolute Percentage Error)**: Shows the prediction accuracy as a percentage, providing a relative measure of performance.
|
| 449 |
+
- **R² (R-squared)**: Indicates how well the predicted values explain the variability in the actual volatility, with a value closer to 1 indicating better performance.
|
| 450 |
+
|
| 451 |
+
4. **Confidence Intervals**: The model provides 68% and 95% confidence intervals to quantify uncertainty around the predictions. Wider intervals indicate more uncertainty, while narrower ones suggest more confidence in the forecasts.
|
| 452 |
+
|
| 453 |
+
**Instructions:**
|
| 454 |
+
- **Adjust Parameters**: Set the rolling window, forecast horizon, and extra days for comparison in the sidebar.
|
| 455 |
+
- **Run the Model**: Click **Run Model** to download data, train the model, and evaluate its performance using actual market data.
|
| 456 |
+
- **Evaluate Results**: The app visualizes the results with performance metrics, residual plots, and error analysis to help gauge how well the model performs.
|
| 457 |
+
""")
|
| 458 |
+
|
| 459 |
+
# Run button
|
| 460 |
+
run_button = st.sidebar.button("Run Model", key='run_button_mp')
|
| 461 |
+
|
| 462 |
+
# Placeholders for plots and metrics
|
| 463 |
+
forecast_placeholder = st.empty()
|
| 464 |
+
extended_forecast_placeholder = st.empty()
|
| 465 |
+
metrics_placeholder = st.empty()
|
| 466 |
+
residual_placeholder = st.empty()
|
| 467 |
+
error_placeholder = st.empty()
|
| 468 |
+
|
| 469 |
+
if run_button:
|
| 470 |
+
with st.spinner("Downloading data and processing..."):
|
| 471 |
+
# Convert end_date_mp to datetime if necessary
|
| 472 |
+
adjusted_end_date = pd.to_datetime(end_date_mp)
|
| 473 |
+
|
| 474 |
+
# Extended end date includes n_days forecast plus extra_days for comparison
|
| 475 |
+
extended_end_date = adjusted_end_date + timedelta(days=n_days + extra_days)
|
| 476 |
+
|
| 477 |
+
data = download_data(tickers, start_date_mp, extended_end_date)
|
| 478 |
+
data = calculate_returns_and_volatility(data, rolling_window)
|
| 479 |
+
merged_data = merge_data(data)
|
| 480 |
+
|
| 481 |
+
# Ensure that the data is up to adjusted_end_date for training
|
| 482 |
+
merged_data_train = merged_data[merged_data.index <= adjusted_end_date]
|
| 483 |
+
|
| 484 |
+
# Check stationarity and difference if necessary
|
| 485 |
+
merged_data_diff = merged_data_train.copy()
|
| 486 |
+
check_stationarity_and_difference(merged_data_diff)
|
| 487 |
+
merged_data_diff = merged_data_diff.dropna()
|
| 488 |
+
|
| 489 |
+
# Normalize data
|
| 490 |
+
scaled_data, scaler = normalize_data(merged_data_diff)
|
| 491 |
+
|
| 492 |
+
# Fit VAR model
|
| 493 |
+
results, optimal_lag = fit_var_model(scaled_data)
|
| 494 |
+
|
| 495 |
+
# Forecast future values
|
| 496 |
+
forecast_original, lower_95_original, upper_95_original, lower_68_original, upper_68_original = forecast_future_values(
|
| 497 |
+
results, scaled_data, scaler, n_days, optimal_lag)
|
| 498 |
+
|
| 499 |
+
volatility_predictions = forecast_original[:, 0]
|
| 500 |
+
lower_volatility_95 = lower_95_original[:, 0]
|
| 501 |
+
upper_volatility_95 = upper_95_original[:, 0]
|
| 502 |
+
lower_volatility_68 = lower_68_original[:, 0]
|
| 503 |
+
upper_volatility_68 = upper_68_original[:, 0]
|
| 504 |
+
|
| 505 |
+
# Generate future dates
|
| 506 |
+
future_dates = pd.date_range(start=adjusted_end_date + timedelta(days=1), periods=n_days, freq='B')
|
| 507 |
+
|
| 508 |
+
# Extended forecast evaluation
|
| 509 |
+
forecast_start_date = future_dates[0]
|
| 510 |
+
forecast_end_date = future_dates[-1]
|
| 511 |
+
forecast_data_extended = extended_forecast_evaluation(
|
| 512 |
+
tickers, rolling_window, forecast_start_date,
|
| 513 |
+
forecast_end_date, future_dates, volatility_predictions)
|
| 514 |
+
|
| 515 |
+
# Plot forecast with confidence intervals
|
| 516 |
+
forecast_fig = plot_forecast(merged_data_train, future_dates, volatility_predictions,
|
| 517 |
+
lower_volatility_95, upper_volatility_95,
|
| 518 |
+
lower_volatility_68, upper_volatility_68)
|
| 519 |
+
|
| 520 |
+
# Plot extended forecast comparison
|
| 521 |
+
extended_forecast_fig = plot_extended_forecast(forecast_data_extended, future_dates, volatility_predictions)
|
| 522 |
+
|
| 523 |
+
# Calculate and display performance metrics
|
| 524 |
+
performance_metrics = calculate_performance_metrics(forecast_data_extended, future_dates, volatility_predictions)
|
| 525 |
+
|
| 526 |
+
# Plot residuals using Plotly
|
| 527 |
+
residual_fig = plot_residuals_plotly(results)
|
| 528 |
+
|
| 529 |
+
# Calculate metrics and plot errors using Plotly
|
| 530 |
+
error_fig = calculate_metrics_and_plot_errors_plotly(forecast_data_extended, future_dates, volatility_predictions)
|
| 531 |
+
|
| 532 |
+
# Store results in session_state
|
| 533 |
+
st.session_state['mp_results'] = {
|
| 534 |
+
'forecast_fig': forecast_fig,
|
| 535 |
+
'extended_forecast_fig': extended_forecast_fig,
|
| 536 |
+
'performance_metrics': performance_metrics,
|
| 537 |
+
'residual_fig': residual_fig,
|
| 538 |
+
'error_fig': error_fig
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
# Display plots and metrics using placeholders
|
| 542 |
+
with forecast_placeholder:
|
| 543 |
+
st.subheader("Forecast with Confidence Intervals")
|
| 544 |
+
st.plotly_chart(forecast_fig)
|
| 545 |
+
|
| 546 |
+
with extended_forecast_placeholder:
|
| 547 |
+
st.subheader("Extended Forecast Evaluation")
|
| 548 |
+
st.plotly_chart(extended_forecast_fig)
|
| 549 |
+
|
| 550 |
+
with metrics_placeholder:
|
| 551 |
+
st.markdown("#### Performance Metrics")
|
| 552 |
+
st.markdown(performance_metrics)
|
| 553 |
+
|
| 554 |
+
with residual_placeholder:
|
| 555 |
+
st.subheader("Residuals of Model")
|
| 556 |
+
st.plotly_chart(residual_fig)
|
| 557 |
+
|
| 558 |
+
with error_placeholder:
|
| 559 |
+
st.subheader("Prediction Errors")
|
| 560 |
+
st.plotly_chart(error_fig)
|
| 561 |
+
|
| 562 |
+
elif 'mp_results' in st.session_state:
|
| 563 |
+
# Display stored results using placeholders
|
| 564 |
+
with forecast_placeholder:
|
| 565 |
+
st.subheader("Forecast with Confidence Intervals")
|
| 566 |
+
st.plotly_chart(st.session_state['mp_results']['forecast_fig'])
|
| 567 |
+
|
| 568 |
+
with extended_forecast_placeholder:
|
| 569 |
+
st.subheader("Extended Forecast Evaluation")
|
| 570 |
+
st.plotly_chart(st.session_state['mp_results']['extended_forecast_fig'])
|
| 571 |
+
|
| 572 |
+
with metrics_placeholder:
|
| 573 |
+
st.markdown("#### Performance Metrics")
|
| 574 |
+
st.markdown(st.session_state['mp_results']['performance_metrics'])
|
| 575 |
+
|
| 576 |
+
with residual_placeholder:
|
| 577 |
+
st.subheader("Residuals of Model")
|
| 578 |
+
st.plotly_chart(st.session_state['mp_results']['residual_fig'])
|
| 579 |
+
|
| 580 |
+
with error_placeholder:
|
| 581 |
+
st.subheader("Prediction Errors")
|
| 582 |
+
st.plotly_chart(st.session_state['mp_results']['error_fig'])
|
| 583 |
+
|
| 584 |
+
hide_streamlit_style = """
|
| 585 |
+
<style>
|
| 586 |
+
#MainMenu {visibility: hidden;}
|
| 587 |
+
footer {visibility: hidden;}
|
| 588 |
+
</style>
|
| 589 |
+
"""
|
| 590 |
+
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
|