""" Visualisation utilities — training curves and comparison plots. Uses only Matplotlib (no Seaborn / Plotly dependency). All plots are saved to disk (non-interactive backend). """ from __future__ import annotations from pathlib import Path import numpy as np try: import matplotlib matplotlib.use("Agg") # Non-interactive backend (safe on servers) import matplotlib.pyplot as plt _MPL_OK = True except ImportError: _MPL_OK = False plt = None # type: ignore # ── Helper ──────────────────────────────────────────────────────────────────── def _check_mpl(): if not _MPL_OK: raise ImportError( "matplotlib is required for plotting.\n" "Install with: pip install matplotlib" ) def _moving_average(values: list, window: int) -> list: """Simple unweighted moving average.""" result = [] for i in range(len(values)): start = max(0, i - window + 1) result.append(float(np.mean(values[start : i + 1]))) return result # ── Public functions ────────────────────────────────────────────────────────── def plot_training_curves(metrics, save_path: str | Path | None = None) -> bool: """ Plot four training metrics in a 2×2 grid and save to *save_path*. Args: metrics: A :class:`MetricsTracker` instance. save_path: Destination PNG path. Shown interactively if None. Returns: True on success, False on failure. """ try: _check_mpl() panel_cfg = [ ("episode_reward", "Episode Reward", "blue", "Reward"), ("average_waiting_time", "Avg Waiting Time", "orange", "Waiting Time (s)"), ("average_queue_length", "Avg Queue Length", "red", "Queue Length"), ("throughput", "Throughput", "green", "Vehicles Passed"), ] has_any = any(metrics.has(k) for k, *_ in panel_cfg) if not has_any: print("[WARN] No data available for plotting.") return False fig, axes = plt.subplots(2, 2, figsize=(15, 10)) fig.suptitle("Training Progress", fontsize=16, fontweight="bold") axes_flat = axes.flatten() for ax, (key, title, colour, ylabel) in zip(axes_flat, panel_cfg): ax.set_title(title, fontsize=12, fontweight="bold") ax.set_xlabel("Episode", fontsize=10) ax.set_ylabel(ylabel, fontsize=10) ax.grid(True, alpha=0.3) if not metrics.has(key): ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes, color="grey") continue vals = metrics.get(key) eps = range(1, len(vals) + 1) ax.plot(eps, vals, alpha=0.4, color=colour, linewidth=1, label="Raw") if len(vals) >= 10: w = min(50, max(10, len(vals) // 10)) ma = _moving_average(vals, w) ax.plot(eps, ma, color=colour, linewidth=2, label=f"MA-{w}") ax.legend(loc="best", fontsize=8) plt.tight_layout() if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="white") print(f"[OK] Plot saved -> {save_path}") else: plt.show() plt.close(fig) return True except ImportError as exc: print(f"[WARN] {exc}") return False except Exception as exc: print(f"[WARN] Plotting error: {exc}") try: plt.close("all") except Exception: pass return False def plot_comparison( results_dict: dict[str, list], metric_name: str, save_path: str | Path | None = None, ) -> bool: """ Overlay multiple result series on a single axes. Args: results_dict: ``{"Method Name": [values, …], …}`` metric_name: Y-axis label / title suffix. save_path: Destination PNG path. Returns: True on success. """ try: _check_mpl() if not results_dict: print("[WARN] No data for comparison plot.") return False fig, ax = plt.subplots(figsize=(12, 6)) colours = ["blue", "green", "red", "orange", "purple"] for i, (name, vals) in enumerate(results_dict.items()): if vals: ax.plot(range(1, len(vals) + 1), vals, label=name, linewidth=2, alpha=0.8, color=colours[i % len(colours)]) ax.set_xlabel("Episode", fontsize=12) ax.set_ylabel(metric_name, fontsize=12) ax.set_title(f"{metric_name} - Method Comparison", fontsize=14, fontweight="bold") ax.legend(loc="best") ax.grid(True, alpha=0.3) plt.tight_layout() if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="white") print(f"[OK] Comparison plot saved -> {save_path}") else: plt.show() plt.close(fig) return True except ImportError as exc: print(f"[WARN] {exc}") return False except Exception as exc: print(f"[WARN] Comparison plot error: {exc}") try: plt.close("all") except Exception: pass return False def plot_bar_comparison( method_scores: dict[str, float], title: str = "Method Comparison", ylabel: str = "Mean Reward", save_path: str | Path | None = None, ) -> bool: """ Bar chart comparing scalar scores for different methods. Args: method_scores: {"Method": score, ...} title: Chart title. ylabel: Y-axis label. save_path: Destination PNG path. Returns: True on success. """ try: _check_mpl() if not method_scores: return False names = list(method_scores.keys()) scores = [method_scores[n] for n in names] colours = ["#4472C4", "#ED7D31", "#A9D18E"] fig, ax = plt.subplots(figsize=(8, 5)) bars = ax.bar(names, scores, color=colours[: len(names)], edgecolor="white", linewidth=1.5) # Value labels for bar, score in zip(bars, scores): ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + (max(scores) - min(scores)) * 0.01, f"{score:.2f}", ha="center", va="bottom", fontsize=11, fontweight="bold", ) ax.set_title(title, fontsize=14, fontweight="bold") ax.set_ylabel(ylabel, fontsize=12) ax.grid(axis="y", alpha=0.3) ax.set_ylim(min(scores) * 1.05, max(scores) * 0.95) # Tight y-range plt.tight_layout() if save_path: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="white") print(f"[OK] Bar chart saved -> {save_path}") else: plt.show() plt.close(fig) return True except ImportError as exc: print(f"[WARN] {exc}") return False except Exception as exc: print(f"[WARN] Bar chart error: {exc}") return False