Spaces:
Running
Running
| """ | |
| src.utils.plotting | |
| ================== | |
| Research-grade visualization helpers for battery lifecycle analysis. | |
| All functions produce Matplotlib figures and optionally save them to | |
| ``artifacts/figures/``. Plotly interactive versions are generated | |
| where noted. | |
| """ | |
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Sequence | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import seaborn as sns | |
| from src.utils.config import ( | |
| CMAP_DIVERGING, | |
| CMAP_SEQUENTIAL, | |
| FIG_DPI, | |
| FIG_SIZE, | |
| FIGURES_DIR, | |
| MATPLOTLIB_STYLE, | |
| ) | |
| # Apply research-grade style | |
| try: | |
| plt.style.use(MATPLOTLIB_STYLE) | |
| except OSError: | |
| plt.style.use("seaborn-v0_8") | |
| sns.set_context("paper", font_scale=1.3) | |
| def save_fig(fig: plt.Figure, name: str, tight: bool = True, directory: Path | None = None) -> Path: | |
| """Save figure as PNG to a figures directory. | |
| Parameters | |
| ---------- | |
| fig: | |
| Matplotlib figure to save. | |
| name: | |
| Base filename (without extension). | |
| tight: | |
| Whether to call ``tight_layout()`` before saving. | |
| directory: | |
| Target directory. Defaults to ``FIGURES_DIR`` (artifacts/figures/). | |
| Returns | |
| ------- | |
| Path | |
| Absolute path to the saved PNG file. | |
| """ | |
| if tight: | |
| fig.tight_layout() | |
| out_dir = directory if directory is not None else FIGURES_DIR | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| path = out_dir / f"{name}.png" | |
| fig.savefig(path, dpi=FIG_DPI, bbox_inches="tight") | |
| return path | |
| # ββ Capacity fade curves ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_capacity_fade( | |
| cap_df: pd.DataFrame, | |
| battery_ids: list[str] | None = None, | |
| eol_threshold: float | None = 1.4, | |
| title: str = "Capacity Fade Curves", | |
| save_name: str | None = "capacity_fade", | |
| ) -> plt.Figure: | |
| """MATLAB-style capacity degradation: Capacity (Ah) vs cycle number, per battery.""" | |
| fig, ax = plt.subplots(figsize=FIG_SIZE) | |
| bats = battery_ids or sorted(cap_df["battery_id"].unique()) | |
| cmap = plt.cm.get_cmap("tab20", len(bats)) | |
| for i, bid in enumerate(bats): | |
| sub = cap_df[cap_df["battery_id"] == bid] | |
| ax.plot(sub["cycle_number"], sub["Capacity"], label=bid, color=cmap(i), linewidth=1.2) | |
| if eol_threshold: | |
| ax.axhline(y=eol_threshold, color="crimson", linestyle="--", linewidth=1.5, label=f"EOL = {eol_threshold} Ah") | |
| ax.set_xlabel("Cycle Number") | |
| ax.set_ylabel("Discharge Capacity (Ah)") | |
| ax.set_title(title) | |
| ax.legend(fontsize=7, ncol=4, loc="upper right") | |
| ax.grid(True, alpha=0.3) | |
| if save_name: | |
| save_fig(fig, save_name) | |
| return fig | |
| # ββ SOH degradation with trend ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_soh_degradation( | |
| cap_df: pd.DataFrame, | |
| battery_id: str, | |
| save_name: str | None = None, | |
| ) -> plt.Figure: | |
| """SOH (%) vs cycle with linear + exponential trend lines for one battery.""" | |
| sub = cap_df[cap_df["battery_id"] == battery_id].copy() | |
| x = sub["cycle_number"].values.astype(float) | |
| y = sub["SoH"].values if "SoH" in sub.columns else (sub["Capacity"].values / 2.0) * 100 | |
| fig, ax = plt.subplots(figsize=FIG_SIZE) | |
| ax.scatter(x, y, s=12, alpha=0.6, label="Measured SOH") | |
| # Linear fit | |
| if len(x) > 2: | |
| coeffs_lin = np.polyfit(x, y, 1) | |
| ax.plot(x, np.polyval(coeffs_lin, x), "r--", linewidth=1.5, | |
| label=f"Linear: slope={coeffs_lin[0]:.4f}%/cycle") | |
| # Exponential fit y = a * exp(b * x) β log(y) = log(a) + b*x | |
| try: | |
| valid = y > 0 | |
| coeffs_exp = np.polyfit(x[valid], np.log(y[valid]), 1) | |
| y_exp = np.exp(coeffs_exp[1]) * np.exp(coeffs_exp[0] * x) | |
| ax.plot(x, y_exp, "g-.", linewidth=1.5, | |
| label=f"Exponential: Ξ»={coeffs_exp[0]:.6f}") | |
| except Exception: | |
| pass | |
| ax.set_xlabel("Cycle Number") | |
| ax.set_ylabel("State of Health (%)") | |
| ax.set_title(f"SOH Degradation β {battery_id}") | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| if save_name: | |
| save_fig(fig, save_name) | |
| return fig | |
| # ββ Correlation heatmap βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_correlation_heatmap( | |
| df: pd.DataFrame, | |
| columns: list[str] | None = None, | |
| save_name: str | None = "correlation_heatmap", | |
| ) -> plt.Figure: | |
| """Seaborn heatmap of feature correlations.""" | |
| if columns: | |
| df = df[columns].dropna() | |
| else: | |
| df = df.select_dtypes(include=[np.number]).dropna(axis=1, how="all") | |
| corr = df.corr() | |
| fig, ax = plt.subplots(figsize=(14, 10)) | |
| mask = np.triu(np.ones_like(corr, dtype=bool)) | |
| sns.heatmap(corr, mask=mask, annot=True, fmt=".2f", cmap="coolwarm", | |
| center=0, square=True, linewidths=0.5, ax=ax, | |
| cbar_kws={"shrink": 0.8}) | |
| ax.set_title("Feature Correlation Matrix") | |
| if save_name: | |
| save_fig(fig, save_name) | |
| return fig | |
| # ββ Box / violin plots ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_capacity_by_temperature( | |
| cap_df: pd.DataFrame, | |
| save_name: str | None = "capacity_by_temp", | |
| ) -> plt.Figure: | |
| """Violin plot of discharge capacity by ambient temperature group.""" | |
| fig, ax = plt.subplots(figsize=FIG_SIZE) | |
| sns.violinplot(data=cap_df, x="ambient_temperature", y="Capacity", ax=ax, | |
| inner="quartile", palette="coolwarm", cut=0) | |
| ax.set_xlabel("Ambient Temperature (Β°C)") | |
| ax.set_ylabel("Discharge Capacity (Ah)") | |
| ax.set_title("Capacity Distribution by Temperature") | |
| if save_name: | |
| save_fig(fig, save_name) | |
| return fig | |
| # ββ Training loss curves ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_training_curves( | |
| train_losses: list[float], | |
| val_losses: list[float] | None = None, | |
| title: str = "Training Loss", | |
| save_name: str | None = None, | |
| ) -> plt.Figure: | |
| """Standard training/validation loss curves.""" | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| epochs = range(1, len(train_losses) + 1) | |
| ax.plot(epochs, train_losses, "b-", linewidth=1.5, label="Train Loss") | |
| if val_losses: | |
| ax.plot(epochs, val_losses, "r-", linewidth=1.5, label="Val Loss") | |
| ax.set_xlabel("Epoch") | |
| ax.set_ylabel("Loss") | |
| ax.set_title(title) | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| if save_name: | |
| save_fig(fig, save_name) | |
| return fig | |
| # ββ Actual vs Predicted scatter βββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_actual_vs_predicted( | |
| y_true: np.ndarray, | |
| y_pred: np.ndarray, | |
| label: str = "SOH (%)", | |
| model_name: str = "", | |
| save_name: str | None = None, | |
| ) -> plt.Figure: | |
| """Scatter with identity line and RΒ² annotation.""" | |
| from sklearn.metrics import r2_score | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.scatter(y_true, y_pred, s=15, alpha=0.5, edgecolors="none") | |
| lims = [min(y_true.min(), y_pred.min()), max(y_true.max(), y_pred.max())] | |
| ax.plot(lims, lims, "r--", linewidth=1.5, label="Perfect prediction") | |
| r2 = r2_score(y_true, y_pred) | |
| ax.annotate(f"RΒ² = {r2:.4f}", xy=(0.05, 0.92), xycoords="axes fraction", | |
| fontsize=14, fontweight="bold", | |
| bbox=dict(boxstyle="round,pad=0.3", facecolor="lightyellow")) | |
| ax.set_xlabel(f"Actual {label}") | |
| ax.set_ylabel(f"Predicted {label}") | |
| ax.set_title(f"Actual vs Predicted β {model_name}" if model_name else "Actual vs Predicted") | |
| ax.legend() | |
| ax.set_aspect("equal") | |
| ax.grid(True, alpha=0.3) | |
| if save_name: | |
| save_fig(fig, save_name) | |
| return fig | |
| # ββ Residual analysis βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_residuals( | |
| y_true: np.ndarray, | |
| y_pred: np.ndarray, | |
| model_name: str = "", | |
| save_name: str | None = None, | |
| ) -> plt.Figure: | |
| """Histogram + KDE of prediction residuals and Q-Q plot.""" | |
| import scipy.stats as stats | |
| residuals = y_true - y_pred | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 6)) | |
| # Histogram + KDE | |
| sns.histplot(residuals, kde=True, ax=axes[0], color="steelblue", bins=40) | |
| axes[0].axvline(x=0, color="red", linestyle="--") | |
| axes[0].set_xlabel("Residual") | |
| axes[0].set_title(f"Residual Distribution β {model_name}") | |
| # Q-Q plot | |
| stats.probplot(residuals, dist="norm", plot=axes[1]) | |
| axes[1].set_title(f"Q-Q Plot β {model_name}") | |
| fig.tight_layout() | |
| if save_name: | |
| save_fig(fig, save_name) | |
| return fig | |
| # ββ Radar / Spider chart ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_radar_chart( | |
| metrics_dict: dict[str, dict[str, float]], | |
| title: str = "Model Comparison (Normalized Metrics)", | |
| save_name: str | None = "radar_chart", | |
| ) -> plt.Figure: | |
| """Radar chart comparing multiple models across multiple metrics. | |
| Parameters | |
| ---------- | |
| metrics_dict : dict | |
| ``{model_name: {metric_name: value, ...}, ...}`` | |
| """ | |
| models = list(metrics_dict.keys()) | |
| metric_names = list(next(iter(metrics_dict.values())).keys()) | |
| N = len(metric_names) | |
| angles = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist() | |
| angles += angles[:1] | |
| fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True)) | |
| cmap = plt.cm.get_cmap("Set2", len(models)) | |
| for i, model in enumerate(models): | |
| values = [metrics_dict[model][m] for m in metric_names] | |
| values += values[:1] | |
| ax.plot(angles, values, "o-", linewidth=2, label=model, color=cmap(i)) | |
| ax.fill(angles, values, alpha=0.1, color=cmap(i)) | |
| ax.set_xticks(angles[:-1]) | |
| ax.set_xticklabels(metric_names, fontsize=10) | |
| ax.set_title(title, size=14, pad=20) | |
| ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.1)) | |
| if save_name: | |
| save_fig(fig, save_name) | |
| return fig | |
| # ββ Cumulative Error Distribution ββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_ced( | |
| errors_dict: dict[str, np.ndarray], | |
| title: str = "Cumulative Error Distribution", | |
| save_name: str | None = "ced_curve", | |
| ) -> plt.Figure: | |
| """CED curves: P(|error| < threshold) vs threshold, for multiple models.""" | |
| fig, ax = plt.subplots(figsize=FIG_SIZE) | |
| for name, errors in errors_dict.items(): | |
| sorted_err = np.sort(np.abs(errors)) | |
| cdf = np.arange(1, len(sorted_err) + 1) / len(sorted_err) | |
| ax.plot(sorted_err, cdf, linewidth=1.8, label=name) | |
| ax.set_xlabel("Absolute Error") | |
| ax.set_ylabel("Cumulative Probability") | |
| ax.set_title(title) | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| ax.set_xlim(left=0) | |
| ax.set_ylim(0, 1.05) | |
| if save_name: | |
| save_fig(fig, save_name) | |
| return fig | |
| # ββ Error heatmap (models Γ batteries) ββββββββββββββββββββββββββββββββββββββ | |
| def plot_error_heatmap( | |
| error_df: pd.DataFrame, | |
| title: str = "Per-Battery Error Heatmap", | |
| save_name: str | None = "error_heatmap", | |
| ) -> plt.Figure: | |
| """Heatmap of MAE per model per battery. | |
| Parameters | |
| ---------- | |
| error_df : pd.DataFrame | |
| Rows = models, Columns = batteries, Values = MAE. | |
| """ | |
| fig, ax = plt.subplots(figsize=(14, max(6, len(error_df) * 0.8))) | |
| sns.heatmap(error_df, annot=True, fmt=".3f", cmap="YlOrRd", | |
| linewidths=0.5, ax=ax, cbar_kws={"label": "MAE"}) | |
| ax.set_title(title) | |
| ax.set_xlabel("Battery ID") | |
| ax.set_ylabel("Model") | |
| if save_name: | |
| save_fig(fig, save_name) | |
| return fig | |
| # ββ Model comparison bar chart βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_model_comparison_bars( | |
| summary_df: pd.DataFrame, | |
| metric_cols: list[str], | |
| model_col: str = "model", | |
| title: str = "Model Comparison", | |
| save_name: str | None = "model_comparison_bars", | |
| ) -> plt.Figure: | |
| """Grouped bar chart comparing models on multiple metrics.""" | |
| n_metrics = len(metric_cols) | |
| n_models = len(summary_df) | |
| x = np.arange(n_models) | |
| width = 0.8 / n_metrics | |
| fig, ax = plt.subplots(figsize=(max(12, n_models * 1.5), 7)) | |
| cmap = plt.cm.get_cmap("Set2", n_metrics) | |
| for i, col in enumerate(metric_cols): | |
| offset = (i - n_metrics / 2 + 0.5) * width | |
| bars = ax.bar(x + offset, summary_df[col], width, label=col, color=cmap(i)) | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(summary_df[model_col], rotation=30, ha="right") | |
| ax.set_title(title) | |
| ax.legend() | |
| ax.grid(True, alpha=0.3, axis="y") | |
| fig.tight_layout() | |
| if save_name: | |
| save_fig(fig, save_name) | |
| return fig | |