trenches / backend /src /trenches_env /benchmark_runner.py
Codex
sync main snapshot for HF Space
1794757
from __future__ import annotations
from collections import Counter, defaultdict
from typing import Callable
from trenches_env.agents import AGENT_IDS
from trenches_env.env import FogOfWarDiplomacyEnv
from trenches_env.models import (
BenchmarkEntityScorecard,
BenchmarkRunRequest,
BenchmarkRunResponse,
BenchmarkScenarioResult,
StepSessionRequest,
)
from trenches_env.scenarios import benchmark_scenario_ids, get_scenario_definition, scenario_signals_for_turn
from trenches_env.source_ingestion import SourceHarvester
def _default_env_factory() -> FogOfWarDiplomacyEnv:
return FogOfWarDiplomacyEnv(source_harvester=SourceHarvester(auto_start=False))
class ScenarioBenchmarkRunner:
def __init__(self, env_factory: Callable[[], FogOfWarDiplomacyEnv] | None = None) -> None:
self._env_factory = env_factory or _default_env_factory
def run(self, request: BenchmarkRunRequest) -> BenchmarkRunResponse:
scenario_ids = request.scenario_ids or benchmark_scenario_ids()
results: list[BenchmarkScenarioResult] = []
aggregate_reward_totals: dict[str, float] = {agent_id: 0.0 for agent_id in AGENT_IDS}
for index, scenario_id in enumerate(scenario_ids):
scenario = get_scenario_definition(scenario_id)
scenario_seed = None if request.seed is None else request.seed + index
turn_limit = request.steps_per_scenario or scenario.benchmark_turns
env = self._env_factory()
try:
session = env.create_session(
seed=scenario_seed,
training_stage=request.training_stage,
max_turns=turn_limit,
scenario_id=scenario.id,
)
reward_totals: dict[str, float] = {agent_id: 0.0 for agent_id in AGENT_IDS}
goal_term_totals: dict[str, dict[str, float]] = {
agent_id: defaultdict(float) for agent_id in AGENT_IDS
}
action_counters: dict[str, Counter[str]] = {agent_id: Counter() for agent_id in AGENT_IDS}
oversight_trigger_count = 0
done = False
done_reason: str | None = None
for turn in range(1, turn_limit + 1):
signals = scenario_signals_for_turn(scenario.id, turn)
actions = env.resolve_policy_actions(session, signals)
result = env.step_session(
session,
StepSessionRequest(actions=actions, external_signals=signals),
)
session = result.session
trace = session.recent_traces[-1]
if result.oversight.triggered:
oversight_trigger_count += 1
for agent_id, action in trace.actions.items():
action_counters[agent_id][action.type] += 1
for agent_id, reward in trace.rewards.items():
reward_totals[agent_id] += reward.total
for name, value in reward.goal_terms.items():
goal_term_totals[agent_id][name] += value
if result.done:
done = True
if session.world.tension_level >= 95.0:
done_reason = "tension_threshold"
else:
done_reason = "max_turns"
break
scorecards: dict[str, BenchmarkEntityScorecard] = {}
for agent_id in AGENT_IDS:
final_reward = session.rewards[agent_id]
aggregate_reward_totals[agent_id] += reward_totals[agent_id]
action_counts = dict(action_counters[agent_id])
dominant_action = (
max(action_counts, key=action_counts.get)
if action_counts
else None
)
damaged_asset_count = sum(
1
for asset in session.world.asset_state.get(agent_id, {}).values()
if asset.status != "operational"
)
asset_pressure = round(env._asset_pressure(session.world, agent_id), 3)
warnings: list[str] = []
if dominant_action is not None:
dominant_share = action_counts[dominant_action] / max(sum(action_counts.values()), 1)
if dominant_share >= 0.75:
warnings.append(f"action_monoculture:{dominant_action}")
if asset_pressure >= 0.45 and dominant_action == "hold":
warnings.append("passive_under_asset_pressure")
if final_reward.total <= -0.35 and dominant_action in {"strike", "mobilize", "deceive", "sanction"}:
warnings.append("negative_escalation_bias")
scorecards[agent_id] = BenchmarkEntityScorecard(
agent_id=agent_id,
total_reward=round(reward_totals[agent_id], 3),
mean_reward=round(reward_totals[agent_id] / max(session.world.turn, 1), 3),
final_reward=final_reward.total,
final_goal_terms=final_reward.goal_terms,
aggregated_goal_terms={
name: round(value, 3)
for name, value in goal_term_totals[agent_id].items()
},
final_state=session.world.latent_state.get(agent_id, {}).copy(),
damaged_asset_count=damaged_asset_count,
asset_pressure=asset_pressure,
action_counts=action_counts,
dominant_action=dominant_action,
warnings=warnings,
)
scenario_warnings: list[str] = []
if oversight_trigger_count >= max(2, turn_limit // 2):
scenario_warnings.append("frequent_oversight")
if session.world.tension_level >= 90.0:
scenario_warnings.append("runaway_escalation")
if all(
scorecards[agent_id].dominant_action == "hold"
for agent_id in ("us", "israel", "iran", "hezbollah", "gulf")
):
scenario_warnings.append("global_passivity")
summary = (
f"{scenario.name}: {session.world.turn} turns, tension {session.world.tension_level:.1f}, "
f"oversight triggers {oversight_trigger_count}."
)
results.append(
BenchmarkScenarioResult(
scenario_id=scenario.id,
scenario_name=scenario.name,
seed=scenario_seed,
training_stage=request.training_stage,
turns_executed=session.world.turn,
done=done,
done_reason=done_reason,
oversight_trigger_count=oversight_trigger_count,
final_tension=session.world.tension_level,
final_market_stress=session.world.market_stress,
final_oil_pressure=session.world.oil_pressure,
summary=summary,
warnings=scenario_warnings,
scorecards=scorecards,
)
)
finally:
env.shutdown()
scenario_count = max(len(results), 1)
aggregate_mean_total_rewards = {
agent_id: round(total / scenario_count, 3)
for agent_id, total in aggregate_reward_totals.items()
}
return BenchmarkRunResponse(
seed=request.seed,
training_stage=request.training_stage,
scenario_ids=[result.scenario_id for result in results],
scenario_count=len(results),
results=results,
aggregate_mean_total_rewards=aggregate_mean_total_rewards,
)