from __future__ import annotations import json from pathlib import Path from agents.hero.policy import HeroPolicy from agents.hero.runner import HeroRunner from agents.master.env import DMEnvironment from agents.master.interface import InterfaceAdapter, StrictCliInterfaceAdapter from agents.master.policy import DMRepairContext, DungeonMasterPolicy, DungeonMasterPolicyError from agents.master.schema import DMObservation, DMRewardBreakdown, WorldDefinition from agents.master.snapshots import LiveObserver, LiveSnapshotWriter from .schema import ( ClosedLoopAggregateReport, ClosedLoopEpisodeArtifacts, ClosedLoopEpisodeRecord, ClosedLoopEpisodeSummary, ) DEFAULT_CLOSED_LOOP_ROOT = Path(__file__).resolve().parents[2] / ".play_runs" / "closed_loop" class ClosedLoopRunner: def __init__( self, *, dm_env: DMEnvironment, dm_policy: DungeonMasterPolicy, hero_policy: HeroPolicy, artifacts_root: Path | None = None, live_dir: Path | None = None, max_dm_repair_attempts: int = 2, hero_runner_kwargs: dict[str, object] | None = None, hero_interface_adapter: InterfaceAdapter | None = None, ) -> None: self.dm_env = dm_env self.dm_policy = dm_policy self.hero_policy = hero_policy self.artifacts_root = artifacts_root or DEFAULT_CLOSED_LOOP_ROOT self.live_dir = live_dir self.max_dm_repair_attempts = max_dm_repair_attempts self.hero_runner_kwargs = hero_runner_kwargs or {"max_game_steps": 40, "max_tool_calls": 80} self.hero_interface_adapter = hero_interface_adapter or StrictCliInterfaceAdapter() def run_episode( self, *, seed: int | None = None, target_ratio: float | None = None, live: bool = False, ) -> ClosedLoopEpisodeRecord: self.dm_env.reset(seed=seed, difficulty_hint=target_ratio) episode_id = self.dm_env.state.episode_id if episode_id is None: raise RuntimeError("DM environment did not assign an episode id.") episode_dir = self.artifacts_root / episode_id episode_dir.mkdir(parents=True, exist_ok=True) artifacts = ClosedLoopEpisodeArtifacts.from_episode_dir(episode_dir) observer = self._observer(live) world: WorldDefinition | None = None errors: list[str] = [] compile_attempts = 0 repair_context: DMRepairContext | None = None previous_candidate_json: str | None = None attempt_rows: list[dict[str, object]] = [] for attempt in range(1, self.max_dm_repair_attempts + 2): compile_attempts = attempt try: candidate = self.dm_policy.generate_world( target_ratio=self.dm_env.state.target_ratio, repair_context=repair_context, ) previous_candidate_json = candidate.model_dump_json(indent=2) self._write_json(Path(artifacts.world_definition_path), previous_candidate_json) self.dm_env.compile_world(candidate, episode_id=episode_id) world = candidate attempt_rows.append( { "attempt_number": attempt, "status": "compiled", "world_title": candidate.meta.title, "difficulty_target": candidate.meta.difficulty_target, } ) break except Exception as exc: normalized_error = self._normalize_error(exc) errors.append(normalized_error) attempt_rows.append( { "attempt_number": attempt, "status": "failed", "error": normalized_error, } ) repair_context = DMRepairContext( attempt_number=attempt, error_message=normalized_error, previous_candidate_json=previous_candidate_json, ) self._write_jsonl(Path(artifacts.world_generation_attempts_path), attempt_rows) if world is None: observation = self._compile_failure_observation(errors[-1] if errors else "world compilation failed") record = ClosedLoopEpisodeRecord( episode_id=episode_id, status="compile_failed", target_ratio=self.dm_env.state.target_ratio, compile_attempts=compile_attempts, dm_repair_errors=errors, world_definition=None, declared_difficulty_target=None, difficulty_target_matches_target_ratio=None, observation=observation, artifacts=artifacts, ) self._persist_record(record) self._write_jsonl(Path(artifacts.hero_trace_path), []) self._write_jsonl(Path(artifacts.transcript_path), []) return record hero_runner = HeroRunner(policy=self.hero_policy, **self.hero_runner_kwargs) previous_adapter = self.dm_env.interface_adapter self.dm_env.interface_adapter = self.hero_interface_adapter try: result = self.dm_env.step(world, runner=hero_runner, observer=observer) finally: self.dm_env.interface_adapter = previous_adapter observation = result.observation status = "policy_error" if hero_runner.last_error else ("complete" if observation.player_won else "failed") record = ClosedLoopEpisodeRecord( episode_id=episode_id, status=status, target_ratio=self.dm_env.state.target_ratio, compile_attempts=compile_attempts, dm_repair_errors=errors, hero_policy_error=hero_runner.last_error, hero_episode_stats=hero_runner.episode_stats, world_definition=world, declared_difficulty_target=world.meta.difficulty_target, difficulty_target_matches_target_ratio=(world.meta.difficulty_target == self.dm_env.state.target_ratio), observation=observation, artifacts=artifacts, ) self._persist_record(record) self._write_jsonl( Path(artifacts.hero_trace_path), [event.model_dump(mode="json") for event in self.hero_policy.trace_events], ) self._write_jsonl( Path(artifacts.transcript_path), [turn.model_dump(mode="json") for turn in observation.episode_transcript], ) return record @staticmethod def summary(record: ClosedLoopEpisodeRecord) -> ClosedLoopEpisodeSummary: return ClosedLoopEpisodeSummary( episode_id=record.episode_id, status=record.status, reward=record.observation.reward, player_won=record.observation.player_won, ratio=record.observation.ratio, compile_error=record.observation.compile_error, hero_policy_error=record.hero_policy_error, ) @staticmethod def aggregate(records: list[ClosedLoopEpisodeRecord]) -> ClosedLoopAggregateReport: episodes = len(records) dense_returns = [ record.hero_episode_stats.dense_return for record in records if record.hero_episode_stats is not None ] invalid_penalties = [ record.hero_episode_stats.invalid_action_penalty_total for record in records if record.hero_episode_stats is not None ] repeat_penalties = [ record.hero_episode_stats.repeat_noop_penalty_total for record in records if record.hero_episode_stats is not None ] return ClosedLoopAggregateReport( episodes=episodes, compile_valid_rate=_rate(sum(record.status != "compile_failed" for record in records), episodes), policy_error_rate=_rate(sum(record.status == "policy_error" for record in records), episodes), playable_rate=_rate(sum(record.world_definition is not None for record in records), episodes), solve_rate=_rate(sum(record.status == "complete" for record in records), episodes), mean_dense_return=_mean(dense_returns), mean_invalid_action_penalty=_mean(invalid_penalties), mean_repeat_noop_penalty=_mean(repeat_penalties), ) def _compile_failure_observation(self, error: str) -> DMObservation: breakdown = DMRewardBreakdown( reward_mode="compile_failure_penalty", player_won=False, target_ratio=self.dm_env.state.target_ratio, quality_score=0.0, reward=0.0, ) return DMObservation( player_won=False, compile_error=error, reward=0.0, done=True, reward_breakdown=breakdown, target_ratio_used=self.dm_env.state.target_ratio, ) def _observer(self, live: bool) -> LiveObserver | None: if not live: return None return LiveSnapshotWriter(live_dir=self.live_dir, runner_name="hero_llm") def _persist_record(self, record: ClosedLoopEpisodeRecord) -> None: self._write_json(Path(record.artifacts.run_record_path), record.model_dump_json(indent=2)) @staticmethod def _write_json(path: Path, payload: str) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(payload + "\n", encoding="utf-8") @staticmethod def _write_jsonl(path: Path, rows: list[dict[str, object]]) -> None: path.parent.mkdir(parents=True, exist_ok=True) payload = "".join(json.dumps(row) + "\n" for row in rows) path.write_text(payload, encoding="utf-8") @staticmethod def _normalize_error(exc: Exception) -> str: if isinstance(exc, DungeonMasterPolicyError): return str(exc) return " ".join(str(exc).split()) or exc.__class__.__name__ def _mean(values: list[float]) -> float: if not values: return 0.0 return sum(values) / len(values) def _rate(count: int, total: int) -> float: if total <= 0: return 0.0 return count / total