Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from agents.hero.policy import HeroPolicy | |
| from agents.hero.runner import HeroRunner | |
| from agents.master.env import DMEnvironment | |
| from agents.master.interface import InterfaceAdapter, StrictCliInterfaceAdapter | |
| from agents.master.policy import DMRepairContext, DungeonMasterPolicy, DungeonMasterPolicyError | |
| from agents.master.schema import DMObservation, DMRewardBreakdown, WorldDefinition | |
| from agents.master.snapshots import LiveObserver, LiveSnapshotWriter | |
| from .schema import ( | |
| ClosedLoopAggregateReport, | |
| ClosedLoopEpisodeArtifacts, | |
| ClosedLoopEpisodeRecord, | |
| ClosedLoopEpisodeSummary, | |
| ) | |
| DEFAULT_CLOSED_LOOP_ROOT = Path(__file__).resolve().parents[2] / ".play_runs" / "closed_loop" | |
| class ClosedLoopRunner: | |
| def __init__( | |
| self, | |
| *, | |
| dm_env: DMEnvironment, | |
| dm_policy: DungeonMasterPolicy, | |
| hero_policy: HeroPolicy, | |
| artifacts_root: Path | None = None, | |
| live_dir: Path | None = None, | |
| max_dm_repair_attempts: int = 2, | |
| hero_runner_kwargs: dict[str, object] | None = None, | |
| hero_interface_adapter: InterfaceAdapter | None = None, | |
| ) -> None: | |
| self.dm_env = dm_env | |
| self.dm_policy = dm_policy | |
| self.hero_policy = hero_policy | |
| self.artifacts_root = artifacts_root or DEFAULT_CLOSED_LOOP_ROOT | |
| self.live_dir = live_dir | |
| self.max_dm_repair_attempts = max_dm_repair_attempts | |
| self.hero_runner_kwargs = hero_runner_kwargs or {"max_game_steps": 40, "max_tool_calls": 80} | |
| self.hero_interface_adapter = hero_interface_adapter or StrictCliInterfaceAdapter() | |
| def run_episode( | |
| self, | |
| *, | |
| seed: int | None = None, | |
| target_ratio: float | None = None, | |
| live: bool = False, | |
| ) -> ClosedLoopEpisodeRecord: | |
| self.dm_env.reset(seed=seed, difficulty_hint=target_ratio) | |
| episode_id = self.dm_env.state.episode_id | |
| if episode_id is None: | |
| raise RuntimeError("DM environment did not assign an episode id.") | |
| episode_dir = self.artifacts_root / episode_id | |
| episode_dir.mkdir(parents=True, exist_ok=True) | |
| artifacts = ClosedLoopEpisodeArtifacts.from_episode_dir(episode_dir) | |
| observer = self._observer(live) | |
| world: WorldDefinition | None = None | |
| errors: list[str] = [] | |
| compile_attempts = 0 | |
| repair_context: DMRepairContext | None = None | |
| previous_candidate_json: str | None = None | |
| attempt_rows: list[dict[str, object]] = [] | |
| for attempt in range(1, self.max_dm_repair_attempts + 2): | |
| compile_attempts = attempt | |
| try: | |
| candidate = self.dm_policy.generate_world( | |
| target_ratio=self.dm_env.state.target_ratio, | |
| repair_context=repair_context, | |
| ) | |
| previous_candidate_json = candidate.model_dump_json(indent=2) | |
| self._write_json(Path(artifacts.world_definition_path), previous_candidate_json) | |
| self.dm_env.compile_world(candidate, episode_id=episode_id) | |
| world = candidate | |
| attempt_rows.append( | |
| { | |
| "attempt_number": attempt, | |
| "status": "compiled", | |
| "world_title": candidate.meta.title, | |
| "difficulty_target": candidate.meta.difficulty_target, | |
| } | |
| ) | |
| break | |
| except Exception as exc: | |
| normalized_error = self._normalize_error(exc) | |
| errors.append(normalized_error) | |
| attempt_rows.append( | |
| { | |
| "attempt_number": attempt, | |
| "status": "failed", | |
| "error": normalized_error, | |
| } | |
| ) | |
| repair_context = DMRepairContext( | |
| attempt_number=attempt, | |
| error_message=normalized_error, | |
| previous_candidate_json=previous_candidate_json, | |
| ) | |
| self._write_jsonl(Path(artifacts.world_generation_attempts_path), attempt_rows) | |
| if world is None: | |
| observation = self._compile_failure_observation(errors[-1] if errors else "world compilation failed") | |
| record = ClosedLoopEpisodeRecord( | |
| episode_id=episode_id, | |
| status="compile_failed", | |
| target_ratio=self.dm_env.state.target_ratio, | |
| compile_attempts=compile_attempts, | |
| dm_repair_errors=errors, | |
| world_definition=None, | |
| declared_difficulty_target=None, | |
| difficulty_target_matches_target_ratio=None, | |
| observation=observation, | |
| artifacts=artifacts, | |
| ) | |
| self._persist_record(record) | |
| self._write_jsonl(Path(artifacts.hero_trace_path), []) | |
| self._write_jsonl(Path(artifacts.transcript_path), []) | |
| return record | |
| hero_runner = HeroRunner(policy=self.hero_policy, **self.hero_runner_kwargs) | |
| previous_adapter = self.dm_env.interface_adapter | |
| self.dm_env.interface_adapter = self.hero_interface_adapter | |
| try: | |
| result = self.dm_env.step(world, runner=hero_runner, observer=observer) | |
| finally: | |
| self.dm_env.interface_adapter = previous_adapter | |
| observation = result.observation | |
| status = "policy_error" if hero_runner.last_error else ("complete" if observation.player_won else "failed") | |
| record = ClosedLoopEpisodeRecord( | |
| episode_id=episode_id, | |
| status=status, | |
| target_ratio=self.dm_env.state.target_ratio, | |
| compile_attempts=compile_attempts, | |
| dm_repair_errors=errors, | |
| hero_policy_error=hero_runner.last_error, | |
| hero_episode_stats=hero_runner.episode_stats, | |
| world_definition=world, | |
| declared_difficulty_target=world.meta.difficulty_target, | |
| difficulty_target_matches_target_ratio=(world.meta.difficulty_target == self.dm_env.state.target_ratio), | |
| observation=observation, | |
| artifacts=artifacts, | |
| ) | |
| self._persist_record(record) | |
| self._write_jsonl( | |
| Path(artifacts.hero_trace_path), | |
| [event.model_dump(mode="json") for event in self.hero_policy.trace_events], | |
| ) | |
| self._write_jsonl( | |
| Path(artifacts.transcript_path), | |
| [turn.model_dump(mode="json") for turn in observation.episode_transcript], | |
| ) | |
| return record | |
| def summary(record: ClosedLoopEpisodeRecord) -> ClosedLoopEpisodeSummary: | |
| return ClosedLoopEpisodeSummary( | |
| episode_id=record.episode_id, | |
| status=record.status, | |
| reward=record.observation.reward, | |
| player_won=record.observation.player_won, | |
| ratio=record.observation.ratio, | |
| compile_error=record.observation.compile_error, | |
| hero_policy_error=record.hero_policy_error, | |
| ) | |
| def aggregate(records: list[ClosedLoopEpisodeRecord]) -> ClosedLoopAggregateReport: | |
| episodes = len(records) | |
| dense_returns = [ | |
| record.hero_episode_stats.dense_return | |
| for record in records | |
| if record.hero_episode_stats is not None | |
| ] | |
| invalid_penalties = [ | |
| record.hero_episode_stats.invalid_action_penalty_total | |
| for record in records | |
| if record.hero_episode_stats is not None | |
| ] | |
| repeat_penalties = [ | |
| record.hero_episode_stats.repeat_noop_penalty_total | |
| for record in records | |
| if record.hero_episode_stats is not None | |
| ] | |
| return ClosedLoopAggregateReport( | |
| episodes=episodes, | |
| compile_valid_rate=_rate(sum(record.status != "compile_failed" for record in records), episodes), | |
| policy_error_rate=_rate(sum(record.status == "policy_error" for record in records), episodes), | |
| playable_rate=_rate(sum(record.world_definition is not None for record in records), episodes), | |
| solve_rate=_rate(sum(record.status == "complete" for record in records), episodes), | |
| mean_dense_return=_mean(dense_returns), | |
| mean_invalid_action_penalty=_mean(invalid_penalties), | |
| mean_repeat_noop_penalty=_mean(repeat_penalties), | |
| ) | |
| def _compile_failure_observation(self, error: str) -> DMObservation: | |
| breakdown = DMRewardBreakdown( | |
| reward_mode="compile_failure_penalty", | |
| player_won=False, | |
| target_ratio=self.dm_env.state.target_ratio, | |
| quality_score=0.0, | |
| reward=0.0, | |
| ) | |
| return DMObservation( | |
| player_won=False, | |
| compile_error=error, | |
| reward=0.0, | |
| done=True, | |
| reward_breakdown=breakdown, | |
| target_ratio_used=self.dm_env.state.target_ratio, | |
| ) | |
| def _observer(self, live: bool) -> LiveObserver | None: | |
| if not live: | |
| return None | |
| return LiveSnapshotWriter(live_dir=self.live_dir, runner_name="hero_llm") | |
| def _persist_record(self, record: ClosedLoopEpisodeRecord) -> None: | |
| self._write_json(Path(record.artifacts.run_record_path), record.model_dump_json(indent=2)) | |
| def _write_json(path: Path, payload: str) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| path.write_text(payload + "\n", encoding="utf-8") | |
| def _write_jsonl(path: Path, rows: list[dict[str, object]]) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| payload = "".join(json.dumps(row) + "\n" for row in rows) | |
| path.write_text(payload, encoding="utf-8") | |
| def _normalize_error(exc: Exception) -> str: | |
| if isinstance(exc, DungeonMasterPolicyError): | |
| return str(exc) | |
| return " ".join(str(exc).split()) or exc.__class__.__name__ | |
| def _mean(values: list[float]) -> float: | |
| if not values: | |
| return 0.0 | |
| return sum(values) / len(values) | |
| def _rate(count: int, total: int) -> float: | |
| if total <= 0: | |
| return 0.0 | |
| return count / total | |