FATHOM-DM / agents /master /snapshots.py
aarushgupta's picture
Deploy FATHOM-DM Space bundle
2803d7e verified
from __future__ import annotations
import json
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any, Protocol
from pydantic import Field
from .base import ARTIFACTS_ROOT
from .schema import CompiledWorld, DMFeedback, DMObservation, StrictModel, Turn, WorldDefinition
if TYPE_CHECKING:
from .session import EpisodeSession
STATE_FILENAME = "state.json"
WORLD_FILENAME = "world.json"
LIVE_SCHEMA_VERSION = 1
DEFAULT_LIVE_DIR = ARTIFACTS_ROOT / "live"
class LiveMetrics(StrictModel):
steps_taken: int = 0
min_steps: int | None = None
ratio: float | None = None
reward: float | None = None
player_won: bool | None = None
class LiveRuntime(StrictModel):
current_room_id: str | None = None
inventory_item_ids: list[str] = Field(default_factory=list)
discovered_clue_ids: list[str] = Field(default_factory=list)
traded_npc_ids: list[str] = Field(default_factory=list)
visited_room_ids: list[str] = Field(default_factory=list)
available_commands: list[str] = Field(default_factory=list)
invalid_command_count: int = 0
wrong_submit_count: int = 0
open_node_ids: list[str] = Field(default_factory=list)
locked_node_ids: list[str] = Field(default_factory=list)
class LiveCurrentRoom(StrictModel):
id: str | None = None
label: str | None = None
description: str | None = None
visible_node_ids: list[str] = Field(default_factory=list)
visible_item_ids: list[str] = Field(default_factory=list)
class LiveStateSnapshot(StrictModel):
schema_version: int = LIVE_SCHEMA_VERSION
episode_id: str
status: str
updated_at: str
title: str | None = None
runner: str | None = None
error: str | None = None
transcript: list[Turn] = Field(default_factory=list)
metrics: LiveMetrics = Field(default_factory=LiveMetrics)
feedback: DMFeedback | None = None
runtime: LiveRuntime = Field(default_factory=LiveRuntime)
current_room: LiveCurrentRoom | None = None
class LiveObserver(Protocol):
def on_run_start(self, episode_id: str, world_input: WorldDefinition | dict[str, Any]) -> None:
...
def on_compile_success(self, compiled: CompiledWorld, session: EpisodeSession) -> None:
...
def on_turn(self, session: EpisodeSession, turn: Turn) -> None:
...
def on_complete(self, compiled: CompiledWorld, session: EpisodeSession, observation: DMObservation) -> None:
...
def on_error(
self,
*,
episode_id: str,
error: str,
world_input: WorldDefinition | dict[str, Any],
compiled: CompiledWorld | None = None,
session: EpisodeSession | None = None,
) -> None:
...
class LiveSnapshotWriter:
def __init__(self, live_dir: Path | None = None, runner_name: str | None = None) -> None:
self.live_dir = live_dir or DEFAULT_LIVE_DIR
self.runner_name = runner_name
self.live_dir.mkdir(parents=True, exist_ok=True)
def on_run_start(self, episode_id: str, world_input: WorldDefinition | dict[str, Any]) -> None:
self._remove_world()
snapshot = LiveStateSnapshot(
episode_id=episode_id,
status="compiling",
updated_at=self._timestamp(),
title=self._extract_title(world_input),
runner=self.runner_name,
)
self._write_state_snapshot(snapshot)
def on_compile_success(self, compiled: CompiledWorld, session: EpisodeSession) -> None:
self._write_world(compiled.world)
snapshot = LiveStateSnapshot(
episode_id=compiled.episode_id,
status="running",
updated_at=self._timestamp(),
title=compiled.world.meta.title,
runner=self.runner_name,
metrics=self._metrics(min_steps=len(compiled.solver_policy), steps_taken=session.steps_taken),
runtime=self._runtime(session),
current_room=self._current_room(session),
)
self._write_state_snapshot(snapshot)
def on_turn(self, session: EpisodeSession, turn: Turn) -> None:
del turn
snapshot = LiveStateSnapshot(
episode_id=session.compiled.episode_id,
status="running",
updated_at=self._timestamp(),
title=session.compiled.world.meta.title,
runner=self.runner_name,
transcript=list(session.transcript),
metrics=self._metrics(
min_steps=len(session.compiled.solver_policy),
steps_taken=session.steps_taken,
),
runtime=self._runtime(session),
current_room=self._current_room(session),
)
self._write_state_snapshot(snapshot)
def on_complete(self, compiled: CompiledWorld, session: EpisodeSession, observation: DMObservation) -> None:
status = "complete" if observation.player_won else "failed"
snapshot = LiveStateSnapshot(
episode_id=compiled.episode_id,
status=status,
updated_at=self._timestamp(),
title=compiled.world.meta.title,
runner=self.runner_name,
transcript=list(session.transcript),
metrics=self._metrics(
min_steps=observation.min_steps,
steps_taken=observation.steps_taken or session.steps_taken,
ratio=observation.ratio,
reward=observation.reward,
player_won=observation.player_won,
),
feedback=observation.feedback,
runtime=self._runtime(session),
current_room=self._current_room(session),
)
self._write_state_snapshot(snapshot)
def on_error(
self,
*,
episode_id: str,
error: str,
world_input: WorldDefinition | dict[str, Any],
compiled: CompiledWorld | None = None,
session: EpisodeSession | None = None,
) -> None:
title = compiled.world.meta.title if compiled is not None else self._extract_title(world_input)
snapshot = LiveStateSnapshot(
episode_id=episode_id,
status="compile_error",
updated_at=self._timestamp(),
title=title,
runner=self.runner_name,
error=error,
transcript=list(session.transcript) if session is not None else [],
metrics=self._metrics(
min_steps=len(compiled.solver_policy) if compiled is not None else None,
steps_taken=session.steps_taken if session is not None else 0,
),
runtime=self._runtime(session),
current_room=self._current_room(session),
)
self._write_state_snapshot(snapshot)
def _write_world(self, world: WorldDefinition) -> None:
self._write_json(self.live_dir / WORLD_FILENAME, world.model_dump_json(indent=2))
def _write_state_snapshot(self, snapshot: LiveStateSnapshot) -> None:
self._write_json(self.live_dir / STATE_FILENAME, snapshot.model_dump_json(indent=2))
def _write_json(self, path: Path, payload: str) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = path.with_suffix(path.suffix + ".tmp")
tmp_path.write_text(payload + "\n", encoding="utf-8")
tmp_path.replace(path)
def _remove_world(self) -> None:
world_path = self.live_dir / WORLD_FILENAME
if world_path.exists():
world_path.unlink()
@staticmethod
def _timestamp() -> str:
return datetime.now(timezone.utc).isoformat()
@staticmethod
def _extract_title(world_input: WorldDefinition | dict[str, Any]) -> str | None:
if isinstance(world_input, WorldDefinition):
return world_input.meta.title
meta = world_input.get("meta") if isinstance(world_input, dict) else None
title = meta.get("title") if isinstance(meta, dict) else None
return title if isinstance(title, str) else None
@staticmethod
def _metrics(
*,
min_steps: int | None,
steps_taken: int,
ratio: float | None = None,
reward: float | None = None,
player_won: bool | None = None,
) -> LiveMetrics:
computed_ratio = ratio
if computed_ratio is None and min_steps:
computed_ratio = steps_taken / min_steps
return LiveMetrics(
steps_taken=steps_taken,
min_steps=min_steps,
ratio=computed_ratio,
reward=reward,
player_won=player_won,
)
@staticmethod
def _runtime(session: EpisodeSession | None) -> LiveRuntime:
if session is None:
return LiveRuntime()
room_ids = {
node.id
for node in session.compiled.world.nodes
if node.type in {"location", "junction"}
}
commands = [] if session.done else session.available_commands()
return LiveRuntime(
current_room_id=session.current_room_id,
inventory_item_ids=sorted(session.inventory),
discovered_clue_ids=sorted(session.discovered_clues),
traded_npc_ids=sorted(session.traded_npcs),
visited_room_ids=sorted(room_ids & session.visited_nodes),
available_commands=commands,
invalid_command_count=session.invalid_command_count,
wrong_submit_count=session.wrong_submit_count,
open_node_ids=sorted(session.open_nodes),
locked_node_ids=sorted(session.locked_nodes),
)
@staticmethod
def _current_room(session: EpisodeSession | None) -> LiveCurrentRoom | None:
if session is None:
return None
node_by_id = {node.id: node for node in session.compiled.world.nodes}
room = node_by_id.get(session.current_room_id)
if room is None:
return None
visible_nodes = [
node.id
for node in session.compiled.world.nodes
if getattr(node, "parent_id", None) == session.current_room_id
and (node.type != "readable" or node.id in session.revealed_readables)
]
visible_nodes.extend(
sorted(
door_id
for door_id, rooms in session.compiled.door_rooms.items()
if session.current_room_id in rooms
)
)
visible_items = sorted(
item_id
for item_id, location in session.item_locations.items()
if location == session.current_room_id
)
return LiveCurrentRoom(
id=room.id,
label=room.label,
description=room.description,
visible_node_ids=sorted(set(visible_nodes)),
visible_item_ids=visible_items,
)
def load_live_payload(live_dir: Path, filename: str) -> bytes | None:
path = live_dir / filename
if not path.exists():
return None
return path.read_bytes()
def load_live_state(live_dir: Path) -> dict[str, Any] | None:
payload = load_live_payload(live_dir, STATE_FILENAME)
if payload is None:
return None
return json.loads(payload)