"""Chart helper functions for Gradio 6 native plots.
Generates pandas DataFrames from episode replay data for use with
gr.LinePlot, gr.BarPlot, and styled HTML verdicts.
"""
from __future__ import annotations
import pandas as pd
def format_comparison_scores_html(untrained: dict, trained: dict) -> str:
"""Format comparative scores for untrained vs trained."""
colors = {
"attacker": "var(--sentinel-red)",
"worker": "var(--sentinel-blue)",
"oversight": "var(--sentinel-green)",
}
html = "
"
for agent in untrained.keys():
color = colors.get(agent, "#888")
u_score = untrained[agent]
t_score = trained[agent]
diff = t_score - u_score
diff_color = "#44bb44" if diff > 0 else ("#ff4444" if diff < 0 else "#888")
diff_sign = "+" if diff > 0 else ""
html += (
f"
"
f"
{agent}
"
f"
"
f"
"
f"UNTRAINED:"
f"{u_score:.1f}"
f"
"
f"
"
f"TRAINED:"
f"{t_score:.1f}"
f"
"
f"
"
f"{diff_sign}{diff:.1f}"
f"
"
f"
"
f"
"
)
html += "
"
return html
def format_scores_html(scores: dict) -> str:
"""Format final scores as a styled HTML widget."""
colors = {
"attacker": "var(--sentinel-red)",
"worker": "var(--sentinel-blue)",
"oversight": "var(--sentinel-green)",
}
html = ""
for agent, score in scores.items():
color = colors.get(agent, "#888")
html += (
f"
"
f"{agent}"
f"{score:.1f}"
f"
"
)
html += "
"
return html
def build_score_progression_df(log: list[dict]) -> pd.DataFrame:
"""Track cumulative scores for each agent at each tick.
Returns a DataFrame with columns: tick, agent, score
One row per agent per tick, with accumulated rewards.
"""
agents = ["attacker", "worker", "oversight"]
cumulative = {a: 0.0 for a in agents}
rows: list[dict] = []
seen_ticks: set[int] = set()
for entry in log:
agent = entry["agent"]
reward = entry.get("reward", 0) or 0
cumulative[agent] += reward
tick = entry["tick"]
if tick not in seen_ticks:
seen_ticks.add(tick)
for a in agents:
rows.append({"tick": tick, "agent": a, "score": cumulative[a]})
return pd.DataFrame(rows)
def build_attack_timeline_df(log: list[dict]) -> pd.DataFrame:
"""Extract attack events from the log.
Returns a DataFrame with columns: tick, attack_type, target
Only includes entries where action_type == "launch_attack".
"""
rows: list[dict] = []
for entry in log:
if entry["action_type"] == "launch_attack":
details = entry.get("details", "")
# details is a stringified dict; parse attack_type and target_system
attack_type = ""
target = ""
if isinstance(details, str):
# Extract from stringified parameters dict
for token in ["schema_drift", "policy_drift", "social_engineering", "rate_limit"]:
if token in details:
attack_type = token
break
for sys in ["crm", "billing", "ticketing"]:
if sys in details:
target = sys
break
rows.append({
"tick": entry["tick"],
"attack_type": attack_type,
"target": target,
"count": 1,
})
return pd.DataFrame(rows) if rows else pd.DataFrame(columns=["tick", "attack_type", "target", "count"])
def build_comparison_df(untrained_scores: dict, trained_scores: dict) -> pd.DataFrame:
"""Format scores for a side-by-side bar chart.
Returns a DataFrame with columns: agent, score, type
where type is "untrained" or "trained".
"""
rows: list[dict] = []
for agent, score in untrained_scores.items():
rows.append({"agent": agent, "score": score, "type": "untrained"})
for agent, score in trained_scores.items():
rows.append({"agent": agent, "score": score, "type": "trained"})
return pd.DataFrame(rows)
def build_verdict_html(untrained_log: list, trained_log: list) -> str:
"""Build styled HTML verdict comparing untrained vs trained episodes.
Counts: attacks launched, attacks detected (get_schema/get_current_policy),
social engineering resisted. Returns HTML with large numbers showing
the difference.
"""
def _count_stats(log: list) -> dict:
attacks_launched = 0
attacks_detected = 0
social_eng_resisted = 0
for entry in log:
if entry["action_type"] == "launch_attack":
attacks_launched += 1
if entry["action_type"] in ("get_schema", "get_current_policy"):
attacks_detected += 1
# Social engineering resisted: worker responds with refusal
if (
entry["agent"] == "worker"
and entry["action_type"] == "respond"
and "social engineering" in str(entry.get("details", "")).lower()
):
social_eng_resisted += 1
return {
"attacks_launched": attacks_launched,
"attacks_detected": attacks_detected,
"social_eng_resisted": social_eng_resisted,
}
untrained_stats = _count_stats(untrained_log)
trained_stats = _count_stats(trained_log)
def _stat_card(label: str, untrained_val: int, trained_val: int) -> str:
diff = trained_val - untrained_val
diff_color = "#44bb44" if diff > 0 else ("#ff4444" if diff < 0 else "#888")
diff_sign = "+" if diff > 0 else ""
return (
f""
f"
{label}
"
f"
"
f"
"
f"
{untrained_val}
"
f"
Untrained
"
f"
"
f"
"
f"
{trained_val}
"
f"
Trained
"
f"
"
f"
"
f"
Difference: {diff_sign}{diff}
"
f"
"
)
html = (
""
"
"
)
html += _stat_card(
"Attacks Launched",
untrained_stats["attacks_launched"],
trained_stats["attacks_launched"],
)
html += _stat_card(
"Attacks Detected",
untrained_stats["attacks_detected"],
trained_stats["attacks_detected"],
)
html += _stat_card(
"Social Eng. Resisted",
untrained_stats["social_eng_resisted"],
trained_stats["social_eng_resisted"],
)
html += "
"
return html