Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import random | |
| import struct | |
| import sys | |
| import zlib | |
| from pathlib import Path | |
| from typing import Callable | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from difficulty_controller import GLOBAL_DIFFICULTY_CONTROLLER | |
| from environment import SentinelEnv, _GROUND_TRUTH_RELIABILITY | |
| from sentinel_config import ADVERSARIAL_AWARENESS_STAKES | |
| from training.replay import replay_trained_policy | |
| Policy = Callable[[SentinelEnv, dict, random.Random], dict] | |
| POLICIES: dict[str, Policy] = {} | |
| def random_policy(env: SentinelEnv, obs: dict, rng: random.Random) -> dict: | |
| specialist = rng.choice(obs["available_specialists"]) | |
| roll = rng.random() | |
| if roll < 0.65: | |
| action_type = "delegate" | |
| elif roll < 0.85: | |
| action_type = "verify" | |
| elif roll < 0.95: | |
| action_type = "solve_independently" | |
| specialist = None | |
| else: | |
| action_type = "skip" | |
| specialist = None | |
| return _action(obs, action_type, specialist) | |
| def heuristic_policy(env: SentinelEnv, obs: dict, rng: random.Random) -> dict: | |
| trust = obs["trust_snapshot"] | |
| specialist = max(obs["available_specialists"], key=lambda sid: trust.get(sid, 0.5)) | |
| action_type = ( | |
| "verify" | |
| if obs["stakes_level"] >= ADVERSARIAL_AWARENESS_STAKES and trust.get(specialist, 0.5) < 0.65 | |
| else "delegate" | |
| ) | |
| return _action(obs, action_type, specialist) | |
| def oracle_lite_policy(env: SentinelEnv, obs: dict, rng: random.Random) -> dict: | |
| reliability = env._pool.public_ground_truth_reliability(_GROUND_TRUTH_RELIABILITY) | |
| if obs["task_type"] == "task3" and obs["stakes_level"] >= ADVERSARIAL_AWARENESS_STAKES: | |
| return _action(obs, "verify", env._pool.adversarial_slot) | |
| specialist = max(obs["available_specialists"], key=lambda sid: reliability.get(sid, 0.5)) | |
| return _action(obs, "delegate", specialist) | |
| def _action(obs: dict, action_type: str, specialist_id: str | None) -> dict: | |
| return { | |
| "session_id": obs["session_id"], | |
| "task_type": obs["task_type"], | |
| "action_type": action_type, | |
| "specialist_id": specialist_id, | |
| "subtask_response": "SELF_SOLVED" if action_type == "solve_independently" else None, | |
| "reasoning": f"{action_type} via {specialist_id or 'SELF'}", | |
| } | |
| def run_episode(policy_name: str, policy: Policy, task_type: str, seed: int, adaptive: bool = False) -> dict: | |
| rng = random.Random(seed) | |
| if hasattr(policy, "set_episode"): | |
| policy.set_episode(task_type, seed) | |
| env = SentinelEnv() | |
| result = env.reset(task_type=task_type, seed=seed, adaptive=adaptive) | |
| rewards: list[float] = [] | |
| while not result["done"]: | |
| action = policy(env, result["observation"], rng) | |
| result = env.step(action) | |
| rewards.append(result["reward"]["value"]) | |
| info = result["info"] | |
| breakdown = result["reward"]["signal_breakdown"] | |
| detections = info.get("adversarial_detections", 0) | |
| poisonings = info.get("adversarial_poisonings", 0) | |
| total_adversarial = detections + poisonings | |
| detection_rate = detections / total_adversarial if total_adversarial else breakdown.get("detection_rate", 1.0) | |
| return { | |
| "policy": policy_name, | |
| "task_type": task_type, | |
| "seed": seed, | |
| "steps": info.get("step_count", 0), | |
| "score": round(info.get("score", 0.0), 4), | |
| "total_reward": round(info.get("total_reward", 0.0), 4), | |
| "completion_rate": round(info.get("completion_rate", 0.0), 4), | |
| "detection_rate": round(detection_rate, 4), | |
| "trust_calibration": round(breakdown.get("trust_calibration", 0.0), 4), | |
| "adversarial_detections": detections, | |
| "adversarial_poisonings": poisonings, | |
| "status": "failed" if info.get("forced_end") else "completed", | |
| "difficulty_profile": info.get( | |
| "difficulty_profile", | |
| result["observation"].get("difficulty_profile", {}), | |
| ), | |
| "rewards": [round(value, 4) for value in rewards], | |
| } | |
| def summarize(rows: list[dict]) -> dict: | |
| grouped: dict[str, list[dict]] = {} | |
| for row in rows: | |
| grouped.setdefault(row["policy"], []).append(row) | |
| summary = {} | |
| for policy_name, policy_rows in grouped.items(): | |
| n = len(policy_rows) | |
| summary[policy_name] = { | |
| "episodes": n, | |
| "avg_score": _avg(policy_rows, "score"), | |
| "avg_completion_rate": _avg(policy_rows, "completion_rate"), | |
| "avg_detection_rate": _avg(policy_rows, "detection_rate"), | |
| "avg_trust_calibration": _avg(policy_rows, "trust_calibration"), | |
| "avg_steps": _avg(policy_rows, "steps"), | |
| } | |
| return summary | |
| def _avg(rows: list[dict], key: str) -> float: | |
| return round(sum(float(row.get(key, 0.0)) for row in rows) / max(1, len(rows)), 4) | |
| def summarize_by_task(rows: list[dict]) -> dict: | |
| grouped: dict[str, list[dict]] = {} | |
| for row in rows: | |
| grouped.setdefault(row["task_type"], []).append(row) | |
| return {task: summarize(task_rows) for task, task_rows in sorted(grouped.items())} | |
| FONT_5X7 = { | |
| " ": ["00000", "00000", "00000", "00000", "00000", "00000", "00000"], | |
| "-": ["00000", "00000", "00000", "11111", "00000", "00000", "00000"], | |
| ".": ["00000", "00000", "00000", "00000", "00000", "01100", "01100"], | |
| ":": ["00000", "01100", "01100", "00000", "01100", "01100", "00000"], | |
| "0": ["01110", "10001", "10011", "10101", "11001", "10001", "01110"], | |
| "1": ["00100", "01100", "00100", "00100", "00100", "00100", "01110"], | |
| "2": ["01110", "10001", "00001", "00010", "00100", "01000", "11111"], | |
| "3": ["11110", "00001", "00001", "01110", "00001", "00001", "11110"], | |
| "4": ["00010", "00110", "01010", "10010", "11111", "00010", "00010"], | |
| "5": ["11111", "10000", "10000", "11110", "00001", "00001", "11110"], | |
| "6": ["01110", "10000", "10000", "11110", "10001", "10001", "01110"], | |
| "7": ["11111", "00001", "00010", "00100", "01000", "01000", "01000"], | |
| "8": ["01110", "10001", "10001", "01110", "10001", "10001", "01110"], | |
| "9": ["01110", "10001", "10001", "01111", "00001", "00001", "01110"], | |
| "A": ["01110", "10001", "10001", "11111", "10001", "10001", "10001"], | |
| "B": ["11110", "10001", "10001", "11110", "10001", "10001", "11110"], | |
| "C": ["01110", "10001", "10000", "10000", "10000", "10001", "01110"], | |
| "D": ["11110", "10001", "10001", "10001", "10001", "10001", "11110"], | |
| "E": ["11111", "10000", "10000", "11110", "10000", "10000", "11111"], | |
| "F": ["11111", "10000", "10000", "11110", "10000", "10000", "10000"], | |
| "G": ["01110", "10001", "10000", "10111", "10001", "10001", "01110"], | |
| "H": ["10001", "10001", "10001", "11111", "10001", "10001", "10001"], | |
| "I": ["01110", "00100", "00100", "00100", "00100", "00100", "01110"], | |
| "J": ["00001", "00001", "00001", "00001", "10001", "10001", "01110"], | |
| "K": ["10001", "10010", "10100", "11000", "10100", "10010", "10001"], | |
| "L": ["10000", "10000", "10000", "10000", "10000", "10000", "11111"], | |
| "M": ["10001", "11011", "10101", "10101", "10001", "10001", "10001"], | |
| "N": ["10001", "11001", "10101", "10011", "10001", "10001", "10001"], | |
| "O": ["01110", "10001", "10001", "10001", "10001", "10001", "01110"], | |
| "P": ["11110", "10001", "10001", "11110", "10000", "10000", "10000"], | |
| "Q": ["01110", "10001", "10001", "10001", "10101", "10010", "01101"], | |
| "R": ["11110", "10001", "10001", "11110", "10100", "10010", "10001"], | |
| "S": ["01111", "10000", "10000", "01110", "00001", "00001", "11110"], | |
| "T": ["11111", "00100", "00100", "00100", "00100", "00100", "00100"], | |
| "U": ["10001", "10001", "10001", "10001", "10001", "10001", "01110"], | |
| "V": ["10001", "10001", "10001", "10001", "10001", "01010", "00100"], | |
| "W": ["10001", "10001", "10001", "10101", "10101", "10101", "01010"], | |
| "X": ["10001", "10001", "01010", "00100", "01010", "10001", "10001"], | |
| "Y": ["10001", "10001", "01010", "00100", "00100", "00100", "00100"], | |
| "Z": ["11111", "00001", "00010", "00100", "01000", "10000", "11111"], | |
| } | |
| def write_baseline_chart(payload: dict, path: Path) -> None: | |
| """Write a dependency-free PNG chart for README and onsite demos.""" | |
| by_task = payload["by_task"] | |
| tasks = list(by_task.keys()) | |
| policies = [ | |
| name for name in ("random", "heuristic", "oracle_lite", "trained") | |
| if any(name in by_task[t] for t in tasks) | |
| ] | |
| colors = { | |
| "random": (239, 68, 68), | |
| "heuristic": (59, 130, 246), | |
| "oracle_lite": (16, 185, 129), | |
| "trained": (168, 85, 247), | |
| } | |
| labels = { | |
| "random": "RANDOM", | |
| "heuristic": "HEURISTIC", | |
| "oracle_lite": "ORACLE LITE", | |
| "trained": "GRPO", | |
| } | |
| width, height = 1200, 720 | |
| canvas = bytearray([255, 255, 255] * width * height) | |
| def rect(x0: int, y0: int, x1: int, y1: int, color: tuple[int, int, int]) -> None: | |
| x0, y0 = max(0, x0), max(0, y0) | |
| x1, y1 = min(width, x1), min(height, y1) | |
| for y in range(y0, y1): | |
| row = y * width * 3 | |
| for x in range(x0, x1): | |
| idx = row + x * 3 | |
| canvas[idx : idx + 3] = bytes(color) | |
| def text(x: int, y: int, value: str, color: tuple[int, int, int] = (20, 20, 20), scale: int = 2) -> None: | |
| cursor = x | |
| for ch in value.upper(): | |
| glyph = FONT_5X7.get(ch, FONT_5X7[" "]) | |
| for gy, line in enumerate(glyph): | |
| for gx, bit in enumerate(line): | |
| if bit == "1": | |
| rect(cursor + gx * scale, y + gy * scale, cursor + (gx + 1) * scale, y + (gy + 1) * scale, color) | |
| cursor += 6 * scale | |
| def line_h(y: int, x0: int, x1: int, color: tuple[int, int, int]) -> None: | |
| rect(x0, y, x1, y + 1, color) | |
| def line_v(x: int, y0: int, y1: int, color: tuple[int, int, int]) -> None: | |
| rect(x, y0, x + 1, y1, color) | |
| margin_left, margin_top, margin_right, margin_bottom = 100, 115, 40, 115 | |
| plot_x0, plot_y0 = margin_left, margin_top | |
| plot_x1, plot_y1 = width - margin_right, height - margin_bottom | |
| plot_w, plot_h = plot_x1 - plot_x0, plot_y1 - plot_y0 | |
| text(50, 28, "SENTINEL BASELINE COMPARISON", (17, 24, 39), 3) | |
| text(52, 70, "EPISODE SCORE 0.0 TO 1.0 - RANDOM VS TRUST WEIGHTED VS ORACLE LITE", (75, 85, 99), 2) | |
| for tick in (0.0, 0.25, 0.5, 0.75, 1.0): | |
| y = int(plot_y1 - tick * plot_h) | |
| line_h(y, plot_x0, plot_x1, (226, 232, 240)) | |
| text(32, y - 7, f"{tick:.2f}", (100, 116, 139), 2) | |
| line_v(plot_x0, plot_y0, plot_y1, (148, 163, 184)) | |
| line_h(plot_y1, plot_x0, plot_x1, (148, 163, 184)) | |
| group_w = plot_w / max(1, len(tasks)) | |
| bar_w = max(34, min(76, int((group_w - 80) / max(1, len(policies))))) | |
| for task_idx, task in enumerate(tasks): | |
| group_center = int(plot_x0 + group_w * task_idx + group_w / 2) | |
| start_x = group_center - int((len(policies) * bar_w + (len(policies) - 1) * 18) / 2) | |
| for policy_idx, policy in enumerate(policies): | |
| value = float(by_task[task].get(policy, {}).get("avg_score", 0.0)) | |
| x0 = start_x + policy_idx * (bar_w + 18) | |
| y0 = int(plot_y1 - value * plot_h) | |
| rect(x0 + 3, y0 + 3, x0 + bar_w + 3, plot_y1 + 3, (203, 213, 225)) | |
| rect(x0, y0, x0 + bar_w, plot_y1, colors[policy]) | |
| text(x0 - 4, max(plot_y0 - 2, y0 - 24), f"{value:.2f}", (15, 23, 42), 2) | |
| text(group_center - 36, plot_y1 + 30, task.upper(), (15, 23, 42), 2) | |
| legend_x, legend_y = 780, 32 | |
| for idx, policy in enumerate(policies): | |
| x = legend_x | |
| y = legend_y + idx * 24 | |
| rect(x, y, x + 16, y + 16, colors[policy]) | |
| text(x + 24, y + 1, labels[policy], (51, 65, 85), 2) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| _write_png(path, width, height, canvas) | |
| def _write_png(path: Path, width: int, height: int, rgb: bytearray) -> None: | |
| def chunk(tag: bytes, data: bytes) -> bytes: | |
| return struct.pack(">I", len(data)) + tag + data + struct.pack(">I", zlib.crc32(tag + data) & 0xFFFFFFFF) | |
| rows = [] | |
| stride = width * 3 | |
| for y in range(height): | |
| rows.append(b"\x00" + bytes(rgb[y * stride : (y + 1) * stride])) | |
| raw = b"".join(rows) | |
| png = ( | |
| b"\x89PNG\r\n\x1a\n" | |
| + chunk(b"IHDR", struct.pack(">IIBBBBB", width, height, 8, 2, 0, 0, 0)) | |
| + chunk(b"IDAT", zlib.compress(raw, 9)) | |
| + chunk(b"IEND", b"") | |
| ) | |
| path.write_bytes(png) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Evaluate SENTINEL policies.") | |
| parser.add_argument("--episodes", type=int, default=20, help="Episodes per policy.") | |
| parser.add_argument("--task", default="task3", choices=["task1", "task2", "task3", "all"]) | |
| parser.add_argument("--out", default="outputs/evaluation_results.json") | |
| parser.add_argument("--plot", default="outputs/baseline_comparison.png") | |
| parser.add_argument("--no-plot", action="store_true") | |
| parser.add_argument("--adaptive", action="store_true", help="Enable adaptive curriculum during evaluation.") | |
| parser.add_argument("--reset-difficulty", action="store_true", help="Reset adaptive controller before running.") | |
| parser.add_argument( | |
| "--policies", | |
| default="random,heuristic,oracle_lite", | |
| help="Comma-separated policies: random,heuristic,oracle_lite,trained.", | |
| ) | |
| parser.add_argument("--replay", default="outputs/trained_policy_replay.jsonl", help="Replay JSONL for --policies trained.") | |
| args = parser.parse_args() | |
| if args.reset_difficulty: | |
| GLOBAL_DIFFICULTY_CONTROLLER.reset() | |
| available_policies: dict[str, Policy] = { | |
| "random": random_policy, | |
| "heuristic": heuristic_policy, | |
| "oracle_lite": oracle_lite_policy, | |
| "trained": replay_trained_policy(ROOT / args.replay), | |
| } | |
| requested = [name.strip() for name in args.policies.split(",") if name.strip()] | |
| unknown = sorted(set(requested) - set(available_policies)) | |
| if unknown: | |
| raise SystemExit(f"Unknown policies: {', '.join(unknown)}") | |
| policies = {name: available_policies[name] for name in requested} | |
| tasks = ["task1", "task2", "task3"] if args.task == "all" else [args.task] | |
| rows = [] | |
| controller_by_task_policy: dict[str, dict[str, dict]] = {} | |
| for task_type in tasks: | |
| for policy_name, policy in policies.items(): | |
| if args.adaptive: | |
| GLOBAL_DIFFICULTY_CONTROLLER.reset() | |
| policy_rows = [] | |
| for seed in range(args.episodes): | |
| policy_rows.append(run_episode(policy_name, policy, task_type, seed, adaptive=args.adaptive)) | |
| rows.extend(policy_rows) | |
| controller_by_task_policy.setdefault(task_type, {})[policy_name] = ( | |
| GLOBAL_DIFFICULTY_CONTROLLER.state() if args.adaptive else {} | |
| ) | |
| payload = { | |
| "task": args.task, | |
| "tasks": tasks, | |
| "episodes_per_policy": args.episodes, | |
| "adaptive": args.adaptive, | |
| "difficulty_controller": GLOBAL_DIFFICULTY_CONTROLLER.state(), | |
| "difficulty_controller_by_task_policy": controller_by_task_policy, | |
| "summary": summarize(rows), | |
| "by_task": summarize_by_task(rows), | |
| "episodes": rows, | |
| } | |
| out_path = ROOT / args.out | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| out_path.write_text(json.dumps(payload, indent=2) + "\n") | |
| if not args.no_plot: | |
| chart_path = ROOT / args.plot | |
| write_baseline_chart(payload, chart_path) | |
| payload["chart"] = str(chart_path.relative_to(ROOT)) | |
| out_path.write_text(json.dumps(payload, indent=2) + "\n") | |
| print(json.dumps({"summary": payload["summary"], "by_task": payload["by_task"], "chart": payload.get("chart")}, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |