Spaces:
Sleeping
Sleeping
feat(web): show handover memory + persona rubric for LLM spectate (same as human play); highlight each entity's behaviour as a courier passes it
d42e3af | """Stdlib HTTP server for interactive color-grid play. | |
| A pure ``handle_request(method, path, body, registry)`` router (unit-testable | |
| without a socket) plus a thin BaseHTTPRequestHandler adapter and a make_server | |
| factory. In-memory session registry; local single-user. No new dependencies. | |
| Concurrency: ThreadingHTTPServer serves each request on its own thread, but the | |
| registry is a plain dict and a session is not locked — this assumes one local | |
| player (no concurrent ``/act`` on the same session). Sessions are never evicted | |
| (unbounded for a long-lived server); both are acceptable for single-user play. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import uuid | |
| from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer | |
| from pathlib import Path | |
| from proteus.game.engine.rendering import COLOR_MAP | |
| from proteus.game.engine.difficulty import Difficulty | |
| from proteus.game.scenarios.base import list_scenarios | |
| from proteus.game.runtime import _session_core as core | |
| from proteus.game.runtime.interactive import InteractiveSession | |
| from proteus.game.runtime.io import append_trace | |
| _STATIC = Path(__file__).parent / "static" / "index.html" | |
| _DEFAULT_SEED = 42 | |
| # (status, payload, content_type). payload is a dict (json) or bytes (html). | |
| Response = tuple[int, "dict | bytes", str] | |
| _JSON = "application/json" | |
| _HTML = "text/html; charset=utf-8" | |
| def _err(status: int, message: str) -> Response: | |
| return status, {"error": message}, _JSON | |
| def _config_payload() -> dict: | |
| from proteus.providers import available_providers # local: keep module import light | |
| return { | |
| "scenarios": list_scenarios(), | |
| "difficulties": [d.value for d in Difficulty], | |
| "color_map": {str(k): v for k, v in COLOR_MAP.items()}, | |
| "default_seed": _DEFAULT_SEED, | |
| "providers": available_providers(), | |
| "default_model": "fake:demo", | |
| } | |
| def _resolve_web_memory( | |
| mode, *, scenario: str, difficulty: Difficulty, seed, model: str = "", | |
| memory_turns: int = 100, memory_root: str = "runs/memory", | |
| ): | |
| """Resolve a web 'memory' selection to ``(memory|None, use_default, error)``. | |
| Modes: ``''``/``'default'`` -> the scenario's default memory; ``'none'`` -> | |
| force no memory; ``'persona[:id]'`` -> a hidden reference-policy demonstration | |
| (offline, no model); ``'generate'`` -> the LLM self-plays a memory with | |
| *model*; ``'latest'`` -> the newest saved checkpoint for *model*; anything | |
| else -> a checkpoint file path. ``use_default`` tells the session whether to | |
| fall back to the scenario default when ``memory is None``. | |
| """ | |
| if mode in (None, "", "default"): | |
| return None, True, None | |
| if mode == "none": | |
| return None, False, None | |
| from proteus.game.runtime.memory_gen import generate_memory | |
| if mode.startswith("persona"): | |
| from proteus.game.metrics.persona import available_personas, get_persona | |
| pid = mode.split(":", 1)[1] if ":" in mode else "risk_averse" | |
| try: | |
| persona = get_persona(pid) | |
| except KeyError: | |
| return None, False, f"unknown persona {pid!r}; have {available_personas()}" | |
| ckpt = generate_memory( | |
| scenario, None, difficulty=difficulty, seed=seed, | |
| memory_turns=memory_turns, model_name=f"persona:{pid}", persona=persona, | |
| ) | |
| return ckpt, False, None | |
| if mode.startswith("policy:"): | |
| from proteus.game.runtime.memory_policies import available_policies, get_policy | |
| name = mode.split(":", 1)[1] | |
| try: | |
| policy = get_policy(name) | |
| except KeyError: | |
| return None, False, f"unknown policy {name!r}; have {available_policies()}" | |
| ckpt = generate_memory( | |
| scenario, None, difficulty=difficulty, seed=seed, | |
| memory_turns=memory_turns, model_name=f"policy:{name}", policy=policy, | |
| ) | |
| return ckpt, False, None | |
| if mode in ("generate", "latest"): | |
| if not model: | |
| return None, False, f"memory mode {mode!r} needs a model" | |
| from proteus.providers import make_provider | |
| try: | |
| provider = make_provider(model) | |
| except Exception as exc: # missing SDK / key / bad spec -> clear client error | |
| return None, False, str(exc) | |
| if mode == "generate": | |
| from proteus.game.agents.vanilla import VanillaAgent | |
| from proteus.game.runtime.memory import save_checkpoint | |
| ckpt = generate_memory( | |
| scenario, VanillaAgent(provider), difficulty=difficulty, seed=seed, | |
| memory_turns=memory_turns, model_name=provider.model_name, | |
| ) | |
| save_checkpoint(ckpt, root=memory_root) | |
| return ckpt, False, None | |
| from proteus.game.runtime.memory import latest_for_model | |
| ckpt = latest_for_model(provider.model_name, root=memory_root) | |
| if ckpt is None: | |
| return None, False, f"no saved memory for {provider.model_name!r}" | |
| return ckpt, False, None | |
| from proteus.game.runtime.memory import load_checkpoint | |
| try: | |
| return load_checkpoint(mode), False, None | |
| except FileNotFoundError: | |
| return None, False, f"memory checkpoint not found: {mode}" | |
| def _memory_info(session) -> dict: | |
| """A client-facing summary + the rendered block of the attached memory. | |
| ``block`` is exactly the handover memory text the model is shown (so the web | |
| UI can display, at the start, what the model 'remembers' before the task). | |
| """ | |
| mem = getattr(session, "_memory", None) | |
| info = { | |
| "attached": mem is not None, | |
| "source": (mem.model if mem else None), | |
| "turns": (len(mem.memory_turns) if mem else 0), | |
| "persona": (mem.persona_weight_id if mem else None), | |
| "block": None, | |
| "frames": [], | |
| "variants": [], | |
| "selected": None, | |
| "rubric": None, | |
| } | |
| if mem is not None: | |
| from proteus.game.scenarios.base import get_scenario | |
| from proteus.game.runtime.memory import memory_frames, render_memory_block | |
| info["block"] = render_memory_block(mem) | |
| scen = get_scenario(mem.scenario)() | |
| info["frames"] = memory_frames( | |
| mem, legend=scen.legend(), grid_size=scen.grid_size, | |
| ) | |
| if mem.scenario == "errand_runner": | |
| from proteus.game.runtime.multiagent_director import author_errand_runner_variants | |
| from proteus.game.scenarios import errand_world as w | |
| seed = mem.seed if mem.seed is not None else 0 | |
| variants = author_errand_runner_variants(seed=seed) | |
| info["variants"] = [ | |
| {"id": pid, "label": w.PERSONA_LABELS[pid], | |
| "frames": memory_frames(ck, legend=scen.legend(), grid_size=scen.grid_size)} | |
| for pid, ck in variants.items() | |
| ] | |
| info["selected"] = mem.persona_weight_id or next(iter(variants)) | |
| sel = info["selected"] | |
| info["rubric"] = { | |
| "persona": sel, | |
| "label": w.PERSONA_LABELS.get(sel, sel), | |
| "rows": w.persona_rubric(sel), # per-entity reaction + live coords | |
| } | |
| return info | |
| def _create_session(body: dict, registry: dict) -> Response: | |
| try: | |
| scenario = body["scenario"] | |
| difficulty = Difficulty(body.get("difficulty", "easy")) | |
| raw_seed = body.get("seed", _DEFAULT_SEED) | |
| seed = None if raw_seed in ("", None) else int(raw_seed) | |
| play_turns = int(body.get("play_turns", 100)) | |
| probe = bool(body.get("probe", False)) | |
| except (KeyError, ValueError) as exc: | |
| return _err(400, f"bad session params: {exc}") | |
| if scenario not in list_scenarios(): | |
| return _err(400, f"unknown scenario {scenario!r}") | |
| memory, use_default, mem_err = _resolve_web_memory( | |
| body.get("memory"), scenario=scenario, difficulty=difficulty, seed=seed, | |
| model=body.get("model", ""), memory_root=body.get("memory_root") or "runs/memory", | |
| ) | |
| if mem_err is not None: | |
| return _err(400, mem_err) | |
| session = InteractiveSession( | |
| scenario, difficulty=difficulty, seed=seed, | |
| play_turns=play_turns, use_probe=probe, | |
| memory=memory, use_default_memory=use_default, | |
| ) | |
| sid = uuid.uuid4().hex | |
| registry[sid] = session | |
| return 200, {"session_id": sid, "state": session.state(), | |
| "memory": _memory_info(session)}, _JSON | |
| def _act(sid: str, body: dict, registry: dict) -> Response: | |
| session = registry.get(sid) | |
| if session is None: | |
| return _err(404, f"unknown session {sid!r}") | |
| action = (body or {}).get("action", "") | |
| probe_answer = (body or {}).get("probe_answer", "") | |
| try: | |
| state = session.step(action, probe_answer) | |
| except core.SessionFinishedError: | |
| return _err(409, "session already finished") | |
| except ValueError as exc: | |
| return _err(400, str(exc)) | |
| return 200, {"state": state}, _JSON | |
| def _finish(sid: str, body: dict, registry: dict) -> Response: | |
| session = registry.get(sid) | |
| if session is None: | |
| return _err(404, f"unknown session {sid!r}") | |
| try: | |
| trace = session.finish() | |
| except core.SessionNotFinishedError: | |
| return _err(409, "session not finished") | |
| out = (body or {}).get("out") | |
| if not out: | |
| out = f"runs/web_{trace.scenario}_{trace.difficulty}.jsonl" | |
| written = append_trace(trace, out) | |
| return 200, {"trace_path": str(written), "metrics": trace.metrics}, _JSON | |
| def _create_spectate(body: dict, registry: dict) -> Response: | |
| # Lazy imports keep `import proteus.web.local.server` free of provider SDKs. | |
| from proteus.game.agents.vanilla import VanillaAgent | |
| from proteus.providers import make_provider | |
| from proteus.game.runtime.spectate import SpectateSession | |
| try: | |
| scenario = body["scenario"] | |
| difficulty = Difficulty(body.get("difficulty", "easy")) | |
| seed = body.get("seed", _DEFAULT_SEED) | |
| seed = None if seed in ("", None) else int(seed) | |
| play_turns = int(body.get("play_turns", 100)) | |
| model = body.get("model", "fake:demo") | |
| probe = bool(body.get("probe", False)) | |
| except (KeyError, ValueError) as exc: | |
| return _err(400, f"bad spectate params: {exc}") | |
| if scenario not in list_scenarios(): | |
| return _err(400, f"unknown scenario {scenario!r}") | |
| try: | |
| provider = make_provider(model) | |
| except Exception as exc: # missing SDK / key / bad spec -> clear client error | |
| return _err(400, f"provider {model!r}: {exc}") | |
| memory, use_default, mem_err = _resolve_web_memory( | |
| body.get("memory"), scenario=scenario, difficulty=difficulty, seed=seed, | |
| model=model, memory_root=body.get("memory_root") or "runs/memory", | |
| ) | |
| if mem_err is not None: | |
| return _err(400, mem_err) | |
| agent = VanillaAgent(provider) | |
| session = SpectateSession( | |
| scenario, agent=agent, model_name=provider.model_name, | |
| difficulty=difficulty, seed=seed, play_turns=play_turns, use_probe=probe, | |
| memory=memory, use_default_memory=use_default, | |
| ) | |
| sid = uuid.uuid4().hex | |
| registry[sid] = session | |
| return 200, {"session_id": sid, "state": session.state(), | |
| "memory": _memory_info(session)}, _JSON | |
| def _spectate_next(sid: str, registry: dict) -> Response: | |
| session = registry.get(sid) | |
| if session is None: | |
| return _err(404, f"unknown session {sid!r}") | |
| try: | |
| state = session.advance() | |
| except core.SessionFinishedError: | |
| return _err(409, "session already finished") | |
| except Exception as exc: # provider failure (missing key / network): no stack leak | |
| return _err(500, f"agent step failed: {exc}") | |
| return 200, {"state": state}, _JSON | |
| def _spectate_finish(sid: str, body: dict, registry: dict) -> Response: | |
| session = registry.get(sid) | |
| if session is None: | |
| return _err(404, f"unknown session {sid!r}") | |
| try: | |
| trace = session.finish() | |
| except core.SessionNotFinishedError: | |
| return _err(409, "session not finished") | |
| out = (body or {}).get("out") or f"runs/web_spectate_{trace.scenario}_{trace.difficulty}.jsonl" | |
| written = append_trace(trace, out) | |
| return 200, {"trace_path": str(written), "metrics": trace.metrics}, _JSON | |
| def handle_request( | |
| method: str, path: str, body: dict | None, registry: dict, | |
| ) -> Response: | |
| """Route one request to a (status, payload, content_type) response.""" | |
| path = path.split("?", 1)[0] # ignore any query string | |
| if method == "GET" and path == "/": | |
| return 200, _STATIC.read_bytes(), _HTML | |
| if method == "GET" and path == "/config": | |
| return 200, _config_payload(), _JSON | |
| if method == "POST" and path == "/session": | |
| return _create_session(body or {}, registry) | |
| if path.startswith("/session/"): | |
| rest = path[len("/session/"):] | |
| if method == "POST" and rest.endswith("/act"): | |
| return _act(rest[: -len("/act")], body or {}, registry) | |
| if method == "POST" and rest.endswith("/finish"): | |
| return _finish(rest[: -len("/finish")], body or {}, registry) | |
| if method == "GET": | |
| session = registry.get(rest) | |
| if session is None: | |
| return _err(404, f"unknown session {rest!r}") | |
| return 200, {"state": session.state()}, _JSON | |
| if method == "POST" and path == "/spectate": | |
| return _create_spectate(body or {}, registry) | |
| if path.startswith("/spectate/"): | |
| rest = path[len("/spectate/"):] | |
| if method == "POST" and rest.endswith("/next"): | |
| return _spectate_next(rest[: -len("/next")], registry) | |
| if method == "POST" and rest.endswith("/finish"): | |
| return _spectate_finish(rest[: -len("/finish")], body or {}, registry) | |
| if method == "GET": | |
| session = registry.get(rest) | |
| if session is None: | |
| return _err(404, f"unknown session {rest!r}") | |
| return 200, {"state": session.state()}, _JSON | |
| return _err(404, f"no route for {method} {path}") | |
| class _Handler(BaseHTTPRequestHandler): | |
| registry: dict = {} | |
| def log_message(self, *_args): # silence default stderr logging | |
| return | |
| def _dispatch(self, method: str) -> None: | |
| length = int(self.headers.get("Content-Length", 0) or 0) | |
| raw = self.rfile.read(length) if length else b"" | |
| body = json.loads(raw) if raw else None | |
| try: | |
| status, payload, ctype = handle_request( | |
| method, self.path, body, self.registry, | |
| ) | |
| except Exception as exc: # never leak a stack to the client | |
| status, payload, ctype = 500, {"error": str(exc)}, _JSON | |
| self.send_response(status) | |
| self.send_header("Content-Type", ctype) | |
| data = payload if isinstance(payload, bytes) else json.dumps(payload).encode() | |
| self.send_header("Content-Length", str(len(data))) | |
| self.end_headers() | |
| self.wfile.write(data) | |
| def do_GET(self) -> None: | |
| self._dispatch("GET") | |
| def do_POST(self) -> None: | |
| self._dispatch("POST") | |
| def make_server(host: str = "127.0.0.1", port: int = 8000) -> ThreadingHTTPServer: | |
| """Build (but do not serve) a ThreadingHTTPServer with a fresh registry.""" | |
| handler = type("BoundHandler", (_Handler,), {"registry": {}}) | |
| return ThreadingHTTPServer((host, port), handler) | |