File size: 10,413 Bytes
a8deee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ae45f3
aad7819
a8deee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aad7819
a8deee9
 
 
 
 
 
 
 
 
 
 
 
 
 
aad7819
a8deee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ae45f3
 
 
 
 
 
a8deee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ae45f3
 
a8deee9
 
 
 
 
1ae45f3
 
 
a8deee9
1ae45f3
 
 
 
a8deee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ae45f3
a8deee9
 
 
 
1ae45f3
a8deee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
from __future__ import annotations

import argparse
import json
import random
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Callable

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from environment import SentinelEnv, _GROUND_TRUTH_RELIABILITY
from mission_context import build_orchestrator_prompt, mission_for_task, problem_statement
from sentinel_config import ADVERSARIAL_AWARENESS_STAKES


Policy = Callable[[SentinelEnv, dict, random.Random], dict]


@dataclass
class TraceRow:
    step_count: int
    subtask_index: int
    stakes: float
    action: str
    reward: float
    score: float
    completed: int
    remaining: int
    trust: dict[str, float]
    reason: str
    detections: int
    poisonings: int


def blind_trust_policy(env: SentinelEnv, obs: dict, rng: random.Random) -> dict:
    """Before SENTINEL: always trust the currently highest-trust specialist."""
    trust = obs["trust_snapshot"]
    specialist = max(obs["available_specialists"], key=lambda sid: trust.get(sid, 0.5))
    return action(obs, "delegate", specialist, f"blind-trust best={specialist}")


def sentinel_heuristic_policy(env: SentinelEnv, obs: dict, rng: random.Random) -> dict:
    """
    Simple SENTINEL-style baseline.
    It still has flaws, but it shows the intended skill: verify when risk rises.
    """
    trust = obs["trust_snapshot"]
    specialist = max(obs["available_specialists"], key=lambda sid: trust.get(sid, 0.5))
    score = trust.get(specialist, 0.5)
    action_type = "verify" if obs["stakes_level"] >= ADVERSARIAL_AWARENESS_STAKES and score < 0.65 else "delegate"
    return action(
        obs,
        action_type,
        specialist,
        f"sentinel-heuristic best={specialist} trust={score:.3f} stakes={obs['stakes_level']:.2f}",
    )


def oracle_lite_policy(env: SentinelEnv, obs: dict, rng: random.Random) -> dict:
    """
    Upper-bound policy used for demos and evaluator comparison.
    It uses hidden builder-only info, so it is NOT a deployable policy.
    """
    reliability = env._pool.public_ground_truth_reliability(_GROUND_TRUTH_RELIABILITY)
    if obs["task_type"] == "task3" and obs["stakes_level"] >= ADVERSARIAL_AWARENESS_STAKES:
        return action(obs, "verify", env._pool.adversarial_slot, "oracle-lite verifies adversarial slot")
    specialist = max(obs["available_specialists"], key=lambda sid: reliability.get(sid, 0.5))
    return action(obs, "delegate", specialist, f"oracle-lite best={specialist}")


POLICIES: dict[str, Policy] = {
    "blind": blind_trust_policy,
    "heuristic": sentinel_heuristic_policy,
    "oracle": oracle_lite_policy,
}


def action(obs: dict, action_type: str, specialist_id: str | None, reason: str) -> dict:
    return {
        "session_id": obs["session_id"],
        "task_type": obs["task_type"],
        "action_type": action_type,
        "specialist_id": specialist_id,
        "subtask_response": "SELF_SOLVED" if action_type == "solve_independently" else None,
        "reasoning": reason,
    }


def compact_reset(result: dict) -> dict:
    obs = result["observation"]
    return {
        "session_id": obs["session_id"],
        "scenario_id": obs["scenario_id"],
        "task_type": obs["task_type"],
        "current_subtask": obs["current_subtask"],
        "available_specialists": obs["available_specialists"],
        "trust_snapshot": obs["trust_snapshot"],
        "stakes_level": obs["stakes_level"],
        "step_count": obs["step_count"],
        "max_steps": obs["max_steps"],
        "done": result["done"],
        "reward": result["reward"],
    }


def run_episode(
    policy_name: str,
    task_type: str,
    seed: int,
    show_hidden: bool,
    max_rows: int | None,
) -> tuple[SentinelEnv, dict, list[TraceRow]]:
    policy = POLICIES[policy_name]
    rng = random.Random(seed)
    env = SentinelEnv()
    result = env.reset(task_type=task_type, seed=seed)
    rows: list[TraceRow] = []

    print_header(policy_name, task_type, seed)
    print("RESET JSON - compact agent-facing shape")
    print(json.dumps(compact_reset(result), indent=2))
    print()
    print("LLM ORCHESTRATOR PROMPT - first 28 lines")
    prompt_lines = build_orchestrator_prompt(result["observation"]).splitlines()
    print("\n".join(prompt_lines[:28]))
    if len(prompt_lines) > 28:
        print("...")
    print()
    if show_hidden:
        print("BUILDER-ONLY HIDDEN PROFILE - agent never sees this")
        print(json.dumps({
            "public_slot_to_internal_behavior": env._pool.internal_profile(),
            "adversarial_public_slot": env._pool.adversarial_slot,
        }, indent=2))
        print()

    print_trace_header()
    guard = 0
    while not result["done"] and guard < 100:
        obs = result["observation"]
        chosen = policy(env, obs, rng)
        result = env.step(chosen)
        graph_summary = env._graph.summary()
        row = TraceRow(
            step_count=result["info"]["step_count"],
            subtask_index=result["observation"]["subtask_index"],
            stakes=obs["stakes_level"],
            action=f"{chosen['action_type']}:{chosen.get('specialist_id') or 'SELF'}",
            reward=result["reward"]["value"],
            score=result["info"]["score"],
            completed=graph_summary["subtasks_completed"],
            remaining=graph_summary["subtasks_remaining"],
            trust=result["observation"]["trust_snapshot"],
            reason=result["reward"]["reason"],
            detections=graph_summary["adversarial_detections"],
            poisonings=graph_summary["adversarial_poisonings"],
        )
        rows.append(row)
        if max_rows is None or len(rows) <= max_rows:
            print_trace_row(row)
        guard += 1

    if max_rows is not None and len(rows) > max_rows:
        print(f"... {len(rows) - max_rows} more rows hidden by --max-rows")

    print()
    print("FINAL INFO")
    print(json.dumps(result["info"], indent=2))
    print("FINAL REWARD")
    print(json.dumps(result["reward"], indent=2))
    print()
    return env, result, rows


def print_header(policy_name: str, task_type: str, seed: int) -> None:
    problem = problem_statement()["problem"]
    mission = mission_for_task(task_type)
    print("=" * 92)
    print("SENTINEL BACKEND WALKTHROUGH")
    print("=" * 92)
    print(f"policy={policy_name} task={task_type} seed={seed}")
    print()
    print("REAL USER PROMPT EXAMPLE")
    print(problem["real_user_prompt_example"])
    print()
    print("REAL-WORLD MAPPING")
    print(problem["not_a_simple_prompt_solver"])
    print(f"Task mission: {mission['judge_friendly_story']}")
    print("The JSON action is the next internal control move, not the final user answer.")
    print("SENTINEL trains the transferable behavior: trust, verify, recover, finish.")
    print()


def print_trace_header() -> None:
    print("STEP TRACE")
    print(
        "step | node | stake | action          | reward | score | done/rem | adv det/poison | trust snapshot"
    )
    print("-" * 132)


def print_trace_row(row: TraceRow) -> None:
    trust = " ".join(f"{sid}:{score:.3f}" for sid, score in row.trust.items())
    print(
        f"{row.step_count:>4} | {row.subtask_index:>4} | {row.stakes:>5.2f} | "
        f"{row.action:<15} | {row.reward:>6.3f} | {row.score:>5.3f} | "
        f"{row.completed:>2}/{row.completed + row.remaining:<2} | "
        f"{row.detections:>2}/{row.poisonings:<2} | {trust}"
    )
    print(f"     reason: {row.reason}")


def compare_policies(task_type: str, seed: int, show_hidden: bool) -> None:
    mission = mission_for_task(task_type)
    print("=" * 92)
    print("BEFORE / AFTER BACKEND COMPARISON")
    print("=" * 92)
    print("before=blind trust, middle=heuristic trust, target=oracle-lite upper bound")
    print(f"mission={mission['name']} - {mission['real_life_example']}")
    print()
    results = []
    for policy_name in ("blind", "heuristic", "oracle"):
        env = SentinelEnv()
        result = env.reset(task_type=task_type, seed=seed)
        rng = random.Random(seed)
        while not result["done"]:
            chosen = POLICIES[policy_name](env, result["observation"], rng)
            result = env.step(chosen)
        info = result["info"]
        results.append({
            "policy": policy_name,
            "score": info.get("score", 0.0),
            "completion": info.get("completion_rate", 0.0),
            "detections": info.get("adversarial_detections", 0),
            "poisonings": info.get("adversarial_poisonings", 0),
            "steps": info.get("step_count", 0),
            "status": "failed" if info.get("forced_end") else "completed",
        })
        if show_hidden and policy_name == "blind":
            print("Hidden profile for this comparison seed:")
            print(json.dumps({
                "public_slot_to_internal_behavior": env._pool.internal_profile(),
                "adversarial_public_slot": env._pool.adversarial_slot,
            }, indent=2))
            print()

    print("policy    | score | completion | detections | poisonings | steps | status")
    print("-" * 78)
    for item in results:
        print(
            f"{item['policy']:<9} | {item['score']:.3f} | "
            f"{item['completion']:.3f}      | {item['detections']:<10} | "
            f"{item['poisonings']:<10} | {item['steps']:<5} | {item['status']}"
        )
    print()


def main() -> None:
    parser = argparse.ArgumentParser(description="Explain SENTINEL backend behavior from terminal.")
    parser.add_argument("--task", default="task3", choices=["task1", "task2", "task3"])
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--policy", default="heuristic", choices=sorted(POLICIES))
    parser.add_argument("--hide-hidden", action="store_true", help="Do not print builder-only hidden profile.")
    parser.add_argument("--max-rows", type=int, default=None, help="Limit printed trace rows.")
    parser.add_argument("--compare", action="store_true", help="Compare blind vs heuristic vs oracle-lite.")
    args = parser.parse_args()

    show_hidden = not args.hide_hidden
    if args.compare:
        compare_policies(args.task, args.seed, show_hidden)
    run_episode(args.policy, args.task, args.seed, show_hidden, args.max_rows)


if __name__ == "__main__":
    main()