Spaces:
Sleeping
Sleeping
| """ | |
| Visualisation utilities — training curves and comparison plots. | |
| Uses only Matplotlib (no Seaborn / Plotly dependency). | |
| All plots are saved to disk (non-interactive backend). | |
| """ | |
| from __future__ import annotations | |
| from pathlib import Path | |
| import numpy as np | |
| try: | |
| import matplotlib | |
| matplotlib.use("Agg") # Non-interactive backend (safe on servers) | |
| import matplotlib.pyplot as plt | |
| _MPL_OK = True | |
| except ImportError: | |
| _MPL_OK = False | |
| plt = None # type: ignore | |
| # ── Helper ──────────────────────────────────────────────────────────────────── | |
| def _check_mpl(): | |
| if not _MPL_OK: | |
| raise ImportError( | |
| "matplotlib is required for plotting.\n" | |
| "Install with: pip install matplotlib" | |
| ) | |
| def _moving_average(values: list, window: int) -> list: | |
| """Simple unweighted moving average.""" | |
| result = [] | |
| for i in range(len(values)): | |
| start = max(0, i - window + 1) | |
| result.append(float(np.mean(values[start : i + 1]))) | |
| return result | |
| # ── Public functions ────────────────────────────────────────────────────────── | |
| def plot_training_curves(metrics, save_path: str | Path | None = None) -> bool: | |
| """ | |
| Plot four training metrics in a 2×2 grid and save to *save_path*. | |
| Args: | |
| metrics: A :class:`MetricsTracker` instance. | |
| save_path: Destination PNG path. Shown interactively if None. | |
| Returns: | |
| True on success, False on failure. | |
| """ | |
| try: | |
| _check_mpl() | |
| panel_cfg = [ | |
| ("episode_reward", "Episode Reward", "blue", "Reward"), | |
| ("average_waiting_time", "Avg Waiting Time", "orange", "Waiting Time (s)"), | |
| ("average_queue_length", "Avg Queue Length", "red", "Queue Length"), | |
| ("throughput", "Throughput", "green", "Vehicles Passed"), | |
| ] | |
| has_any = any(metrics.has(k) for k, *_ in panel_cfg) | |
| if not has_any: | |
| print("[WARN] No data available for plotting.") | |
| return False | |
| fig, axes = plt.subplots(2, 2, figsize=(15, 10)) | |
| fig.suptitle("Training Progress", fontsize=16, fontweight="bold") | |
| axes_flat = axes.flatten() | |
| for ax, (key, title, colour, ylabel) in zip(axes_flat, panel_cfg): | |
| ax.set_title(title, fontsize=12, fontweight="bold") | |
| ax.set_xlabel("Episode", fontsize=10) | |
| ax.set_ylabel(ylabel, fontsize=10) | |
| ax.grid(True, alpha=0.3) | |
| if not metrics.has(key): | |
| ax.text(0.5, 0.5, "No data", ha="center", va="center", | |
| transform=ax.transAxes, color="grey") | |
| continue | |
| vals = metrics.get(key) | |
| eps = range(1, len(vals) + 1) | |
| ax.plot(eps, vals, alpha=0.4, color=colour, linewidth=1, label="Raw") | |
| if len(vals) >= 10: | |
| w = min(50, max(10, len(vals) // 10)) | |
| ma = _moving_average(vals, w) | |
| ax.plot(eps, ma, color=colour, linewidth=2, | |
| label=f"MA-{w}") | |
| ax.legend(loc="best", fontsize=8) | |
| plt.tight_layout() | |
| if save_path: | |
| save_path = Path(save_path) | |
| save_path.parent.mkdir(parents=True, exist_ok=True) | |
| plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="white") | |
| print(f"[OK] Plot saved -> {save_path}") | |
| else: | |
| plt.show() | |
| plt.close(fig) | |
| return True | |
| except ImportError as exc: | |
| print(f"[WARN] {exc}") | |
| return False | |
| except Exception as exc: | |
| print(f"[WARN] Plotting error: {exc}") | |
| try: | |
| plt.close("all") | |
| except Exception: | |
| pass | |
| return False | |
| def plot_comparison( | |
| results_dict: dict[str, list], | |
| metric_name: str, | |
| save_path: str | Path | None = None, | |
| ) -> bool: | |
| """ | |
| Overlay multiple result series on a single axes. | |
| Args: | |
| results_dict: ``{"Method Name": [values, …], …}`` | |
| metric_name: Y-axis label / title suffix. | |
| save_path: Destination PNG path. | |
| Returns: | |
| True on success. | |
| """ | |
| try: | |
| _check_mpl() | |
| if not results_dict: | |
| print("[WARN] No data for comparison plot.") | |
| return False | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| colours = ["blue", "green", "red", "orange", "purple"] | |
| for i, (name, vals) in enumerate(results_dict.items()): | |
| if vals: | |
| ax.plot(range(1, len(vals) + 1), vals, | |
| label=name, linewidth=2, alpha=0.8, | |
| color=colours[i % len(colours)]) | |
| ax.set_xlabel("Episode", fontsize=12) | |
| ax.set_ylabel(metric_name, fontsize=12) | |
| ax.set_title(f"{metric_name} - Method Comparison", | |
| fontsize=14, fontweight="bold") | |
| ax.legend(loc="best") | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| save_path = Path(save_path) | |
| save_path.parent.mkdir(parents=True, exist_ok=True) | |
| plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="white") | |
| print(f"[OK] Comparison plot saved -> {save_path}") | |
| else: | |
| plt.show() | |
| plt.close(fig) | |
| return True | |
| except ImportError as exc: | |
| print(f"[WARN] {exc}") | |
| return False | |
| except Exception as exc: | |
| print(f"[WARN] Comparison plot error: {exc}") | |
| try: | |
| plt.close("all") | |
| except Exception: | |
| pass | |
| return False | |
| def plot_bar_comparison( | |
| method_scores: dict[str, float], | |
| title: str = "Method Comparison", | |
| ylabel: str = "Mean Reward", | |
| save_path: str | Path | None = None, | |
| ) -> bool: | |
| """ | |
| Bar chart comparing scalar scores for different methods. | |
| Args: | |
| method_scores: {"Method": score, ...} | |
| title: Chart title. | |
| ylabel: Y-axis label. | |
| save_path: Destination PNG path. | |
| Returns: | |
| True on success. | |
| """ | |
| try: | |
| _check_mpl() | |
| if not method_scores: | |
| return False | |
| names = list(method_scores.keys()) | |
| scores = [method_scores[n] for n in names] | |
| colours = ["#4472C4", "#ED7D31", "#A9D18E"] | |
| fig, ax = plt.subplots(figsize=(8, 5)) | |
| bars = ax.bar(names, scores, | |
| color=colours[: len(names)], | |
| edgecolor="white", linewidth=1.5) | |
| # Value labels | |
| for bar, score in zip(bars, scores): | |
| ax.text( | |
| bar.get_x() + bar.get_width() / 2, | |
| bar.get_height() + (max(scores) - min(scores)) * 0.01, | |
| f"{score:.2f}", | |
| ha="center", va="bottom", fontsize=11, fontweight="bold", | |
| ) | |
| ax.set_title(title, fontsize=14, fontweight="bold") | |
| ax.set_ylabel(ylabel, fontsize=12) | |
| ax.grid(axis="y", alpha=0.3) | |
| ax.set_ylim(min(scores) * 1.05, max(scores) * 0.95) # Tight y-range | |
| plt.tight_layout() | |
| if save_path: | |
| save_path = Path(save_path) | |
| save_path.parent.mkdir(parents=True, exist_ok=True) | |
| plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="white") | |
| print(f"[OK] Bar chart saved -> {save_path}") | |
| else: | |
| plt.show() | |
| plt.close(fig) | |
| return True | |
| except ImportError as exc: | |
| print(f"[WARN] {exc}") | |
| return False | |
| except Exception as exc: | |
| print(f"[WARN] Bar chart error: {exc}") | |
| return False | |