sentinel-env / training /evaluate.py
XcodeAddy's picture
Add GPU trust environment and GRPO replay pipeline
a36db1b
from __future__ import annotations
import argparse
import json
import random
import struct
import sys
import zlib
from pathlib import Path
from typing import Callable
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from difficulty_controller import GLOBAL_DIFFICULTY_CONTROLLER
from environment import SentinelEnv, _GROUND_TRUTH_RELIABILITY
from sentinel_config import ADVERSARIAL_AWARENESS_STAKES
from training.replay import replay_trained_policy
Policy = Callable[[SentinelEnv, dict, random.Random], dict]
POLICIES: dict[str, Policy] = {}
def random_policy(env: SentinelEnv, obs: dict, rng: random.Random) -> dict:
specialist = rng.choice(obs["available_specialists"])
roll = rng.random()
if roll < 0.65:
action_type = "delegate"
elif roll < 0.85:
action_type = "verify"
elif roll < 0.95:
action_type = "solve_independently"
specialist = None
else:
action_type = "skip"
specialist = None
return _action(obs, action_type, specialist)
def heuristic_policy(env: SentinelEnv, obs: dict, rng: random.Random) -> dict:
trust = obs["trust_snapshot"]
specialist = max(obs["available_specialists"], key=lambda sid: trust.get(sid, 0.5))
action_type = (
"verify"
if obs["stakes_level"] >= ADVERSARIAL_AWARENESS_STAKES and trust.get(specialist, 0.5) < 0.65
else "delegate"
)
return _action(obs, action_type, specialist)
def oracle_lite_policy(env: SentinelEnv, obs: dict, rng: random.Random) -> dict:
reliability = env._pool.public_ground_truth_reliability(_GROUND_TRUTH_RELIABILITY)
if obs["task_type"] == "task3" and obs["stakes_level"] >= ADVERSARIAL_AWARENESS_STAKES:
return _action(obs, "verify", env._pool.adversarial_slot)
specialist = max(obs["available_specialists"], key=lambda sid: reliability.get(sid, 0.5))
return _action(obs, "delegate", specialist)
def _action(obs: dict, action_type: str, specialist_id: str | None) -> dict:
return {
"session_id": obs["session_id"],
"task_type": obs["task_type"],
"action_type": action_type,
"specialist_id": specialist_id,
"subtask_response": "SELF_SOLVED" if action_type == "solve_independently" else None,
"reasoning": f"{action_type} via {specialist_id or 'SELF'}",
}
def run_episode(policy_name: str, policy: Policy, task_type: str, seed: int, adaptive: bool = False) -> dict:
rng = random.Random(seed)
if hasattr(policy, "set_episode"):
policy.set_episode(task_type, seed)
env = SentinelEnv()
result = env.reset(task_type=task_type, seed=seed, adaptive=adaptive)
rewards: list[float] = []
while not result["done"]:
action = policy(env, result["observation"], rng)
result = env.step(action)
rewards.append(result["reward"]["value"])
info = result["info"]
breakdown = result["reward"]["signal_breakdown"]
detections = info.get("adversarial_detections", 0)
poisonings = info.get("adversarial_poisonings", 0)
total_adversarial = detections + poisonings
detection_rate = detections / total_adversarial if total_adversarial else breakdown.get("detection_rate", 1.0)
return {
"policy": policy_name,
"task_type": task_type,
"seed": seed,
"steps": info.get("step_count", 0),
"score": round(info.get("score", 0.0), 4),
"total_reward": round(info.get("total_reward", 0.0), 4),
"completion_rate": round(info.get("completion_rate", 0.0), 4),
"detection_rate": round(detection_rate, 4),
"trust_calibration": round(breakdown.get("trust_calibration", 0.0), 4),
"adversarial_detections": detections,
"adversarial_poisonings": poisonings,
"status": "failed" if info.get("forced_end") else "completed",
"difficulty_profile": info.get(
"difficulty_profile",
result["observation"].get("difficulty_profile", {}),
),
"rewards": [round(value, 4) for value in rewards],
}
def summarize(rows: list[dict]) -> dict:
grouped: dict[str, list[dict]] = {}
for row in rows:
grouped.setdefault(row["policy"], []).append(row)
summary = {}
for policy_name, policy_rows in grouped.items():
n = len(policy_rows)
summary[policy_name] = {
"episodes": n,
"avg_score": _avg(policy_rows, "score"),
"avg_completion_rate": _avg(policy_rows, "completion_rate"),
"avg_detection_rate": _avg(policy_rows, "detection_rate"),
"avg_trust_calibration": _avg(policy_rows, "trust_calibration"),
"avg_steps": _avg(policy_rows, "steps"),
}
return summary
def _avg(rows: list[dict], key: str) -> float:
return round(sum(float(row.get(key, 0.0)) for row in rows) / max(1, len(rows)), 4)
def summarize_by_task(rows: list[dict]) -> dict:
grouped: dict[str, list[dict]] = {}
for row in rows:
grouped.setdefault(row["task_type"], []).append(row)
return {task: summarize(task_rows) for task, task_rows in sorted(grouped.items())}
FONT_5X7 = {
" ": ["00000", "00000", "00000", "00000", "00000", "00000", "00000"],
"-": ["00000", "00000", "00000", "11111", "00000", "00000", "00000"],
".": ["00000", "00000", "00000", "00000", "00000", "01100", "01100"],
":": ["00000", "01100", "01100", "00000", "01100", "01100", "00000"],
"0": ["01110", "10001", "10011", "10101", "11001", "10001", "01110"],
"1": ["00100", "01100", "00100", "00100", "00100", "00100", "01110"],
"2": ["01110", "10001", "00001", "00010", "00100", "01000", "11111"],
"3": ["11110", "00001", "00001", "01110", "00001", "00001", "11110"],
"4": ["00010", "00110", "01010", "10010", "11111", "00010", "00010"],
"5": ["11111", "10000", "10000", "11110", "00001", "00001", "11110"],
"6": ["01110", "10000", "10000", "11110", "10001", "10001", "01110"],
"7": ["11111", "00001", "00010", "00100", "01000", "01000", "01000"],
"8": ["01110", "10001", "10001", "01110", "10001", "10001", "01110"],
"9": ["01110", "10001", "10001", "01111", "00001", "00001", "01110"],
"A": ["01110", "10001", "10001", "11111", "10001", "10001", "10001"],
"B": ["11110", "10001", "10001", "11110", "10001", "10001", "11110"],
"C": ["01110", "10001", "10000", "10000", "10000", "10001", "01110"],
"D": ["11110", "10001", "10001", "10001", "10001", "10001", "11110"],
"E": ["11111", "10000", "10000", "11110", "10000", "10000", "11111"],
"F": ["11111", "10000", "10000", "11110", "10000", "10000", "10000"],
"G": ["01110", "10001", "10000", "10111", "10001", "10001", "01110"],
"H": ["10001", "10001", "10001", "11111", "10001", "10001", "10001"],
"I": ["01110", "00100", "00100", "00100", "00100", "00100", "01110"],
"J": ["00001", "00001", "00001", "00001", "10001", "10001", "01110"],
"K": ["10001", "10010", "10100", "11000", "10100", "10010", "10001"],
"L": ["10000", "10000", "10000", "10000", "10000", "10000", "11111"],
"M": ["10001", "11011", "10101", "10101", "10001", "10001", "10001"],
"N": ["10001", "11001", "10101", "10011", "10001", "10001", "10001"],
"O": ["01110", "10001", "10001", "10001", "10001", "10001", "01110"],
"P": ["11110", "10001", "10001", "11110", "10000", "10000", "10000"],
"Q": ["01110", "10001", "10001", "10001", "10101", "10010", "01101"],
"R": ["11110", "10001", "10001", "11110", "10100", "10010", "10001"],
"S": ["01111", "10000", "10000", "01110", "00001", "00001", "11110"],
"T": ["11111", "00100", "00100", "00100", "00100", "00100", "00100"],
"U": ["10001", "10001", "10001", "10001", "10001", "10001", "01110"],
"V": ["10001", "10001", "10001", "10001", "10001", "01010", "00100"],
"W": ["10001", "10001", "10001", "10101", "10101", "10101", "01010"],
"X": ["10001", "10001", "01010", "00100", "01010", "10001", "10001"],
"Y": ["10001", "10001", "01010", "00100", "00100", "00100", "00100"],
"Z": ["11111", "00001", "00010", "00100", "01000", "10000", "11111"],
}
def write_baseline_chart(payload: dict, path: Path) -> None:
"""Write a dependency-free PNG chart for README and onsite demos."""
by_task = payload["by_task"]
tasks = list(by_task.keys())
policies = [
name for name in ("random", "heuristic", "oracle_lite", "trained")
if any(name in by_task[t] for t in tasks)
]
colors = {
"random": (239, 68, 68),
"heuristic": (59, 130, 246),
"oracle_lite": (16, 185, 129),
"trained": (168, 85, 247),
}
labels = {
"random": "RANDOM",
"heuristic": "HEURISTIC",
"oracle_lite": "ORACLE LITE",
"trained": "GRPO",
}
width, height = 1200, 720
canvas = bytearray([255, 255, 255] * width * height)
def rect(x0: int, y0: int, x1: int, y1: int, color: tuple[int, int, int]) -> None:
x0, y0 = max(0, x0), max(0, y0)
x1, y1 = min(width, x1), min(height, y1)
for y in range(y0, y1):
row = y * width * 3
for x in range(x0, x1):
idx = row + x * 3
canvas[idx : idx + 3] = bytes(color)
def text(x: int, y: int, value: str, color: tuple[int, int, int] = (20, 20, 20), scale: int = 2) -> None:
cursor = x
for ch in value.upper():
glyph = FONT_5X7.get(ch, FONT_5X7[" "])
for gy, line in enumerate(glyph):
for gx, bit in enumerate(line):
if bit == "1":
rect(cursor + gx * scale, y + gy * scale, cursor + (gx + 1) * scale, y + (gy + 1) * scale, color)
cursor += 6 * scale
def line_h(y: int, x0: int, x1: int, color: tuple[int, int, int]) -> None:
rect(x0, y, x1, y + 1, color)
def line_v(x: int, y0: int, y1: int, color: tuple[int, int, int]) -> None:
rect(x, y0, x + 1, y1, color)
margin_left, margin_top, margin_right, margin_bottom = 100, 115, 40, 115
plot_x0, plot_y0 = margin_left, margin_top
plot_x1, plot_y1 = width - margin_right, height - margin_bottom
plot_w, plot_h = plot_x1 - plot_x0, plot_y1 - plot_y0
text(50, 28, "SENTINEL BASELINE COMPARISON", (17, 24, 39), 3)
text(52, 70, "EPISODE SCORE 0.0 TO 1.0 - RANDOM VS TRUST WEIGHTED VS ORACLE LITE", (75, 85, 99), 2)
for tick in (0.0, 0.25, 0.5, 0.75, 1.0):
y = int(plot_y1 - tick * plot_h)
line_h(y, plot_x0, plot_x1, (226, 232, 240))
text(32, y - 7, f"{tick:.2f}", (100, 116, 139), 2)
line_v(plot_x0, plot_y0, plot_y1, (148, 163, 184))
line_h(plot_y1, plot_x0, plot_x1, (148, 163, 184))
group_w = plot_w / max(1, len(tasks))
bar_w = max(34, min(76, int((group_w - 80) / max(1, len(policies)))))
for task_idx, task in enumerate(tasks):
group_center = int(plot_x0 + group_w * task_idx + group_w / 2)
start_x = group_center - int((len(policies) * bar_w + (len(policies) - 1) * 18) / 2)
for policy_idx, policy in enumerate(policies):
value = float(by_task[task].get(policy, {}).get("avg_score", 0.0))
x0 = start_x + policy_idx * (bar_w + 18)
y0 = int(plot_y1 - value * plot_h)
rect(x0 + 3, y0 + 3, x0 + bar_w + 3, plot_y1 + 3, (203, 213, 225))
rect(x0, y0, x0 + bar_w, plot_y1, colors[policy])
text(x0 - 4, max(plot_y0 - 2, y0 - 24), f"{value:.2f}", (15, 23, 42), 2)
text(group_center - 36, plot_y1 + 30, task.upper(), (15, 23, 42), 2)
legend_x, legend_y = 780, 32
for idx, policy in enumerate(policies):
x = legend_x
y = legend_y + idx * 24
rect(x, y, x + 16, y + 16, colors[policy])
text(x + 24, y + 1, labels[policy], (51, 65, 85), 2)
path.parent.mkdir(parents=True, exist_ok=True)
_write_png(path, width, height, canvas)
def _write_png(path: Path, width: int, height: int, rgb: bytearray) -> None:
def chunk(tag: bytes, data: bytes) -> bytes:
return struct.pack(">I", len(data)) + tag + data + struct.pack(">I", zlib.crc32(tag + data) & 0xFFFFFFFF)
rows = []
stride = width * 3
for y in range(height):
rows.append(b"\x00" + bytes(rgb[y * stride : (y + 1) * stride]))
raw = b"".join(rows)
png = (
b"\x89PNG\r\n\x1a\n"
+ chunk(b"IHDR", struct.pack(">IIBBBBB", width, height, 8, 2, 0, 0, 0))
+ chunk(b"IDAT", zlib.compress(raw, 9))
+ chunk(b"IEND", b"")
)
path.write_bytes(png)
def main() -> None:
parser = argparse.ArgumentParser(description="Evaluate SENTINEL policies.")
parser.add_argument("--episodes", type=int, default=20, help="Episodes per policy.")
parser.add_argument("--task", default="task3", choices=["task1", "task2", "task3", "all"])
parser.add_argument("--out", default="outputs/evaluation_results.json")
parser.add_argument("--plot", default="outputs/baseline_comparison.png")
parser.add_argument("--no-plot", action="store_true")
parser.add_argument("--adaptive", action="store_true", help="Enable adaptive curriculum during evaluation.")
parser.add_argument("--reset-difficulty", action="store_true", help="Reset adaptive controller before running.")
parser.add_argument(
"--policies",
default="random,heuristic,oracle_lite",
help="Comma-separated policies: random,heuristic,oracle_lite,trained.",
)
parser.add_argument("--replay", default="outputs/trained_policy_replay.jsonl", help="Replay JSONL for --policies trained.")
args = parser.parse_args()
if args.reset_difficulty:
GLOBAL_DIFFICULTY_CONTROLLER.reset()
available_policies: dict[str, Policy] = {
"random": random_policy,
"heuristic": heuristic_policy,
"oracle_lite": oracle_lite_policy,
"trained": replay_trained_policy(ROOT / args.replay),
}
requested = [name.strip() for name in args.policies.split(",") if name.strip()]
unknown = sorted(set(requested) - set(available_policies))
if unknown:
raise SystemExit(f"Unknown policies: {', '.join(unknown)}")
policies = {name: available_policies[name] for name in requested}
tasks = ["task1", "task2", "task3"] if args.task == "all" else [args.task]
rows = []
controller_by_task_policy: dict[str, dict[str, dict]] = {}
for task_type in tasks:
for policy_name, policy in policies.items():
if args.adaptive:
GLOBAL_DIFFICULTY_CONTROLLER.reset()
policy_rows = []
for seed in range(args.episodes):
policy_rows.append(run_episode(policy_name, policy, task_type, seed, adaptive=args.adaptive))
rows.extend(policy_rows)
controller_by_task_policy.setdefault(task_type, {})[policy_name] = (
GLOBAL_DIFFICULTY_CONTROLLER.state() if args.adaptive else {}
)
payload = {
"task": args.task,
"tasks": tasks,
"episodes_per_policy": args.episodes,
"adaptive": args.adaptive,
"difficulty_controller": GLOBAL_DIFFICULTY_CONTROLLER.state(),
"difficulty_controller_by_task_policy": controller_by_task_policy,
"summary": summarize(rows),
"by_task": summarize_by_task(rows),
"episodes": rows,
}
out_path = ROOT / args.out
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(payload, indent=2) + "\n")
if not args.no_plot:
chart_path = ROOT / args.plot
write_baseline_chart(payload, chart_path)
payload["chart"] = str(chart_path.relative_to(ROOT))
out_path.write_text(json.dumps(payload, indent=2) + "\n")
print(json.dumps({"summary": payload["summary"], "by_task": payload["by_task"], "chart": payload.get("chart")}, indent=2))
if __name__ == "__main__":
main()