Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import json | |
| import mimetypes | |
| import threading | |
| from http import HTTPStatus | |
| from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer | |
| from pathlib import Path | |
| from typing import Any | |
| from urllib.parse import urlparse | |
| from .base import DMCompileError, DMInterfaceError | |
| from .build import WorldCompiler | |
| from .interface import GeminiInterfaceAdapter, SimpleInterfaceAdapter | |
| from .schema import CompiledWorld, WorldDefinition | |
| from .session import EpisodeSession | |
| from .snapshots import ( | |
| DEFAULT_LIVE_DIR, | |
| STATE_FILENAME, | |
| WORLD_FILENAME, | |
| LiveCurrentRoom, | |
| LiveMetrics, | |
| LiveRuntime, | |
| LiveStateSnapshot, | |
| load_live_payload, | |
| ) | |
| WEB_DIST_DIR = Path(__file__).resolve().parents[2] / "www" / "dist" | |
| class GameSessionManager: | |
| """Thread-safe container for an interactive play session.""" | |
| def __init__(self, live_dir: Path, use_gemini: bool = False) -> None: | |
| self._lock = threading.Lock() | |
| self._session: EpisodeSession | None = None | |
| self._compiled: CompiledWorld | None = None | |
| self._compiler = WorldCompiler() | |
| self._live_dir = live_dir | |
| self._use_gemini = use_gemini | |
| self._clear_stale_files() | |
| def _clear_stale_files(self) -> None: | |
| """Remove leftover state/world JSON from a previous session.""" | |
| for fname in (STATE_FILENAME, WORLD_FILENAME): | |
| path = self._live_dir / fname | |
| path.unlink(missing_ok=True) | |
| def start(self, world_input: WorldDefinition | dict[str, Any]) -> dict[str, Any]: | |
| with self._lock: | |
| if self._session is not None: | |
| self._session.close() | |
| compiled = self._compiler.compile(world_input) | |
| adapter = self._make_adapter() | |
| session = EpisodeSession(compiled, interface_adapter=adapter) | |
| self._compiled = compiled | |
| self._session = session | |
| self._write_world(compiled.world) | |
| self._write_state("running") | |
| return { | |
| "ok": True, | |
| "episode_id": compiled.episode_id, | |
| "observation": session.current_feedback(), | |
| "available_commands": session.available_commands(), | |
| "room": self._room_info(session), | |
| } | |
| def reset(self) -> dict[str, Any]: | |
| with self._lock: | |
| if self._session is not None: | |
| self._session.close() | |
| self._session = None | |
| self._compiled = None | |
| self._clear_stale_files() | |
| return {"ok": True} | |
| def command(self, raw_command: str) -> dict[str, Any]: | |
| with self._lock: | |
| session = self._session | |
| if session is None: | |
| return {"ok": False, "error": "No active session. POST /api/start first."} | |
| if session.done: | |
| return { | |
| "ok": False, | |
| "error": "Episode is complete.", | |
| "done": True, | |
| "player_won": session.player_won, | |
| } | |
| try: | |
| turn = session.step(raw_command) | |
| except (DMInterfaceError, RuntimeError) as exc: | |
| return {"ok": False, "error": str(exc)} | |
| status = "complete" if session.done and session.player_won else ( | |
| "failed" if session.done else "running" | |
| ) | |
| self._write_state(status) | |
| return { | |
| "ok": True, | |
| "step": turn.step, | |
| "command": turn.textworld_command, | |
| "observation": turn.observation, | |
| "done": session.done, | |
| "player_won": session.player_won, | |
| "available_commands": [] if session.done else session.available_commands(), | |
| "room": self._room_info(session), | |
| } | |
| def get_state_payload(self) -> dict[str, Any] | None: | |
| with self._lock: | |
| session = self._session | |
| compiled = self._compiled | |
| if session is None or compiled is None: | |
| return None | |
| return self._snapshot(session, compiled).model_dump() | |
| def _make_adapter(self) -> SimpleInterfaceAdapter | GeminiInterfaceAdapter: | |
| if self._use_gemini: | |
| try: | |
| return GeminiInterfaceAdapter(narrate_observations=True) | |
| except DMInterfaceError: | |
| pass | |
| return SimpleInterfaceAdapter() | |
| def _write_world(self, world: WorldDefinition) -> None: | |
| self._write_json(WORLD_FILENAME, world.model_dump_json(indent=2)) | |
| def _write_state(self, status: str) -> None: | |
| session = self._session | |
| compiled = self._compiled | |
| if session is None or compiled is None: | |
| return | |
| snapshot = self._snapshot(session, compiled, status=status) | |
| self._write_json(STATE_FILENAME, snapshot.model_dump_json(indent=2)) | |
| def _snapshot( | |
| self, | |
| session: EpisodeSession, | |
| compiled: CompiledWorld, | |
| status: str | None = None, | |
| ) -> LiveStateSnapshot: | |
| from datetime import datetime, timezone | |
| room_ids = { | |
| node.id for node in compiled.world.nodes if node.type in {"location", "junction"} | |
| } | |
| commands = [] if session.done else session.available_commands() | |
| if status is None: | |
| if session.done: | |
| status = "complete" if session.player_won else "failed" | |
| else: | |
| status = "running" | |
| return LiveStateSnapshot( | |
| episode_id=compiled.episode_id, | |
| status=status, | |
| updated_at=datetime.now(timezone.utc).isoformat(), | |
| title=compiled.world.meta.title, | |
| transcript=list(session.transcript), | |
| metrics=LiveMetrics( | |
| steps_taken=session.steps_taken, | |
| min_steps=len(compiled.solver_policy), | |
| ratio=session.steps_taken / len(compiled.solver_policy) if compiled.solver_policy else None, | |
| player_won=session.player_won if session.done else None, | |
| ), | |
| runtime=LiveRuntime( | |
| current_room_id=session.current_room_id, | |
| inventory_item_ids=sorted(session.inventory), | |
| discovered_clue_ids=sorted(session.discovered_clues), | |
| traded_npc_ids=sorted(session.traded_npcs), | |
| visited_room_ids=sorted(room_ids & session.visited_nodes), | |
| available_commands=commands, | |
| invalid_command_count=session.invalid_command_count, | |
| wrong_submit_count=session.wrong_submit_count, | |
| open_node_ids=sorted(session.open_nodes), | |
| locked_node_ids=sorted(session.locked_nodes), | |
| ), | |
| current_room=self._current_room_snapshot(session), | |
| ) | |
| def _current_room_snapshot(session: EpisodeSession) -> LiveCurrentRoom | None: | |
| node_by_id = {node.id: node for node in session.compiled.world.nodes} | |
| room = node_by_id.get(session.current_room_id) | |
| if room is None: | |
| return None | |
| visible_nodes = [ | |
| node.id | |
| for node in session.compiled.world.nodes | |
| if getattr(node, "parent_id", None) == session.current_room_id | |
| and (node.type != "readable" or node.id in session.revealed_readables) | |
| ] | |
| visible_nodes.extend( | |
| sorted( | |
| door_id | |
| for door_id, rooms in session.compiled.door_rooms.items() | |
| if session.current_room_id in rooms | |
| ) | |
| ) | |
| visible_items = sorted( | |
| item_id | |
| for item_id, location in session.item_locations.items() | |
| if location == session.current_room_id | |
| ) | |
| return LiveCurrentRoom( | |
| id=room.id, | |
| label=room.label, | |
| description=room.description, | |
| visible_node_ids=sorted(set(visible_nodes)), | |
| visible_item_ids=visible_items, | |
| ) | |
| def _room_info(session: EpisodeSession) -> dict[str, Any]: | |
| node_by_id = {node.id: node for node in session.compiled.world.nodes} | |
| room = node_by_id.get(session.current_room_id) | |
| return { | |
| "id": session.current_room_id, | |
| "label": room.label if room else session.current_room_id, | |
| "description": room.description if room else "", | |
| } | |
| def _write_json(self, filename: str, payload: str) -> None: | |
| self._live_dir.mkdir(parents=True, exist_ok=True) | |
| path = self._live_dir / filename | |
| tmp_path = path.with_suffix(path.suffix + ".tmp") | |
| tmp_path.write_text(payload + "\n", encoding="utf-8") | |
| tmp_path.replace(path) | |
| def create_server( | |
| *, | |
| live_dir: Path | None = None, | |
| host: str = "127.0.0.1", | |
| port: int = 8000, | |
| use_gemini: bool = False, | |
| ) -> ThreadingHTTPServer: | |
| resolved_live_dir = live_dir or DEFAULT_LIVE_DIR | |
| game = GameSessionManager(resolved_live_dir, use_gemini=use_gemini) | |
| class LiveViewerHandler(BaseHTTPRequestHandler): | |
| server_version = "AgentsMasterLive/1.0" | |
| def do_GET(self) -> None: # noqa: N802 | |
| path = urlparse(self.path).path | |
| if path == "/api/state": | |
| self._serve_live_file(STATE_FILENAME) | |
| return | |
| if path == "/api/world": | |
| self._serve_live_file(WORLD_FILENAME) | |
| return | |
| if path == "/": | |
| self._serve_index() | |
| return | |
| if path == "/favicon.ico": | |
| self.send_response(HTTPStatus.NO_CONTENT) | |
| self.end_headers() | |
| return | |
| if self._serve_web_file(path): | |
| return | |
| if WEB_DIST_DIR.exists() and Path(path).suffix == "": | |
| self._serve_index() | |
| return | |
| self._respond(HTTPStatus.NOT_FOUND, b"Not found\n", "text/plain; charset=utf-8") | |
| def do_POST(self) -> None: # noqa: N802 | |
| path = urlparse(self.path).path | |
| body = self._read_body() | |
| if path == "/api/reset": | |
| result = game.reset() | |
| self._json_respond(HTTPStatus.OK, result) | |
| return | |
| if path == "/api/start": | |
| try: | |
| world_input = json.loads(body) if body else None | |
| if world_input is None: | |
| self._json_respond(HTTPStatus.BAD_REQUEST, {"ok": False, "error": "Missing JSON body."}) | |
| return | |
| result = game.start(world_input) | |
| self._json_respond(HTTPStatus.OK, result) | |
| except (DMCompileError, ValueError, json.JSONDecodeError) as exc: | |
| self._json_respond(HTTPStatus.BAD_REQUEST, {"ok": False, "error": str(exc)}) | |
| return | |
| if path == "/api/command": | |
| try: | |
| data = json.loads(body) if body else {} | |
| command = data.get("command", "").strip() | |
| if not command: | |
| self._json_respond(HTTPStatus.BAD_REQUEST, {"ok": False, "error": "Missing 'command' field."}) | |
| return | |
| result = game.command(command) | |
| self._json_respond(HTTPStatus.OK, result) | |
| except json.JSONDecodeError as exc: | |
| self._json_respond(HTTPStatus.BAD_REQUEST, {"ok": False, "error": str(exc)}) | |
| return | |
| self._respond(HTTPStatus.NOT_FOUND, b"Not found\n", "text/plain; charset=utf-8") | |
| def log_message(self, format: str, *args: object) -> None: # noqa: A003 | |
| del format, args | |
| def _read_body(self) -> bytes: | |
| length = int(self.headers.get("Content-Length", 0)) | |
| return self.rfile.read(length) if length > 0 else b"" | |
| def _serve_index(self) -> None: | |
| index_path = WEB_DIST_DIR / "index.html" | |
| if index_path.is_file(): | |
| self._respond(HTTPStatus.OK, index_path.read_bytes(), "text/html; charset=utf-8") | |
| else: | |
| from .templates import render_index | |
| self._respond(HTTPStatus.OK, render_index().encode("utf-8"), "text/html; charset=utf-8") | |
| def _serve_live_file(self, filename: str) -> None: | |
| payload = load_live_payload(resolved_live_dir, filename) | |
| if payload is None: | |
| self.send_response(HTTPStatus.NO_CONTENT) | |
| self.send_header("Cache-Control", "no-store") | |
| self.end_headers() | |
| return | |
| self._respond( | |
| HTTPStatus.OK, payload, "application/json; charset=utf-8", | |
| extra_headers={"Cache-Control": "no-store"}, | |
| ) | |
| def _serve_web_file(self, path: str) -> bool: | |
| candidate = (WEB_DIST_DIR / path.lstrip("/")).resolve() | |
| try: | |
| candidate.relative_to(WEB_DIST_DIR.resolve()) | |
| except ValueError: | |
| return False | |
| if not candidate.is_file(): | |
| return False | |
| content_type = mimetypes.guess_type(candidate.name)[0] or "application/octet-stream" | |
| self._respond(HTTPStatus.OK, candidate.read_bytes(), content_type) | |
| return True | |
| def _json_respond(self, status: HTTPStatus, data: dict[str, Any]) -> None: | |
| payload = json.dumps(data).encode("utf-8") | |
| self._respond(status, payload, "application/json; charset=utf-8", | |
| extra_headers={"Cache-Control": "no-store"}) | |
| def _respond( | |
| self, status: HTTPStatus, payload: bytes, content_type: str, | |
| *, extra_headers: dict[str, str] | None = None, | |
| ) -> None: | |
| self.send_response(status) | |
| self.send_header("Content-Type", content_type) | |
| self.send_header("Content-Length", str(len(payload))) | |
| if extra_headers: | |
| for key, value in extra_headers.items(): | |
| self.send_header(key, value) | |
| self.end_headers() | |
| self.wfile.write(payload) | |
| return ThreadingHTTPServer((host, port), LiveViewerHandler) | |
| def run_server(*, port: int = 8000, live_dir: Path | None = None, host: str = "127.0.0.1", use_gemini: bool = False) -> None: | |
| server = create_server(live_dir=live_dir, host=host, port=port, use_gemini=use_gemini) | |
| print(f"Serving live viewer on http://{host}:{server.server_address[1]}") | |
| try: | |
| server.serve_forever() | |
| finally: | |
| server.server_close() | |