Traffic-Control / utils /visualizer.py
Dhaerya's picture
Add files
b00d5d5
"""
Visualisation utilities — training curves and comparison plots.
Uses only Matplotlib (no Seaborn / Plotly dependency).
All plots are saved to disk (non-interactive backend).
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
try:
import matplotlib
matplotlib.use("Agg") # Non-interactive backend (safe on servers)
import matplotlib.pyplot as plt
_MPL_OK = True
except ImportError:
_MPL_OK = False
plt = None # type: ignore
# ── Helper ────────────────────────────────────────────────────────────────────
def _check_mpl():
if not _MPL_OK:
raise ImportError(
"matplotlib is required for plotting.\n"
"Install with: pip install matplotlib"
)
def _moving_average(values: list, window: int) -> list:
"""Simple unweighted moving average."""
result = []
for i in range(len(values)):
start = max(0, i - window + 1)
result.append(float(np.mean(values[start : i + 1])))
return result
# ── Public functions ──────────────────────────────────────────────────────────
def plot_training_curves(metrics, save_path: str | Path | None = None) -> bool:
"""
Plot four training metrics in a 2×2 grid and save to *save_path*.
Args:
metrics: A :class:`MetricsTracker` instance.
save_path: Destination PNG path. Shown interactively if None.
Returns:
True on success, False on failure.
"""
try:
_check_mpl()
panel_cfg = [
("episode_reward", "Episode Reward", "blue", "Reward"),
("average_waiting_time", "Avg Waiting Time", "orange", "Waiting Time (s)"),
("average_queue_length", "Avg Queue Length", "red", "Queue Length"),
("throughput", "Throughput", "green", "Vehicles Passed"),
]
has_any = any(metrics.has(k) for k, *_ in panel_cfg)
if not has_any:
print("[WARN] No data available for plotting.")
return False
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle("Training Progress", fontsize=16, fontweight="bold")
axes_flat = axes.flatten()
for ax, (key, title, colour, ylabel) in zip(axes_flat, panel_cfg):
ax.set_title(title, fontsize=12, fontweight="bold")
ax.set_xlabel("Episode", fontsize=10)
ax.set_ylabel(ylabel, fontsize=10)
ax.grid(True, alpha=0.3)
if not metrics.has(key):
ax.text(0.5, 0.5, "No data", ha="center", va="center",
transform=ax.transAxes, color="grey")
continue
vals = metrics.get(key)
eps = range(1, len(vals) + 1)
ax.plot(eps, vals, alpha=0.4, color=colour, linewidth=1, label="Raw")
if len(vals) >= 10:
w = min(50, max(10, len(vals) // 10))
ma = _moving_average(vals, w)
ax.plot(eps, ma, color=colour, linewidth=2,
label=f"MA-{w}")
ax.legend(loc="best", fontsize=8)
plt.tight_layout()
if save_path:
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="white")
print(f"[OK] Plot saved -> {save_path}")
else:
plt.show()
plt.close(fig)
return True
except ImportError as exc:
print(f"[WARN] {exc}")
return False
except Exception as exc:
print(f"[WARN] Plotting error: {exc}")
try:
plt.close("all")
except Exception:
pass
return False
def plot_comparison(
results_dict: dict[str, list],
metric_name: str,
save_path: str | Path | None = None,
) -> bool:
"""
Overlay multiple result series on a single axes.
Args:
results_dict: ``{"Method Name": [values, …], …}``
metric_name: Y-axis label / title suffix.
save_path: Destination PNG path.
Returns:
True on success.
"""
try:
_check_mpl()
if not results_dict:
print("[WARN] No data for comparison plot.")
return False
fig, ax = plt.subplots(figsize=(12, 6))
colours = ["blue", "green", "red", "orange", "purple"]
for i, (name, vals) in enumerate(results_dict.items()):
if vals:
ax.plot(range(1, len(vals) + 1), vals,
label=name, linewidth=2, alpha=0.8,
color=colours[i % len(colours)])
ax.set_xlabel("Episode", fontsize=12)
ax.set_ylabel(metric_name, fontsize=12)
ax.set_title(f"{metric_name} - Method Comparison",
fontsize=14, fontweight="bold")
ax.legend(loc="best")
ax.grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="white")
print(f"[OK] Comparison plot saved -> {save_path}")
else:
plt.show()
plt.close(fig)
return True
except ImportError as exc:
print(f"[WARN] {exc}")
return False
except Exception as exc:
print(f"[WARN] Comparison plot error: {exc}")
try:
plt.close("all")
except Exception:
pass
return False
def plot_bar_comparison(
method_scores: dict[str, float],
title: str = "Method Comparison",
ylabel: str = "Mean Reward",
save_path: str | Path | None = None,
) -> bool:
"""
Bar chart comparing scalar scores for different methods.
Args:
method_scores: {"Method": score, ...}
title: Chart title.
ylabel: Y-axis label.
save_path: Destination PNG path.
Returns:
True on success.
"""
try:
_check_mpl()
if not method_scores:
return False
names = list(method_scores.keys())
scores = [method_scores[n] for n in names]
colours = ["#4472C4", "#ED7D31", "#A9D18E"]
fig, ax = plt.subplots(figsize=(8, 5))
bars = ax.bar(names, scores,
color=colours[: len(names)],
edgecolor="white", linewidth=1.5)
# Value labels
for bar, score in zip(bars, scores):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + (max(scores) - min(scores)) * 0.01,
f"{score:.2f}",
ha="center", va="bottom", fontsize=11, fontweight="bold",
)
ax.set_title(title, fontsize=14, fontweight="bold")
ax.set_ylabel(ylabel, fontsize=12)
ax.grid(axis="y", alpha=0.3)
ax.set_ylim(min(scores) * 1.05, max(scores) * 0.95) # Tight y-range
plt.tight_layout()
if save_path:
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor="white")
print(f"[OK] Bar chart saved -> {save_path}")
else:
plt.show()
plt.close(fig)
return True
except ImportError as exc:
print(f"[WARN] {exc}")
return False
except Exception as exc:
print(f"[WARN] Bar chart error: {exc}")
return False