Spaces:
Sleeping
Sleeping
| """Visualization for RecallTrace adversarial self-play training. | |
| Two main functions: | |
| - show_training_curves(): 2x2 panel with F1, adversary reward, quarantined, steps | |
| - show_episode_comparison(): side-by-side early vs late episode comparison | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from typing import Any, Dict, List | |
| import numpy as np | |
| def _rolling_average(data: List[float], window: int = 20) -> List[float]: | |
| """Compute rolling average with the given window size.""" | |
| result = [] | |
| for i in range(len(data)): | |
| start = max(0, i - window + 1) | |
| result.append(sum(data[start:i+1]) / (i - start + 1)) | |
| return result | |
| def show_training_curves( | |
| stats: List[Dict[str, Any]], | |
| save_path: str = "plots/selfplay_training.png", | |
| ) -> None: | |
| """Create a 2x2 publication-quality training curves figure. | |
| Top left: Investigator F1 over episodes (raw + rolling avg) | |
| Top right: Adversary reward over episodes | |
| Bottom left: Nodes quarantined over episodes | |
| Bottom right: Steps to finalize over episodes | |
| Uses a dark theme for hackathon-ready visuals. | |
| """ | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from matplotlib import font_manager | |
| episodes = [s["episode"] for s in stats] | |
| f1_scores = [s["investigator_f1"] for s in stats] | |
| adv_rewards = [s["adversary_reward"] for s in stats] | |
| quarantined = [s["num_quarantined"] for s in stats] | |
| steps = [s["steps_taken"] for s in stats] | |
| f1_rolling = _rolling_average(f1_scores) | |
| adv_rolling = _rolling_average(adv_rewards) | |
| q_rolling = _rolling_average(quarantined) | |
| s_rolling = _rolling_average(steps) | |
| # --- Dark theme setup --- | |
| plt.style.use("dark_background") | |
| fig, axes = plt.subplots(2, 2, figsize=(16, 10)) | |
| fig.patch.set_facecolor("#0d1117") | |
| colors = { | |
| "f1_raw": "#3b82f6", # blue | |
| "f1_avg": "#60a5fa", # light blue | |
| "adv_raw": "#ef4444", # red | |
| "adv_avg": "#f87171", # light red | |
| "q_raw": "#22c55e", # green | |
| "q_avg": "#4ade80", # light green | |
| "s_raw": "#f59e0b", # amber | |
| "s_avg": "#fbbf24", # light amber | |
| } | |
| bg_color = "#161b22" | |
| grid_color = "#30363d" | |
| text_color = "#e6edf3" | |
| for ax in axes.flat: | |
| ax.set_facecolor(bg_color) | |
| ax.tick_params(colors=text_color, labelsize=10) | |
| ax.spines["bottom"].set_color(grid_color) | |
| ax.spines["left"].set_color(grid_color) | |
| ax.spines["top"].set_visible(False) | |
| ax.spines["right"].set_visible(False) | |
| ax.grid(True, alpha=0.15, color=grid_color) | |
| # --- Top Left: Investigator F1 --- | |
| ax = axes[0, 0] | |
| ax.scatter(episodes, f1_scores, c=colors["f1_raw"], alpha=0.15, s=8, zorder=2) | |
| ax.plot(episodes, f1_rolling, color=colors["f1_avg"], linewidth=2.5, zorder=3, label="20-ep rolling avg") | |
| ax.axhline(y=0.5, color="#ef4444", linestyle="--", alpha=0.4, linewidth=1) | |
| ax.axhline(y=0.8, color="#22c55e", linestyle="--", alpha=0.4, linewidth=1) | |
| ax.set_title("Investigator F1 Score", fontsize=14, color=text_color, fontweight="bold", pad=12) | |
| ax.set_xlabel("Episode", color=text_color, fontsize=11) | |
| ax.set_ylabel("F1 Score", color=text_color, fontsize=11) | |
| ax.set_ylim(-0.05, 1.05) | |
| ax.legend(loc="lower right", fontsize=9, facecolor=bg_color, edgecolor=grid_color) | |
| # Add annotations | |
| ax.text(0.02, 0.95, "Adversary wins ↓", transform=ax.transAxes, | |
| fontsize=8, color="#ef4444", alpha=0.7, va="top") | |
| ax.text(0.02, 0.05, "Investigator wins ↑", transform=ax.transAxes, | |
| fontsize=8, color="#22c55e", alpha=0.7, va="bottom") | |
| # --- Top Right: Adversary Reward --- | |
| ax = axes[0, 1] | |
| ax.scatter(episodes, adv_rewards, c=colors["adv_raw"], alpha=0.15, s=8, zorder=2) | |
| ax.plot(episodes, adv_rolling, color=colors["adv_avg"], linewidth=2.5, zorder=3, label="20-ep rolling avg") | |
| ax.axhline(y=0, color=text_color, linestyle="-", alpha=0.2, linewidth=1) | |
| ax.set_title("Adversary Reward", fontsize=14, color=text_color, fontweight="bold", pad=12) | |
| ax.set_xlabel("Episode", color=text_color, fontsize=11) | |
| ax.set_ylabel("Reward", color=text_color, fontsize=11) | |
| ax.set_ylim(-1.3, 1.3) | |
| ax.legend(loc="upper right", fontsize=9, facecolor=bg_color, edgecolor=grid_color) | |
| # --- Bottom Left: Nodes Quarantined --- | |
| ax = axes[1, 0] | |
| ax.scatter(episodes, quarantined, c=colors["q_raw"], alpha=0.15, s=8, zorder=2) | |
| ax.plot(episodes, q_rolling, color=colors["q_avg"], linewidth=2.5, zorder=3, label="20-ep rolling avg") | |
| ax.set_title("Nodes Quarantined per Episode", fontsize=14, color=text_color, fontweight="bold", pad=12) | |
| ax.set_xlabel("Episode", color=text_color, fontsize=11) | |
| ax.set_ylabel("Count", color=text_color, fontsize=11) | |
| ax.legend(loc="upper right", fontsize=9, facecolor=bg_color, edgecolor=grid_color) | |
| # --- Bottom Right: Steps Taken --- | |
| ax = axes[1, 1] | |
| ax.scatter(episodes, steps, c=colors["s_raw"], alpha=0.15, s=8, zorder=2) | |
| ax.plot(episodes, s_rolling, color=colors["s_avg"], linewidth=2.5, zorder=3, label="20-ep rolling avg") | |
| ax.set_title("Steps to Finalize", fontsize=14, color=text_color, fontweight="bold", pad=12) | |
| ax.set_xlabel("Episode", color=text_color, fontsize=11) | |
| ax.set_ylabel("Steps", color=text_color, fontsize=11) | |
| ax.legend(loc="upper right", fontsize=9, facecolor=bg_color, edgecolor=grid_color) | |
| # --- Main title --- | |
| fig.suptitle( | |
| "RecallTrace — Adversarial Self-Play Training", | |
| fontsize=18, color=text_color, fontweight="bold", y=0.98, | |
| ) | |
| fig.text( | |
| 0.5, 0.935, | |
| "Investigator vs Adversary co-evolution over 200 episodes", | |
| ha="center", fontsize=11, color="#8b949e", | |
| ) | |
| plt.tight_layout(rect=[0, 0, 1, 0.92]) | |
| # Save | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| fig.savefig(save_path, dpi=200, bbox_inches="tight", facecolor=fig.get_facecolor()) | |
| plt.close(fig) | |
| print(f" Saved training curves to {save_path}") | |
| def show_episode_comparison( | |
| early_stats: Dict[str, Any], | |
| late_stats: Dict[str, Any], | |
| save_path: str = "plots/episode_comparison.png", | |
| ) -> None: | |
| """Create a side-by-side comparison of early vs late episode behavior. | |
| Shows: nodes visited, nodes quarantined, F1 score, belief confidence, | |
| intervention type, correctly identified or not. | |
| """ | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from matplotlib.patches import FancyBboxPatch | |
| fig, (ax_early, ax_late) = plt.subplots(1, 2, figsize=(18, 9)) | |
| fig.patch.set_facecolor("#0d1117") | |
| bg_color = "#161b22" | |
| text_color = "#e6edf3" | |
| dim_color = "#8b949e" | |
| def _draw_episode_card(ax, stats, title, is_good): | |
| ax.set_facecolor(bg_color) | |
| ax.set_xlim(0, 10) | |
| ax.set_ylim(0, 10) | |
| ax.axis("off") | |
| # Title bar | |
| border_color = "#22c55e" if is_good else "#ef4444" | |
| title_bg = "#1a3a2a" if is_good else "#3a1a1a" | |
| rect = FancyBboxPatch( | |
| (0.3, 8.5), 9.4, 1.2, | |
| boxstyle="round,pad=0.15", | |
| facecolor=title_bg, edgecolor=border_color, linewidth=2, | |
| ) | |
| ax.add_patch(rect) | |
| ax.text(5, 9.1, title, fontsize=16, fontweight="bold", | |
| color=text_color, ha="center", va="center") | |
| # F1 Score (large) | |
| f1 = stats["investigator_f1"] | |
| f1_color = "#22c55e" if f1 > 0.7 else "#f59e0b" if f1 > 0.4 else "#ef4444" | |
| ax.text(5, 7.5, f"F1 Score: {f1:.3f}", fontsize=28, fontweight="bold", | |
| color=f1_color, ha="center", va="center") | |
| # Stats grid | |
| info_lines = [ | |
| ("Nodes Visited", str(len(stats.get("nodes_visited", [])))), | |
| ("Nodes Quarantined", str(stats["num_quarantined"])), | |
| ("Steps Taken", str(stats["steps_taken"])), | |
| ("Belief Confidence", f"{stats['belief_confidence']:.2f}"), | |
| ("Intervention Type", stats["intervention_type"]), | |
| ("Correctly Identified", "YES" if stats["intervention_correctly_identified"] else "NO"), | |
| ("Quarantine Threshold", f"{stats['quarantine_threshold']:.3f}"), | |
| ("Exploration Rate", f"{stats['exploration_rate']:.3f}"), | |
| ] | |
| y_pos = 6.2 | |
| for label, value in info_lines: | |
| # Label | |
| ax.text(1.0, y_pos, label + ":", fontsize=11, color=dim_color, | |
| ha="left", va="center", fontfamily="monospace") | |
| # Value | |
| v_color = text_color | |
| if label == "Correctly Identified": | |
| v_color = "#22c55e" if value == "YES" else "#ef4444" | |
| ax.text(9.0, y_pos, value, fontsize=12, fontweight="bold", | |
| color=v_color, ha="right", va="center", fontfamily="monospace") | |
| y_pos -= 0.7 | |
| # Quarantined nodes list | |
| q_nodes = stats.get("nodes_quarantined_list", []) | |
| if q_nodes: | |
| ax.text(1.0, y_pos - 0.3, "Quarantined:", fontsize=10, color=dim_color, | |
| ha="left", va="center") | |
| node_text = ", ".join(q_nodes[:6]) | |
| if len(q_nodes) > 6: | |
| node_text += f" +{len(q_nodes)-6} more" | |
| ax.text(1.0, y_pos - 0.9, node_text, fontsize=9, color="#f59e0b", | |
| ha="left", va="center", fontfamily="monospace") | |
| _draw_episode_card(ax_early, early_stats, | |
| f"Episode {early_stats['episode']} (Early)", is_good=False) | |
| _draw_episode_card(ax_late, late_stats, | |
| f"Episode {late_stats['episode']} (Late)", is_good=True) | |
| # Arrow between cards | |
| fig.text(0.5, 0.5, "→", fontsize=48, color="#8b949e", | |
| ha="center", va="center", fontweight="bold") | |
| fig.suptitle( | |
| "RecallTrace — Before / After Self-Play Training", | |
| fontsize=18, color=text_color, fontweight="bold", y=0.97, | |
| ) | |
| fig.text( | |
| 0.5, 0.92, | |
| "Investigator behavior change: spray & pray → precision targeting", | |
| ha="center", fontsize=12, color=dim_color, | |
| ) | |
| plt.tight_layout(rect=[0, 0, 1, 0.90]) | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| fig.savefig(save_path, dpi=200, bbox_inches="tight", facecolor=fig.get_facecolor()) | |
| plt.close(fig) | |
| print(f" Saved episode comparison to {save_path}") | |