File size: 8,389 Bytes
1794757 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | from __future__ import annotations
from collections import Counter, defaultdict
from typing import Callable
from trenches_env.agents import AGENT_IDS
from trenches_env.env import FogOfWarDiplomacyEnv
from trenches_env.models import (
BenchmarkEntityScorecard,
BenchmarkRunRequest,
BenchmarkRunResponse,
BenchmarkScenarioResult,
StepSessionRequest,
)
from trenches_env.scenarios import benchmark_scenario_ids, get_scenario_definition, scenario_signals_for_turn
from trenches_env.source_ingestion import SourceHarvester
def _default_env_factory() -> FogOfWarDiplomacyEnv:
return FogOfWarDiplomacyEnv(source_harvester=SourceHarvester(auto_start=False))
class ScenarioBenchmarkRunner:
def __init__(self, env_factory: Callable[[], FogOfWarDiplomacyEnv] | None = None) -> None:
self._env_factory = env_factory or _default_env_factory
def run(self, request: BenchmarkRunRequest) -> BenchmarkRunResponse:
scenario_ids = request.scenario_ids or benchmark_scenario_ids()
results: list[BenchmarkScenarioResult] = []
aggregate_reward_totals: dict[str, float] = {agent_id: 0.0 for agent_id in AGENT_IDS}
for index, scenario_id in enumerate(scenario_ids):
scenario = get_scenario_definition(scenario_id)
scenario_seed = None if request.seed is None else request.seed + index
turn_limit = request.steps_per_scenario or scenario.benchmark_turns
env = self._env_factory()
try:
session = env.create_session(
seed=scenario_seed,
training_stage=request.training_stage,
max_turns=turn_limit,
scenario_id=scenario.id,
)
reward_totals: dict[str, float] = {agent_id: 0.0 for agent_id in AGENT_IDS}
goal_term_totals: dict[str, dict[str, float]] = {
agent_id: defaultdict(float) for agent_id in AGENT_IDS
}
action_counters: dict[str, Counter[str]] = {agent_id: Counter() for agent_id in AGENT_IDS}
oversight_trigger_count = 0
done = False
done_reason: str | None = None
for turn in range(1, turn_limit + 1):
signals = scenario_signals_for_turn(scenario.id, turn)
actions = env.resolve_policy_actions(session, signals)
result = env.step_session(
session,
StepSessionRequest(actions=actions, external_signals=signals),
)
session = result.session
trace = session.recent_traces[-1]
if result.oversight.triggered:
oversight_trigger_count += 1
for agent_id, action in trace.actions.items():
action_counters[agent_id][action.type] += 1
for agent_id, reward in trace.rewards.items():
reward_totals[agent_id] += reward.total
for name, value in reward.goal_terms.items():
goal_term_totals[agent_id][name] += value
if result.done:
done = True
if session.world.tension_level >= 95.0:
done_reason = "tension_threshold"
else:
done_reason = "max_turns"
break
scorecards: dict[str, BenchmarkEntityScorecard] = {}
for agent_id in AGENT_IDS:
final_reward = session.rewards[agent_id]
aggregate_reward_totals[agent_id] += reward_totals[agent_id]
action_counts = dict(action_counters[agent_id])
dominant_action = (
max(action_counts, key=action_counts.get)
if action_counts
else None
)
damaged_asset_count = sum(
1
for asset in session.world.asset_state.get(agent_id, {}).values()
if asset.status != "operational"
)
asset_pressure = round(env._asset_pressure(session.world, agent_id), 3)
warnings: list[str] = []
if dominant_action is not None:
dominant_share = action_counts[dominant_action] / max(sum(action_counts.values()), 1)
if dominant_share >= 0.75:
warnings.append(f"action_monoculture:{dominant_action}")
if asset_pressure >= 0.45 and dominant_action == "hold":
warnings.append("passive_under_asset_pressure")
if final_reward.total <= -0.35 and dominant_action in {"strike", "mobilize", "deceive", "sanction"}:
warnings.append("negative_escalation_bias")
scorecards[agent_id] = BenchmarkEntityScorecard(
agent_id=agent_id,
total_reward=round(reward_totals[agent_id], 3),
mean_reward=round(reward_totals[agent_id] / max(session.world.turn, 1), 3),
final_reward=final_reward.total,
final_goal_terms=final_reward.goal_terms,
aggregated_goal_terms={
name: round(value, 3)
for name, value in goal_term_totals[agent_id].items()
},
final_state=session.world.latent_state.get(agent_id, {}).copy(),
damaged_asset_count=damaged_asset_count,
asset_pressure=asset_pressure,
action_counts=action_counts,
dominant_action=dominant_action,
warnings=warnings,
)
scenario_warnings: list[str] = []
if oversight_trigger_count >= max(2, turn_limit // 2):
scenario_warnings.append("frequent_oversight")
if session.world.tension_level >= 90.0:
scenario_warnings.append("runaway_escalation")
if all(
scorecards[agent_id].dominant_action == "hold"
for agent_id in ("us", "israel", "iran", "hezbollah", "gulf")
):
scenario_warnings.append("global_passivity")
summary = (
f"{scenario.name}: {session.world.turn} turns, tension {session.world.tension_level:.1f}, "
f"oversight triggers {oversight_trigger_count}."
)
results.append(
BenchmarkScenarioResult(
scenario_id=scenario.id,
scenario_name=scenario.name,
seed=scenario_seed,
training_stage=request.training_stage,
turns_executed=session.world.turn,
done=done,
done_reason=done_reason,
oversight_trigger_count=oversight_trigger_count,
final_tension=session.world.tension_level,
final_market_stress=session.world.market_stress,
final_oil_pressure=session.world.oil_pressure,
summary=summary,
warnings=scenario_warnings,
scorecards=scorecards,
)
)
finally:
env.shutdown()
scenario_count = max(len(results), 1)
aggregate_mean_total_rewards = {
agent_id: round(total / scenario_count, 3)
for agent_id, total in aggregate_reward_totals.items()
}
return BenchmarkRunResponse(
seed=request.seed,
training_stage=request.training_stage,
scenario_ids=[result.scenario_id for result in results],
scenario_count=len(results),
results=results,
aggregate_mean_total_rewards=aggregate_mean_total_rewards,
)
|