File size: 9,595 Bytes
15b68db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
"""
Chart generation for forecast visualization
"""

import plotly.graph_objs as go
from plotly.subplots import make_subplots
import pandas as pd
from typing import List
from config.constants import COLORS, CHART_CONFIG


def create_forecast_chart(
    historical_data: pd.DataFrame,
    forecast_data: pd.DataFrame,
    confidence_levels: List[int],
    title: str = "Time Series Forecast",
    y_axis_label: str = "Value",
    backtest_data: pd.DataFrame = None
) -> go.Figure:
    """
    Create an interactive forecast chart with confidence intervals

    Args:
        historical_data: DataFrame with columns ['ds', 'y']
        forecast_data: DataFrame with forecast and confidence intervals
        confidence_levels: List of confidence levels to plot
        title: Chart title
        y_axis_label: Label for y-axis (variable name being forecasted)
        backtest_data: Optional DataFrame with backtest results

    Returns:
        Plotly figure
    """
    fig = go.Figure()

    # Add historical data
    fig.add_trace(go.Scatter(
        x=historical_data['ds'],
        y=historical_data['y'],
        mode='lines',
        name='Historical',
        line=dict(color=COLORS['historical'], width=2),
        hovertemplate=f'<b>Date:</b> %{{x}}<br><b>{y_axis_label}:</b> %{{y:.2f}}<extra></extra>'
    ))

    # Add backtest data if provided (shows model performance on historical data)
    if backtest_data is not None and len(backtest_data) > 0:
        # Add actual values from backtest period
        fig.add_trace(go.Scatter(
            x=backtest_data['timestamp'],
            y=backtest_data['actual'],
            mode='lines',
            name='Backtest Actual',
            line=dict(color='rgba(100, 100, 100, 0.6)', width=2, dash='dot'),
            hovertemplate=f'<b>Date:</b> %{{x}}<br><b>{y_axis_label} (Actual):</b> %{{y:.2f}}<extra></extra>'
        ))

        # Add predicted values from backtest period
        fig.add_trace(go.Scatter(
            x=backtest_data['timestamp'],
            y=backtest_data['predicted'],
            mode='lines',
            name='Backtest Predicted',
            line=dict(color='rgba(255, 100, 100, 0.8)', width=2),
            hovertemplate=f'<b>Date:</b> %{{x}}<br><b>{y_axis_label} (Predicted):</b> %{{y:.2f}}<extra></extra>'
        ))

    # Add confidence bands (from widest to narrowest)
    for cl in sorted(confidence_levels, reverse=True):
        lower_col = f'lower_{cl}'
        upper_col = f'upper_{cl}'

        if lower_col in forecast_data.columns and upper_col in forecast_data.columns:
            # Add filled area for confidence interval
            fig.add_trace(go.Scatter(
                x=forecast_data['ds'].tolist() + forecast_data['ds'].tolist()[::-1],
                y=forecast_data[upper_col].tolist() + forecast_data[lower_col].tolist()[::-1],
                fill='toself',
                fillcolor=COLORS['confidence'][cl],
                line=dict(width=0),
                name=f'{cl}% Confidence',
                showlegend=True,
                hoverinfo='skip'
            ))

    # Add forecast line
    fig.add_trace(go.Scatter(
        x=forecast_data['ds'],
        y=forecast_data['forecast'],
        mode='lines',
        name='Forecast',
        line=dict(color=COLORS['forecast'], width=2),
        hovertemplate=f'<b>Date:</b> %{{x}}<br><b>{y_axis_label} (Forecast):</b> %{{y:.2f}}<extra></extra>'
    ))

    # Add vertical separator line
    if len(historical_data) > 0:
        last_historical_date = historical_data['ds'].iloc[-1]
        # Use add_shape instead of add_vline to avoid Timestamp arithmetic issues
        fig.add_shape(
            type="line",
            x0=last_historical_date,
            x1=last_historical_date,
            y0=0,
            y1=1,
            yref="paper",
            line=dict(color=COLORS['separator'], dash="dash", width=1)
        )
        # Add annotation
        fig.add_annotation(
            x=last_historical_date,
            y=1.0,
            yref="paper",
            text="Forecast Start",
            showarrow=False,
            yanchor="bottom"
        )

    # Update layout
    fig.update_layout(
        title=dict(text=title, x=0.5, xanchor='center'),
        xaxis_title="Date",
        yaxis_title=y_axis_label,
        hovermode='x unified',
        template='plotly_white',
        height=700,  # Increased height to accommodate rangeslider
        showlegend=True,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        ),
        margin=dict(l=50, r=50, t=80, b=150),  # Increased bottom margin for larger rangeslider
        xaxis=dict(
            rangeslider=dict(
                visible=True,
                thickness=0.12  # Wider slider (12% of chart height)
            ),
            type='date'
        )
    )

    # Update config
    fig.update_layout(
        modebar_add=['v1hovermode', 'toggleSpikelines']
    )

    return fig


def create_empty_chart(message: str = "No data available") -> go.Figure:
    """
    Create an empty placeholder chart

    Args:
        message: Message to display

    Returns:
        Plotly figure
    """
    fig = go.Figure()

    fig.add_annotation(
        text=message,
        xref="paper",
        yref="paper",
        x=0.5,
        y=0.5,
        showarrow=False,
        font=dict(size=20, color='gray')
    )

    fig.update_layout(
        template='plotly_white',
        height=600,
        xaxis=dict(visible=False),
        yaxis=dict(visible=False)
    )

    return fig


def create_metrics_display(metrics: dict, inference_time: float = None) -> list:
    """
    Create metrics display components

    Args:
        metrics: Dictionary of metric values
        inference_time: Time taken for inference in seconds

    Returns:
        List of Dash components
    """
    import dash_bootstrap_components as dbc
    from dash import html

    metric_cards = []

    # Add inference time if available
    if inference_time is not None:
        metric_cards.append(
            dbc.Col([
                dbc.Card([
                    dbc.CardBody([
                        html.H6("Inference Time", className="text-muted mb-2"),
                        html.H4(f"{inference_time:.2f}s")
                    ])
                ], className="text-center")
            ], md=2)
        )

    # Add other metrics
    metric_names = {
        'MAE': 'Mean Absolute Error',
        'RMSE': 'Root Mean Squared Error',
        'MAPE': 'Mean Absolute % Error',
        'R2': 'R-Squared'
    }

    for key, name in metric_names.items():
        if key in metrics and metrics[key] is not None:
            value = metrics[key]
            if key in ['MAPE']:
                formatted_value = f"{value:.2f}%"
            elif key == 'R2':
                formatted_value = f"{value:.4f}"
            else:
                formatted_value = f"{value:.2f}"

            metric_cards.append(
                dbc.Col([
                    dbc.Card([
                        dbc.CardBody([
                            html.H6(name, className="text-muted mb-2"),
                            html.H4(formatted_value)
                        ])
                    ], className="text-center")
                ], md=2)
            )

    return metric_cards


def create_backtest_metrics_display(metrics: dict) -> list:
    """
    Create backtest metrics display components

    Args:
        metrics: Dictionary of backtest metric values (MAE, RMSE, MAPE, R2)

    Returns:
        Dash component card
    """
    import dash_bootstrap_components as dbc
    from dash import html

    return dbc.Card([
        dbc.CardHeader([
            html.I(className="fas fa-chart-bar me-2"),
            html.Span("Backtest Performance Metrics", className="fw-bold")
        ]),
        dbc.CardBody([
            html.P("Model performance on historical data validation:", className="text-muted small mb-3"),
            dbc.Row([
                dbc.Col([
                    html.Small("MAE", className="text-muted"),
                    html.H5(f"{metrics.get('MAE', 0):.2f}", className="mb-0")
                ], md=3),
                dbc.Col([
                    html.Small("RMSE", className="text-muted"),
                    html.H5(f"{metrics.get('RMSE', 0):.2f}", className="mb-0")
                ], md=3),
                dbc.Col([
                    html.Small("MAPE", className="text-muted"),
                    html.H5(f"{metrics.get('MAPE', 0):.2f}%", className="mb-0")
                ], md=3),
                dbc.Col([
                    html.Small("R²", className="text-muted"),
                    html.H5(f"{metrics.get('R2', 0):.4f}", className="mb-0")
                ], md=3),
            ]),
            html.Hr(),
            html.Small([
                html.I(className="fas fa-info-circle me-1"),
                "Lower MAE/RMSE/MAPE and higher R² (closer to 1.0) indicate better model performance"
            ], className="text-muted")
        ])
    ], className="mt-3")


def decimate_data(df: pd.DataFrame, max_points: int = 10000) -> pd.DataFrame:
    """
    Reduce number of data points for visualization

    Args:
        df: Input DataFrame
        max_points: Maximum number of points to keep

    Returns:
        Decimated DataFrame
    """
    if len(df) <= max_points:
        return df

    # Use systematic sampling
    step = len(df) // max_points
    return df.iloc[::step].reset_index(drop=True)