import logging import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch import torchmetrics from matplotlib.figure import Figure import plotly.graph_objects as go import io import base64 from src.data.containers import BatchTimeSeriesContainer from src.data.frequency import Frequency logger = logging.getLogger(__name__) def matplotlib_to_plotly(fig_matplotlib): """Convert matplotlib figure to Plotly figure for Gradio compatibility.""" try: # Convert matplotlib figure to image bytes buf = io.BytesIO() fig_matplotlib.savefig(buf, format='png', dpi=100, bbox_inches='tight') buf.seek(0) img_str = base64.b64encode(buf.read()).decode('utf-8') buf.close() # Create a Plotly figure with the image fig_plotly = go.Figure() # Add image trace fig_plotly.add_trace(go.Image( source=f'data:image/png;base64,{img_str}' )) # Update layout to remove axes and make image fill the space fig_plotly.update_layout( xaxis=dict(visible=False), yaxis=dict(visible=False), margin=dict(l=0, r=0, t=0, b=0), width=800, height=400 ) # Close the matplotlib figure to free memory plt.close(fig_matplotlib) return fig_plotly except Exception as e: logger.error(f"Failed to convert matplotlib figure to Plotly: {e}") # Return a simple error message figure fig_plotly = go.Figure() fig_plotly.add_annotation( text="Error: Could not generate plot", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=14, color="red") ) fig_plotly.update_layout(width=600, height=300) return fig_plotly def calculate_smape(y_true: np.ndarray, y_pred: np.ndarray) -> float: """Calculate Symmetric Mean Absolute Percentage Error (SMAPE).""" pred_tensor = torch.from_numpy(y_pred).float() true_tensor = torch.from_numpy(y_true).float() return torchmetrics.SymmetricMeanAbsolutePercentageError()(pred_tensor, true_tensor).item() def _create_date_ranges( start: np.datetime64 | pd.Timestamp | None, frequency: Frequency | str | None, history_length: int, prediction_length: int, ) -> tuple[pd.DatetimeIndex, pd.DatetimeIndex]: """Create date ranges for history and future periods.""" if start is not None and frequency is not None: start_timestamp = pd.Timestamp(start) pandas_freq = frequency.to_pandas_freq(for_date_range=True) history_dates = pd.date_range(start=start_timestamp, periods=history_length, freq=pandas_freq) if prediction_length > 0: next_timestamp = history_dates[-1] + pd.tseries.frequencies.to_offset(pandas_freq) future_dates = pd.date_range(start=next_timestamp, periods=prediction_length, freq=pandas_freq) else: future_dates = pd.DatetimeIndex([]) else: # Fallback to default daily frequency history_dates = pd.date_range(end=pd.Timestamp.now(), periods=history_length, freq="D") if prediction_length > 0: future_dates = pd.date_range( start=history_dates[-1] + pd.Timedelta(days=1), periods=prediction_length, freq="D", ) else: future_dates = pd.DatetimeIndex([]) return history_dates, future_dates def _plot_single_channel( ax: plt.Axes, channel_idx: int, history_dates: pd.DatetimeIndex, future_dates: pd.DatetimeIndex, history_values: np.ndarray, future_values: np.ndarray | None = None, predicted_values: np.ndarray | None = None, lower_bound: np.ndarray | None = None, upper_bound: np.ndarray | None = None, ) -> None: """Plot a single channel's time series data.""" # Plot history ax.plot(history_dates, history_values[:, channel_idx], color="black", label="History") # Plot ground truth future if future_values is not None: ax.plot( future_dates, future_values[:, channel_idx], color="blue", label="Ground Truth", ) # Plot predictions if predicted_values is not None: ax.plot( future_dates, predicted_values[:, channel_idx], color="orange", linestyle="--", label="Prediction (Median)", ) # Plot uncertainty band if lower_bound is not None and upper_bound is not None: ax.fill_between( future_dates, lower_bound[:, channel_idx], upper_bound[:, channel_idx], color="orange", alpha=0.2, label="Uncertainty Band", ) ax.set_title(f"Channel {channel_idx + 1}") ax.grid(True, which="both", linestyle="--", linewidth=0.5) def _setup_figure(num_channels: int) -> tuple[Figure, list[plt.Axes]]: """Create and configure the matplotlib figure and axes.""" fig, axes = plt.subplots(num_channels, 1, figsize=(15, 3 * num_channels), sharex=True) if num_channels == 1: axes = [axes] return fig, axes def _finalize_plot( fig: Figure, axes: list[plt.Axes], title: str | None = None, smape_value: float | None = None, output_file: str | None = None, show: bool = True, ) -> None: """Add legend, title, and save/show the plot.""" # Create legend from first axis handles, labels = axes[0].get_legend_handles_labels() fig.legend(handles, labels, loc="upper right") # Set title with optional SMAPE if title: if smape_value is not None: title = f"{title} | SMAPE: {smape_value:.4f}" fig.suptitle(title, fontsize=16) # Adjust layout plt.tight_layout(rect=[0, 0.03, 1, 0.95] if title else None) # Save and/or show if output_file: plt.savefig(output_file, dpi=300) if show: plt.show() else: plt.close(fig) def plot_multivariate_timeseries( history_values: np.ndarray, future_values: np.ndarray | None = None, predicted_values: np.ndarray | None = None, start: np.datetime64 | pd.Timestamp | None = None, frequency: Frequency | str | None = None, title: str | None = None, output_file: str | None = None, show: bool = True, lower_bound: np.ndarray | None = None, upper_bound: np.ndarray | None = None, ) -> go.Figure: """Plot a multivariate time series with history, future, predictions, and uncertainty bands.""" # Calculate SMAPE if both predicted and true values are available smape_value = None if predicted_values is not None and future_values is not None: try: smape_value = calculate_smape(future_values, predicted_values) except Exception as e: logger.warning(f"Failed to calculate SMAPE: {str(e)}") # Extract dimensions num_channels = history_values.shape[1] history_length = history_values.shape[0] prediction_length = ( predicted_values.shape[0] if predicted_values is not None else (future_values.shape[0] if future_values is not None else 0) ) # Create date ranges history_dates, future_dates = _create_date_ranges(start, frequency, history_length, prediction_length) # Setup figure fig, axes = _setup_figure(num_channels) # Plot each channel for i in range(num_channels): _plot_single_channel( ax=axes[i], channel_idx=i, history_dates=history_dates, future_dates=future_dates, history_values=history_values, future_values=future_values, predicted_values=predicted_values, lower_bound=lower_bound, upper_bound=upper_bound, ) # Finalize plot _finalize_plot(fig, axes, title, smape_value, output_file, show) # Convert to Plotly for Gradio compatibility return matplotlib_to_plotly(fig) def _extract_quantile_predictions( predicted_values: np.ndarray, model_quantiles: list[float], ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]: """Extract median, lower, and upper bound predictions from quantile output.""" try: median_idx = model_quantiles.index(0.5) lower_idx = model_quantiles.index(0.1) upper_idx = model_quantiles.index(0.9) median_preds = predicted_values[..., median_idx] lower_bound = predicted_values[..., lower_idx] upper_bound = predicted_values[..., upper_idx] return median_preds, lower_bound, upper_bound except (ValueError, IndexError): logger.warning("Could not find 0.1, 0.5, 0.9 quantiles for plotting. Using median of available quantiles.") median_preds = predicted_values[..., predicted_values.shape[-1] // 2] return median_preds, None, None def plot_from_container( batch: BatchTimeSeriesContainer, sample_idx: int, predicted_values: np.ndarray | None = None, model_quantiles: list[float] | None = None, title: str | None = None, output_file: str | None = None, show: bool = True, ) -> go.Figure: """Plot a single sample from a BatchTimeSeriesContainer with proper quantile handling.""" # Extract data for the specific sample history_values = batch.history_values[sample_idx].cpu().numpy() future_values = batch.future_values[sample_idx].cpu().numpy() # Process predictions if predicted_values is not None: # Handle batch vs single sample predictions if predicted_values.ndim >= 3 or ( predicted_values.ndim == 2 and predicted_values.shape[0] > future_values.shape[0] ): sample_preds = predicted_values[sample_idx] else: sample_preds = predicted_values # Extract quantile information if available if model_quantiles: median_preds, lower_bound, upper_bound = _extract_quantile_predictions(sample_preds, model_quantiles) else: median_preds = sample_preds lower_bound = None upper_bound = None else: median_preds = None lower_bound = None upper_bound = None # Create the plot return plot_multivariate_timeseries( history_values=history_values, future_values=future_values, predicted_values=median_preds, start=batch.start[sample_idx], frequency=batch.frequency[sample_idx], title=title, output_file=output_file, show=show, lower_bound=lower_bound, upper_bound=upper_bound, )