File size: 8,389 Bytes
1794757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from __future__ import annotations

from collections import Counter, defaultdict
from typing import Callable

from trenches_env.agents import AGENT_IDS
from trenches_env.env import FogOfWarDiplomacyEnv
from trenches_env.models import (
    BenchmarkEntityScorecard,
    BenchmarkRunRequest,
    BenchmarkRunResponse,
    BenchmarkScenarioResult,
    StepSessionRequest,
)
from trenches_env.scenarios import benchmark_scenario_ids, get_scenario_definition, scenario_signals_for_turn
from trenches_env.source_ingestion import SourceHarvester


def _default_env_factory() -> FogOfWarDiplomacyEnv:
    return FogOfWarDiplomacyEnv(source_harvester=SourceHarvester(auto_start=False))


class ScenarioBenchmarkRunner:
    def __init__(self, env_factory: Callable[[], FogOfWarDiplomacyEnv] | None = None) -> None:
        self._env_factory = env_factory or _default_env_factory

    def run(self, request: BenchmarkRunRequest) -> BenchmarkRunResponse:
        scenario_ids = request.scenario_ids or benchmark_scenario_ids()
        results: list[BenchmarkScenarioResult] = []
        aggregate_reward_totals: dict[str, float] = {agent_id: 0.0 for agent_id in AGENT_IDS}

        for index, scenario_id in enumerate(scenario_ids):
            scenario = get_scenario_definition(scenario_id)
            scenario_seed = None if request.seed is None else request.seed + index
            turn_limit = request.steps_per_scenario or scenario.benchmark_turns
            env = self._env_factory()

            try:
                session = env.create_session(
                    seed=scenario_seed,
                    training_stage=request.training_stage,
                    max_turns=turn_limit,
                    scenario_id=scenario.id,
                )
                reward_totals: dict[str, float] = {agent_id: 0.0 for agent_id in AGENT_IDS}
                goal_term_totals: dict[str, dict[str, float]] = {
                    agent_id: defaultdict(float) for agent_id in AGENT_IDS
                }
                action_counters: dict[str, Counter[str]] = {agent_id: Counter() for agent_id in AGENT_IDS}
                oversight_trigger_count = 0
                done = False
                done_reason: str | None = None

                for turn in range(1, turn_limit + 1):
                    signals = scenario_signals_for_turn(scenario.id, turn)
                    actions = env.resolve_policy_actions(session, signals)
                    result = env.step_session(
                        session,
                        StepSessionRequest(actions=actions, external_signals=signals),
                    )
                    session = result.session
                    trace = session.recent_traces[-1]

                    if result.oversight.triggered:
                        oversight_trigger_count += 1

                    for agent_id, action in trace.actions.items():
                        action_counters[agent_id][action.type] += 1

                    for agent_id, reward in trace.rewards.items():
                        reward_totals[agent_id] += reward.total
                        for name, value in reward.goal_terms.items():
                            goal_term_totals[agent_id][name] += value

                    if result.done:
                        done = True
                        if session.world.tension_level >= 95.0:
                            done_reason = "tension_threshold"
                        else:
                            done_reason = "max_turns"
                        break

                scorecards: dict[str, BenchmarkEntityScorecard] = {}
                for agent_id in AGENT_IDS:
                    final_reward = session.rewards[agent_id]
                    aggregate_reward_totals[agent_id] += reward_totals[agent_id]
                    action_counts = dict(action_counters[agent_id])
                    dominant_action = (
                        max(action_counts, key=action_counts.get)
                        if action_counts
                        else None
                    )
                    damaged_asset_count = sum(
                        1
                        for asset in session.world.asset_state.get(agent_id, {}).values()
                        if asset.status != "operational"
                    )
                    asset_pressure = round(env._asset_pressure(session.world, agent_id), 3)
                    warnings: list[str] = []
                    if dominant_action is not None:
                        dominant_share = action_counts[dominant_action] / max(sum(action_counts.values()), 1)
                        if dominant_share >= 0.75:
                            warnings.append(f"action_monoculture:{dominant_action}")
                    if asset_pressure >= 0.45 and dominant_action == "hold":
                        warnings.append("passive_under_asset_pressure")
                    if final_reward.total <= -0.35 and dominant_action in {"strike", "mobilize", "deceive", "sanction"}:
                        warnings.append("negative_escalation_bias")

                    scorecards[agent_id] = BenchmarkEntityScorecard(
                        agent_id=agent_id,
                        total_reward=round(reward_totals[agent_id], 3),
                        mean_reward=round(reward_totals[agent_id] / max(session.world.turn, 1), 3),
                        final_reward=final_reward.total,
                        final_goal_terms=final_reward.goal_terms,
                        aggregated_goal_terms={
                            name: round(value, 3)
                            for name, value in goal_term_totals[agent_id].items()
                        },
                        final_state=session.world.latent_state.get(agent_id, {}).copy(),
                        damaged_asset_count=damaged_asset_count,
                        asset_pressure=asset_pressure,
                        action_counts=action_counts,
                        dominant_action=dominant_action,
                        warnings=warnings,
                    )

                scenario_warnings: list[str] = []
                if oversight_trigger_count >= max(2, turn_limit // 2):
                    scenario_warnings.append("frequent_oversight")
                if session.world.tension_level >= 90.0:
                    scenario_warnings.append("runaway_escalation")
                if all(
                    scorecards[agent_id].dominant_action == "hold"
                    for agent_id in ("us", "israel", "iran", "hezbollah", "gulf")
                ):
                    scenario_warnings.append("global_passivity")

                summary = (
                    f"{scenario.name}: {session.world.turn} turns, tension {session.world.tension_level:.1f}, "
                    f"oversight triggers {oversight_trigger_count}."
                )
                results.append(
                    BenchmarkScenarioResult(
                        scenario_id=scenario.id,
                        scenario_name=scenario.name,
                        seed=scenario_seed,
                        training_stage=request.training_stage,
                        turns_executed=session.world.turn,
                        done=done,
                        done_reason=done_reason,
                        oversight_trigger_count=oversight_trigger_count,
                        final_tension=session.world.tension_level,
                        final_market_stress=session.world.market_stress,
                        final_oil_pressure=session.world.oil_pressure,
                        summary=summary,
                        warnings=scenario_warnings,
                        scorecards=scorecards,
                    )
                )
            finally:
                env.shutdown()

        scenario_count = max(len(results), 1)
        aggregate_mean_total_rewards = {
            agent_id: round(total / scenario_count, 3)
            for agent_id, total in aggregate_reward_totals.items()
        }
        return BenchmarkRunResponse(
            seed=request.seed,
            training_stage=request.training_stage,
            scenario_ids=[result.scenario_id for result in results],
            scenario_count=len(results),
            results=results,
            aggregate_mean_total_rewards=aggregate_mean_total_rewards,
        )