new_model / selfplay /visualization.py
Ayu
feat: RecallTrace Tasks 1-9 complete - belief calibration + curriculum + plots
d19137b
"""Visualization for RecallTrace adversarial self-play training.
Two main functions:
- show_training_curves(): 2x2 panel with F1, adversary reward, quarantined, steps
- show_episode_comparison(): side-by-side early vs late episode comparison
"""
from __future__ import annotations
import os
from typing import Any, Dict, List
import numpy as np
def _rolling_average(data: List[float], window: int = 20) -> List[float]:
"""Compute rolling average with the given window size."""
result = []
for i in range(len(data)):
start = max(0, i - window + 1)
result.append(sum(data[start:i+1]) / (i - start + 1))
return result
def show_training_curves(
stats: List[Dict[str, Any]],
save_path: str = "plots/selfplay_training.png",
) -> None:
"""Create a 2x2 publication-quality training curves figure.
Top left: Investigator F1 over episodes (raw + rolling avg)
Top right: Adversary reward over episodes
Bottom left: Nodes quarantined over episodes
Bottom right: Steps to finalize over episodes
Uses a dark theme for hackathon-ready visuals.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import font_manager
episodes = [s["episode"] for s in stats]
f1_scores = [s["investigator_f1"] for s in stats]
adv_rewards = [s["adversary_reward"] for s in stats]
quarantined = [s["num_quarantined"] for s in stats]
steps = [s["steps_taken"] for s in stats]
f1_rolling = _rolling_average(f1_scores)
adv_rolling = _rolling_average(adv_rewards)
q_rolling = _rolling_average(quarantined)
s_rolling = _rolling_average(steps)
# --- Dark theme setup ---
plt.style.use("dark_background")
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
fig.patch.set_facecolor("#0d1117")
colors = {
"f1_raw": "#3b82f6", # blue
"f1_avg": "#60a5fa", # light blue
"adv_raw": "#ef4444", # red
"adv_avg": "#f87171", # light red
"q_raw": "#22c55e", # green
"q_avg": "#4ade80", # light green
"s_raw": "#f59e0b", # amber
"s_avg": "#fbbf24", # light amber
}
bg_color = "#161b22"
grid_color = "#30363d"
text_color = "#e6edf3"
for ax in axes.flat:
ax.set_facecolor(bg_color)
ax.tick_params(colors=text_color, labelsize=10)
ax.spines["bottom"].set_color(grid_color)
ax.spines["left"].set_color(grid_color)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.grid(True, alpha=0.15, color=grid_color)
# --- Top Left: Investigator F1 ---
ax = axes[0, 0]
ax.scatter(episodes, f1_scores, c=colors["f1_raw"], alpha=0.15, s=8, zorder=2)
ax.plot(episodes, f1_rolling, color=colors["f1_avg"], linewidth=2.5, zorder=3, label="20-ep rolling avg")
ax.axhline(y=0.5, color="#ef4444", linestyle="--", alpha=0.4, linewidth=1)
ax.axhline(y=0.8, color="#22c55e", linestyle="--", alpha=0.4, linewidth=1)
ax.set_title("Investigator F1 Score", fontsize=14, color=text_color, fontweight="bold", pad=12)
ax.set_xlabel("Episode", color=text_color, fontsize=11)
ax.set_ylabel("F1 Score", color=text_color, fontsize=11)
ax.set_ylim(-0.05, 1.05)
ax.legend(loc="lower right", fontsize=9, facecolor=bg_color, edgecolor=grid_color)
# Add annotations
ax.text(0.02, 0.95, "Adversary wins ↓", transform=ax.transAxes,
fontsize=8, color="#ef4444", alpha=0.7, va="top")
ax.text(0.02, 0.05, "Investigator wins ↑", transform=ax.transAxes,
fontsize=8, color="#22c55e", alpha=0.7, va="bottom")
# --- Top Right: Adversary Reward ---
ax = axes[0, 1]
ax.scatter(episodes, adv_rewards, c=colors["adv_raw"], alpha=0.15, s=8, zorder=2)
ax.plot(episodes, adv_rolling, color=colors["adv_avg"], linewidth=2.5, zorder=3, label="20-ep rolling avg")
ax.axhline(y=0, color=text_color, linestyle="-", alpha=0.2, linewidth=1)
ax.set_title("Adversary Reward", fontsize=14, color=text_color, fontweight="bold", pad=12)
ax.set_xlabel("Episode", color=text_color, fontsize=11)
ax.set_ylabel("Reward", color=text_color, fontsize=11)
ax.set_ylim(-1.3, 1.3)
ax.legend(loc="upper right", fontsize=9, facecolor=bg_color, edgecolor=grid_color)
# --- Bottom Left: Nodes Quarantined ---
ax = axes[1, 0]
ax.scatter(episodes, quarantined, c=colors["q_raw"], alpha=0.15, s=8, zorder=2)
ax.plot(episodes, q_rolling, color=colors["q_avg"], linewidth=2.5, zorder=3, label="20-ep rolling avg")
ax.set_title("Nodes Quarantined per Episode", fontsize=14, color=text_color, fontweight="bold", pad=12)
ax.set_xlabel("Episode", color=text_color, fontsize=11)
ax.set_ylabel("Count", color=text_color, fontsize=11)
ax.legend(loc="upper right", fontsize=9, facecolor=bg_color, edgecolor=grid_color)
# --- Bottom Right: Steps Taken ---
ax = axes[1, 1]
ax.scatter(episodes, steps, c=colors["s_raw"], alpha=0.15, s=8, zorder=2)
ax.plot(episodes, s_rolling, color=colors["s_avg"], linewidth=2.5, zorder=3, label="20-ep rolling avg")
ax.set_title("Steps to Finalize", fontsize=14, color=text_color, fontweight="bold", pad=12)
ax.set_xlabel("Episode", color=text_color, fontsize=11)
ax.set_ylabel("Steps", color=text_color, fontsize=11)
ax.legend(loc="upper right", fontsize=9, facecolor=bg_color, edgecolor=grid_color)
# --- Main title ---
fig.suptitle(
"RecallTrace — Adversarial Self-Play Training",
fontsize=18, color=text_color, fontweight="bold", y=0.98,
)
fig.text(
0.5, 0.935,
"Investigator vs Adversary co-evolution over 200 episodes",
ha="center", fontsize=11, color="#8b949e",
)
plt.tight_layout(rect=[0, 0, 1, 0.92])
# Save
os.makedirs(os.path.dirname(save_path), exist_ok=True)
fig.savefig(save_path, dpi=200, bbox_inches="tight", facecolor=fig.get_facecolor())
plt.close(fig)
print(f" Saved training curves to {save_path}")
def show_episode_comparison(
early_stats: Dict[str, Any],
late_stats: Dict[str, Any],
save_path: str = "plots/episode_comparison.png",
) -> None:
"""Create a side-by-side comparison of early vs late episode behavior.
Shows: nodes visited, nodes quarantined, F1 score, belief confidence,
intervention type, correctly identified or not.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch
fig, (ax_early, ax_late) = plt.subplots(1, 2, figsize=(18, 9))
fig.patch.set_facecolor("#0d1117")
bg_color = "#161b22"
text_color = "#e6edf3"
dim_color = "#8b949e"
def _draw_episode_card(ax, stats, title, is_good):
ax.set_facecolor(bg_color)
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.axis("off")
# Title bar
border_color = "#22c55e" if is_good else "#ef4444"
title_bg = "#1a3a2a" if is_good else "#3a1a1a"
rect = FancyBboxPatch(
(0.3, 8.5), 9.4, 1.2,
boxstyle="round,pad=0.15",
facecolor=title_bg, edgecolor=border_color, linewidth=2,
)
ax.add_patch(rect)
ax.text(5, 9.1, title, fontsize=16, fontweight="bold",
color=text_color, ha="center", va="center")
# F1 Score (large)
f1 = stats["investigator_f1"]
f1_color = "#22c55e" if f1 > 0.7 else "#f59e0b" if f1 > 0.4 else "#ef4444"
ax.text(5, 7.5, f"F1 Score: {f1:.3f}", fontsize=28, fontweight="bold",
color=f1_color, ha="center", va="center")
# Stats grid
info_lines = [
("Nodes Visited", str(len(stats.get("nodes_visited", [])))),
("Nodes Quarantined", str(stats["num_quarantined"])),
("Steps Taken", str(stats["steps_taken"])),
("Belief Confidence", f"{stats['belief_confidence']:.2f}"),
("Intervention Type", stats["intervention_type"]),
("Correctly Identified", "YES" if stats["intervention_correctly_identified"] else "NO"),
("Quarantine Threshold", f"{stats['quarantine_threshold']:.3f}"),
("Exploration Rate", f"{stats['exploration_rate']:.3f}"),
]
y_pos = 6.2
for label, value in info_lines:
# Label
ax.text(1.0, y_pos, label + ":", fontsize=11, color=dim_color,
ha="left", va="center", fontfamily="monospace")
# Value
v_color = text_color
if label == "Correctly Identified":
v_color = "#22c55e" if value == "YES" else "#ef4444"
ax.text(9.0, y_pos, value, fontsize=12, fontweight="bold",
color=v_color, ha="right", va="center", fontfamily="monospace")
y_pos -= 0.7
# Quarantined nodes list
q_nodes = stats.get("nodes_quarantined_list", [])
if q_nodes:
ax.text(1.0, y_pos - 0.3, "Quarantined:", fontsize=10, color=dim_color,
ha="left", va="center")
node_text = ", ".join(q_nodes[:6])
if len(q_nodes) > 6:
node_text += f" +{len(q_nodes)-6} more"
ax.text(1.0, y_pos - 0.9, node_text, fontsize=9, color="#f59e0b",
ha="left", va="center", fontfamily="monospace")
_draw_episode_card(ax_early, early_stats,
f"Episode {early_stats['episode']} (Early)", is_good=False)
_draw_episode_card(ax_late, late_stats,
f"Episode {late_stats['episode']} (Late)", is_good=True)
# Arrow between cards
fig.text(0.5, 0.5, "→", fontsize=48, color="#8b949e",
ha="center", va="center", fontweight="bold")
fig.suptitle(
"RecallTrace — Before / After Self-Play Training",
fontsize=18, color=text_color, fontweight="bold", y=0.97,
)
fig.text(
0.5, 0.92,
"Investigator behavior change: spray & pray → precision targeting",
ha="center", fontsize=12, color=dim_color,
)
plt.tight_layout(rect=[0, 0, 1, 0.90])
os.makedirs(os.path.dirname(save_path), exist_ok=True)
fig.savefig(save_path, dpi=200, bbox_inches="tight", facecolor=fig.get_facecolor())
plt.close(fig)
print(f" Saved episode comparison to {save_path}")