| from __future__ import annotations | |
| from collections import Counter, defaultdict | |
| from typing import Callable | |
| from trenches_env.agents import AGENT_IDS | |
| from trenches_env.env import FogOfWarDiplomacyEnv | |
| from trenches_env.models import ( | |
| BenchmarkEntityScorecard, | |
| BenchmarkRunRequest, | |
| BenchmarkRunResponse, | |
| BenchmarkScenarioResult, | |
| StepSessionRequest, | |
| ) | |
| from trenches_env.scenarios import benchmark_scenario_ids, get_scenario_definition, scenario_signals_for_turn | |
| from trenches_env.source_ingestion import SourceHarvester | |
| def _default_env_factory() -> FogOfWarDiplomacyEnv: | |
| return FogOfWarDiplomacyEnv(source_harvester=SourceHarvester(auto_start=False)) | |
| class ScenarioBenchmarkRunner: | |
| def __init__(self, env_factory: Callable[[], FogOfWarDiplomacyEnv] | None = None) -> None: | |
| self._env_factory = env_factory or _default_env_factory | |
| def run(self, request: BenchmarkRunRequest) -> BenchmarkRunResponse: | |
| scenario_ids = request.scenario_ids or benchmark_scenario_ids() | |
| results: list[BenchmarkScenarioResult] = [] | |
| aggregate_reward_totals: dict[str, float] = {agent_id: 0.0 for agent_id in AGENT_IDS} | |
| for index, scenario_id in enumerate(scenario_ids): | |
| scenario = get_scenario_definition(scenario_id) | |
| scenario_seed = None if request.seed is None else request.seed + index | |
| turn_limit = request.steps_per_scenario or scenario.benchmark_turns | |
| env = self._env_factory() | |
| try: | |
| session = env.create_session( | |
| seed=scenario_seed, | |
| training_stage=request.training_stage, | |
| max_turns=turn_limit, | |
| scenario_id=scenario.id, | |
| ) | |
| reward_totals: dict[str, float] = {agent_id: 0.0 for agent_id in AGENT_IDS} | |
| goal_term_totals: dict[str, dict[str, float]] = { | |
| agent_id: defaultdict(float) for agent_id in AGENT_IDS | |
| } | |
| action_counters: dict[str, Counter[str]] = {agent_id: Counter() for agent_id in AGENT_IDS} | |
| oversight_trigger_count = 0 | |
| done = False | |
| done_reason: str | None = None | |
| for turn in range(1, turn_limit + 1): | |
| signals = scenario_signals_for_turn(scenario.id, turn) | |
| actions = env.resolve_policy_actions(session, signals) | |
| result = env.step_session( | |
| session, | |
| StepSessionRequest(actions=actions, external_signals=signals), | |
| ) | |
| session = result.session | |
| trace = session.recent_traces[-1] | |
| if result.oversight.triggered: | |
| oversight_trigger_count += 1 | |
| for agent_id, action in trace.actions.items(): | |
| action_counters[agent_id][action.type] += 1 | |
| for agent_id, reward in trace.rewards.items(): | |
| reward_totals[agent_id] += reward.total | |
| for name, value in reward.goal_terms.items(): | |
| goal_term_totals[agent_id][name] += value | |
| if result.done: | |
| done = True | |
| if session.world.tension_level >= 95.0: | |
| done_reason = "tension_threshold" | |
| else: | |
| done_reason = "max_turns" | |
| break | |
| scorecards: dict[str, BenchmarkEntityScorecard] = {} | |
| for agent_id in AGENT_IDS: | |
| final_reward = session.rewards[agent_id] | |
| aggregate_reward_totals[agent_id] += reward_totals[agent_id] | |
| action_counts = dict(action_counters[agent_id]) | |
| dominant_action = ( | |
| max(action_counts, key=action_counts.get) | |
| if action_counts | |
| else None | |
| ) | |
| damaged_asset_count = sum( | |
| 1 | |
| for asset in session.world.asset_state.get(agent_id, {}).values() | |
| if asset.status != "operational" | |
| ) | |
| asset_pressure = round(env._asset_pressure(session.world, agent_id), 3) | |
| warnings: list[str] = [] | |
| if dominant_action is not None: | |
| dominant_share = action_counts[dominant_action] / max(sum(action_counts.values()), 1) | |
| if dominant_share >= 0.75: | |
| warnings.append(f"action_monoculture:{dominant_action}") | |
| if asset_pressure >= 0.45 and dominant_action == "hold": | |
| warnings.append("passive_under_asset_pressure") | |
| if final_reward.total <= -0.35 and dominant_action in {"strike", "mobilize", "deceive", "sanction"}: | |
| warnings.append("negative_escalation_bias") | |
| scorecards[agent_id] = BenchmarkEntityScorecard( | |
| agent_id=agent_id, | |
| total_reward=round(reward_totals[agent_id], 3), | |
| mean_reward=round(reward_totals[agent_id] / max(session.world.turn, 1), 3), | |
| final_reward=final_reward.total, | |
| final_goal_terms=final_reward.goal_terms, | |
| aggregated_goal_terms={ | |
| name: round(value, 3) | |
| for name, value in goal_term_totals[agent_id].items() | |
| }, | |
| final_state=session.world.latent_state.get(agent_id, {}).copy(), | |
| damaged_asset_count=damaged_asset_count, | |
| asset_pressure=asset_pressure, | |
| action_counts=action_counts, | |
| dominant_action=dominant_action, | |
| warnings=warnings, | |
| ) | |
| scenario_warnings: list[str] = [] | |
| if oversight_trigger_count >= max(2, turn_limit // 2): | |
| scenario_warnings.append("frequent_oversight") | |
| if session.world.tension_level >= 90.0: | |
| scenario_warnings.append("runaway_escalation") | |
| if all( | |
| scorecards[agent_id].dominant_action == "hold" | |
| for agent_id in ("us", "israel", "iran", "hezbollah", "gulf") | |
| ): | |
| scenario_warnings.append("global_passivity") | |
| summary = ( | |
| f"{scenario.name}: {session.world.turn} turns, tension {session.world.tension_level:.1f}, " | |
| f"oversight triggers {oversight_trigger_count}." | |
| ) | |
| results.append( | |
| BenchmarkScenarioResult( | |
| scenario_id=scenario.id, | |
| scenario_name=scenario.name, | |
| seed=scenario_seed, | |
| training_stage=request.training_stage, | |
| turns_executed=session.world.turn, | |
| done=done, | |
| done_reason=done_reason, | |
| oversight_trigger_count=oversight_trigger_count, | |
| final_tension=session.world.tension_level, | |
| final_market_stress=session.world.market_stress, | |
| final_oil_pressure=session.world.oil_pressure, | |
| summary=summary, | |
| warnings=scenario_warnings, | |
| scorecards=scorecards, | |
| ) | |
| ) | |
| finally: | |
| env.shutdown() | |
| scenario_count = max(len(results), 1) | |
| aggregate_mean_total_rewards = { | |
| agent_id: round(total / scenario_count, 3) | |
| for agent_id, total in aggregate_reward_totals.items() | |
| } | |
| return BenchmarkRunResponse( | |
| seed=request.seed, | |
| training_stage=request.training_stage, | |
| scenario_ids=[result.scenario_id for result in results], | |
| scenario_count=len(results), | |
| results=results, | |
| aggregate_mean_total_rewards=aggregate_mean_total_rewards, | |
| ) | |