Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import json | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING, Any, Protocol | |
| from pydantic import Field | |
| from .base import ARTIFACTS_ROOT | |
| from .schema import CompiledWorld, DMFeedback, DMObservation, StrictModel, Turn, WorldDefinition | |
| if TYPE_CHECKING: | |
| from .session import EpisodeSession | |
| STATE_FILENAME = "state.json" | |
| WORLD_FILENAME = "world.json" | |
| LIVE_SCHEMA_VERSION = 1 | |
| DEFAULT_LIVE_DIR = ARTIFACTS_ROOT / "live" | |
| class LiveMetrics(StrictModel): | |
| steps_taken: int = 0 | |
| min_steps: int | None = None | |
| ratio: float | None = None | |
| reward: float | None = None | |
| player_won: bool | None = None | |
| class LiveRuntime(StrictModel): | |
| current_room_id: str | None = None | |
| inventory_item_ids: list[str] = Field(default_factory=list) | |
| discovered_clue_ids: list[str] = Field(default_factory=list) | |
| traded_npc_ids: list[str] = Field(default_factory=list) | |
| visited_room_ids: list[str] = Field(default_factory=list) | |
| available_commands: list[str] = Field(default_factory=list) | |
| invalid_command_count: int = 0 | |
| wrong_submit_count: int = 0 | |
| open_node_ids: list[str] = Field(default_factory=list) | |
| locked_node_ids: list[str] = Field(default_factory=list) | |
| class LiveCurrentRoom(StrictModel): | |
| id: str | None = None | |
| label: str | None = None | |
| description: str | None = None | |
| visible_node_ids: list[str] = Field(default_factory=list) | |
| visible_item_ids: list[str] = Field(default_factory=list) | |
| class LiveStateSnapshot(StrictModel): | |
| schema_version: int = LIVE_SCHEMA_VERSION | |
| episode_id: str | |
| status: str | |
| updated_at: str | |
| title: str | None = None | |
| runner: str | None = None | |
| error: str | None = None | |
| transcript: list[Turn] = Field(default_factory=list) | |
| metrics: LiveMetrics = Field(default_factory=LiveMetrics) | |
| feedback: DMFeedback | None = None | |
| runtime: LiveRuntime = Field(default_factory=LiveRuntime) | |
| current_room: LiveCurrentRoom | None = None | |
| class LiveObserver(Protocol): | |
| def on_run_start(self, episode_id: str, world_input: WorldDefinition | dict[str, Any]) -> None: | |
| ... | |
| def on_compile_success(self, compiled: CompiledWorld, session: EpisodeSession) -> None: | |
| ... | |
| def on_turn(self, session: EpisodeSession, turn: Turn) -> None: | |
| ... | |
| def on_complete(self, compiled: CompiledWorld, session: EpisodeSession, observation: DMObservation) -> None: | |
| ... | |
| def on_error( | |
| self, | |
| *, | |
| episode_id: str, | |
| error: str, | |
| world_input: WorldDefinition | dict[str, Any], | |
| compiled: CompiledWorld | None = None, | |
| session: EpisodeSession | None = None, | |
| ) -> None: | |
| ... | |
| class LiveSnapshotWriter: | |
| def __init__(self, live_dir: Path | None = None, runner_name: str | None = None) -> None: | |
| self.live_dir = live_dir or DEFAULT_LIVE_DIR | |
| self.runner_name = runner_name | |
| self.live_dir.mkdir(parents=True, exist_ok=True) | |
| def on_run_start(self, episode_id: str, world_input: WorldDefinition | dict[str, Any]) -> None: | |
| self._remove_world() | |
| snapshot = LiveStateSnapshot( | |
| episode_id=episode_id, | |
| status="compiling", | |
| updated_at=self._timestamp(), | |
| title=self._extract_title(world_input), | |
| runner=self.runner_name, | |
| ) | |
| self._write_state_snapshot(snapshot) | |
| def on_compile_success(self, compiled: CompiledWorld, session: EpisodeSession) -> None: | |
| self._write_world(compiled.world) | |
| snapshot = LiveStateSnapshot( | |
| episode_id=compiled.episode_id, | |
| status="running", | |
| updated_at=self._timestamp(), | |
| title=compiled.world.meta.title, | |
| runner=self.runner_name, | |
| metrics=self._metrics(min_steps=len(compiled.solver_policy), steps_taken=session.steps_taken), | |
| runtime=self._runtime(session), | |
| current_room=self._current_room(session), | |
| ) | |
| self._write_state_snapshot(snapshot) | |
| def on_turn(self, session: EpisodeSession, turn: Turn) -> None: | |
| del turn | |
| snapshot = LiveStateSnapshot( | |
| episode_id=session.compiled.episode_id, | |
| status="running", | |
| updated_at=self._timestamp(), | |
| title=session.compiled.world.meta.title, | |
| runner=self.runner_name, | |
| transcript=list(session.transcript), | |
| metrics=self._metrics( | |
| min_steps=len(session.compiled.solver_policy), | |
| steps_taken=session.steps_taken, | |
| ), | |
| runtime=self._runtime(session), | |
| current_room=self._current_room(session), | |
| ) | |
| self._write_state_snapshot(snapshot) | |
| def on_complete(self, compiled: CompiledWorld, session: EpisodeSession, observation: DMObservation) -> None: | |
| status = "complete" if observation.player_won else "failed" | |
| snapshot = LiveStateSnapshot( | |
| episode_id=compiled.episode_id, | |
| status=status, | |
| updated_at=self._timestamp(), | |
| title=compiled.world.meta.title, | |
| runner=self.runner_name, | |
| transcript=list(session.transcript), | |
| metrics=self._metrics( | |
| min_steps=observation.min_steps, | |
| steps_taken=observation.steps_taken or session.steps_taken, | |
| ratio=observation.ratio, | |
| reward=observation.reward, | |
| player_won=observation.player_won, | |
| ), | |
| feedback=observation.feedback, | |
| runtime=self._runtime(session), | |
| current_room=self._current_room(session), | |
| ) | |
| self._write_state_snapshot(snapshot) | |
| def on_error( | |
| self, | |
| *, | |
| episode_id: str, | |
| error: str, | |
| world_input: WorldDefinition | dict[str, Any], | |
| compiled: CompiledWorld | None = None, | |
| session: EpisodeSession | None = None, | |
| ) -> None: | |
| title = compiled.world.meta.title if compiled is not None else self._extract_title(world_input) | |
| snapshot = LiveStateSnapshot( | |
| episode_id=episode_id, | |
| status="compile_error", | |
| updated_at=self._timestamp(), | |
| title=title, | |
| runner=self.runner_name, | |
| error=error, | |
| transcript=list(session.transcript) if session is not None else [], | |
| metrics=self._metrics( | |
| min_steps=len(compiled.solver_policy) if compiled is not None else None, | |
| steps_taken=session.steps_taken if session is not None else 0, | |
| ), | |
| runtime=self._runtime(session), | |
| current_room=self._current_room(session), | |
| ) | |
| self._write_state_snapshot(snapshot) | |
| def _write_world(self, world: WorldDefinition) -> None: | |
| self._write_json(self.live_dir / WORLD_FILENAME, world.model_dump_json(indent=2)) | |
| def _write_state_snapshot(self, snapshot: LiveStateSnapshot) -> None: | |
| self._write_json(self.live_dir / STATE_FILENAME, snapshot.model_dump_json(indent=2)) | |
| def _write_json(self, path: Path, payload: str) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| tmp_path = path.with_suffix(path.suffix + ".tmp") | |
| tmp_path.write_text(payload + "\n", encoding="utf-8") | |
| tmp_path.replace(path) | |
| def _remove_world(self) -> None: | |
| world_path = self.live_dir / WORLD_FILENAME | |
| if world_path.exists(): | |
| world_path.unlink() | |
| def _timestamp() -> str: | |
| return datetime.now(timezone.utc).isoformat() | |
| def _extract_title(world_input: WorldDefinition | dict[str, Any]) -> str | None: | |
| if isinstance(world_input, WorldDefinition): | |
| return world_input.meta.title | |
| meta = world_input.get("meta") if isinstance(world_input, dict) else None | |
| title = meta.get("title") if isinstance(meta, dict) else None | |
| return title if isinstance(title, str) else None | |
| def _metrics( | |
| *, | |
| min_steps: int | None, | |
| steps_taken: int, | |
| ratio: float | None = None, | |
| reward: float | None = None, | |
| player_won: bool | None = None, | |
| ) -> LiveMetrics: | |
| computed_ratio = ratio | |
| if computed_ratio is None and min_steps: | |
| computed_ratio = steps_taken / min_steps | |
| return LiveMetrics( | |
| steps_taken=steps_taken, | |
| min_steps=min_steps, | |
| ratio=computed_ratio, | |
| reward=reward, | |
| player_won=player_won, | |
| ) | |
| def _runtime(session: EpisodeSession | None) -> LiveRuntime: | |
| if session is None: | |
| return LiveRuntime() | |
| room_ids = { | |
| node.id | |
| for node in session.compiled.world.nodes | |
| if node.type in {"location", "junction"} | |
| } | |
| commands = [] if session.done else session.available_commands() | |
| return LiveRuntime( | |
| current_room_id=session.current_room_id, | |
| inventory_item_ids=sorted(session.inventory), | |
| discovered_clue_ids=sorted(session.discovered_clues), | |
| traded_npc_ids=sorted(session.traded_npcs), | |
| visited_room_ids=sorted(room_ids & session.visited_nodes), | |
| available_commands=commands, | |
| invalid_command_count=session.invalid_command_count, | |
| wrong_submit_count=session.wrong_submit_count, | |
| open_node_ids=sorted(session.open_nodes), | |
| locked_node_ids=sorted(session.locked_nodes), | |
| ) | |
| def _current_room(session: EpisodeSession | None) -> LiveCurrentRoom | None: | |
| if session is None: | |
| return None | |
| node_by_id = {node.id: node for node in session.compiled.world.nodes} | |
| room = node_by_id.get(session.current_room_id) | |
| if room is None: | |
| return None | |
| visible_nodes = [ | |
| node.id | |
| for node in session.compiled.world.nodes | |
| if getattr(node, "parent_id", None) == session.current_room_id | |
| and (node.type != "readable" or node.id in session.revealed_readables) | |
| ] | |
| visible_nodes.extend( | |
| sorted( | |
| door_id | |
| for door_id, rooms in session.compiled.door_rooms.items() | |
| if session.current_room_id in rooms | |
| ) | |
| ) | |
| visible_items = sorted( | |
| item_id | |
| for item_id, location in session.item_locations.items() | |
| if location == session.current_room_id | |
| ) | |
| return LiveCurrentRoom( | |
| id=room.id, | |
| label=room.label, | |
| description=room.description, | |
| visible_node_ids=sorted(set(visible_nodes)), | |
| visible_item_ids=visible_items, | |
| ) | |
| def load_live_payload(live_dir: Path, filename: str) -> bytes | None: | |
| path = live_dir / filename | |
| if not path.exists(): | |
| return None | |
| return path.read_bytes() | |
| def load_live_state(live_dir: Path) -> dict[str, Any] | None: | |
| payload = load_live_payload(live_dir, STATE_FILENAME) | |
| if payload is None: | |
| return None | |
| return json.loads(payload) | |