irregular6612's picture
feat(web): show handover memory + persona rubric for LLM spectate (same as human play); highlight each entity's behaviour as a courier passes it
d42e3af
Raw
History Blame Contribute Delete
15.4 kB
"""Stdlib HTTP server for interactive color-grid play.
A pure ``handle_request(method, path, body, registry)`` router (unit-testable
without a socket) plus a thin BaseHTTPRequestHandler adapter and a make_server
factory. In-memory session registry; local single-user. No new dependencies.
Concurrency: ThreadingHTTPServer serves each request on its own thread, but the
registry is a plain dict and a session is not locked — this assumes one local
player (no concurrent ``/act`` on the same session). Sessions are never evicted
(unbounded for a long-lived server); both are acceptable for single-user play.
"""
from __future__ import annotations
import json
import uuid
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from proteus.game.engine.rendering import COLOR_MAP
from proteus.game.engine.difficulty import Difficulty
from proteus.game.scenarios.base import list_scenarios
from proteus.game.runtime import _session_core as core
from proteus.game.runtime.interactive import InteractiveSession
from proteus.game.runtime.io import append_trace
_STATIC = Path(__file__).parent / "static" / "index.html"
_DEFAULT_SEED = 42
# (status, payload, content_type). payload is a dict (json) or bytes (html).
Response = tuple[int, "dict | bytes", str]
_JSON = "application/json"
_HTML = "text/html; charset=utf-8"
def _err(status: int, message: str) -> Response:
return status, {"error": message}, _JSON
def _config_payload() -> dict:
from proteus.providers import available_providers # local: keep module import light
return {
"scenarios": list_scenarios(),
"difficulties": [d.value for d in Difficulty],
"color_map": {str(k): v for k, v in COLOR_MAP.items()},
"default_seed": _DEFAULT_SEED,
"providers": available_providers(),
"default_model": "fake:demo",
}
def _resolve_web_memory(
mode, *, scenario: str, difficulty: Difficulty, seed, model: str = "",
memory_turns: int = 100, memory_root: str = "runs/memory",
):
"""Resolve a web 'memory' selection to ``(memory|None, use_default, error)``.
Modes: ``''``/``'default'`` -> the scenario's default memory; ``'none'`` ->
force no memory; ``'persona[:id]'`` -> a hidden reference-policy demonstration
(offline, no model); ``'generate'`` -> the LLM self-plays a memory with
*model*; ``'latest'`` -> the newest saved checkpoint for *model*; anything
else -> a checkpoint file path. ``use_default`` tells the session whether to
fall back to the scenario default when ``memory is None``.
"""
if mode in (None, "", "default"):
return None, True, None
if mode == "none":
return None, False, None
from proteus.game.runtime.memory_gen import generate_memory
if mode.startswith("persona"):
from proteus.game.metrics.persona import available_personas, get_persona
pid = mode.split(":", 1)[1] if ":" in mode else "risk_averse"
try:
persona = get_persona(pid)
except KeyError:
return None, False, f"unknown persona {pid!r}; have {available_personas()}"
ckpt = generate_memory(
scenario, None, difficulty=difficulty, seed=seed,
memory_turns=memory_turns, model_name=f"persona:{pid}", persona=persona,
)
return ckpt, False, None
if mode.startswith("policy:"):
from proteus.game.runtime.memory_policies import available_policies, get_policy
name = mode.split(":", 1)[1]
try:
policy = get_policy(name)
except KeyError:
return None, False, f"unknown policy {name!r}; have {available_policies()}"
ckpt = generate_memory(
scenario, None, difficulty=difficulty, seed=seed,
memory_turns=memory_turns, model_name=f"policy:{name}", policy=policy,
)
return ckpt, False, None
if mode in ("generate", "latest"):
if not model:
return None, False, f"memory mode {mode!r} needs a model"
from proteus.providers import make_provider
try:
provider = make_provider(model)
except Exception as exc: # missing SDK / key / bad spec -> clear client error
return None, False, str(exc)
if mode == "generate":
from proteus.game.agents.vanilla import VanillaAgent
from proteus.game.runtime.memory import save_checkpoint
ckpt = generate_memory(
scenario, VanillaAgent(provider), difficulty=difficulty, seed=seed,
memory_turns=memory_turns, model_name=provider.model_name,
)
save_checkpoint(ckpt, root=memory_root)
return ckpt, False, None
from proteus.game.runtime.memory import latest_for_model
ckpt = latest_for_model(provider.model_name, root=memory_root)
if ckpt is None:
return None, False, f"no saved memory for {provider.model_name!r}"
return ckpt, False, None
from proteus.game.runtime.memory import load_checkpoint
try:
return load_checkpoint(mode), False, None
except FileNotFoundError:
return None, False, f"memory checkpoint not found: {mode}"
def _memory_info(session) -> dict:
"""A client-facing summary + the rendered block of the attached memory.
``block`` is exactly the handover memory text the model is shown (so the web
UI can display, at the start, what the model 'remembers' before the task).
"""
mem = getattr(session, "_memory", None)
info = {
"attached": mem is not None,
"source": (mem.model if mem else None),
"turns": (len(mem.memory_turns) if mem else 0),
"persona": (mem.persona_weight_id if mem else None),
"block": None,
"frames": [],
"variants": [],
"selected": None,
"rubric": None,
}
if mem is not None:
from proteus.game.scenarios.base import get_scenario
from proteus.game.runtime.memory import memory_frames, render_memory_block
info["block"] = render_memory_block(mem)
scen = get_scenario(mem.scenario)()
info["frames"] = memory_frames(
mem, legend=scen.legend(), grid_size=scen.grid_size,
)
if mem.scenario == "errand_runner":
from proteus.game.runtime.multiagent_director import author_errand_runner_variants
from proteus.game.scenarios import errand_world as w
seed = mem.seed if mem.seed is not None else 0
variants = author_errand_runner_variants(seed=seed)
info["variants"] = [
{"id": pid, "label": w.PERSONA_LABELS[pid],
"frames": memory_frames(ck, legend=scen.legend(), grid_size=scen.grid_size)}
for pid, ck in variants.items()
]
info["selected"] = mem.persona_weight_id or next(iter(variants))
sel = info["selected"]
info["rubric"] = {
"persona": sel,
"label": w.PERSONA_LABELS.get(sel, sel),
"rows": w.persona_rubric(sel), # per-entity reaction + live coords
}
return info
def _create_session(body: dict, registry: dict) -> Response:
try:
scenario = body["scenario"]
difficulty = Difficulty(body.get("difficulty", "easy"))
raw_seed = body.get("seed", _DEFAULT_SEED)
seed = None if raw_seed in ("", None) else int(raw_seed)
play_turns = int(body.get("play_turns", 100))
probe = bool(body.get("probe", False))
except (KeyError, ValueError) as exc:
return _err(400, f"bad session params: {exc}")
if scenario not in list_scenarios():
return _err(400, f"unknown scenario {scenario!r}")
memory, use_default, mem_err = _resolve_web_memory(
body.get("memory"), scenario=scenario, difficulty=difficulty, seed=seed,
model=body.get("model", ""), memory_root=body.get("memory_root") or "runs/memory",
)
if mem_err is not None:
return _err(400, mem_err)
session = InteractiveSession(
scenario, difficulty=difficulty, seed=seed,
play_turns=play_turns, use_probe=probe,
memory=memory, use_default_memory=use_default,
)
sid = uuid.uuid4().hex
registry[sid] = session
return 200, {"session_id": sid, "state": session.state(),
"memory": _memory_info(session)}, _JSON
def _act(sid: str, body: dict, registry: dict) -> Response:
session = registry.get(sid)
if session is None:
return _err(404, f"unknown session {sid!r}")
action = (body or {}).get("action", "")
probe_answer = (body or {}).get("probe_answer", "")
try:
state = session.step(action, probe_answer)
except core.SessionFinishedError:
return _err(409, "session already finished")
except ValueError as exc:
return _err(400, str(exc))
return 200, {"state": state}, _JSON
def _finish(sid: str, body: dict, registry: dict) -> Response:
session = registry.get(sid)
if session is None:
return _err(404, f"unknown session {sid!r}")
try:
trace = session.finish()
except core.SessionNotFinishedError:
return _err(409, "session not finished")
out = (body or {}).get("out")
if not out:
out = f"runs/web_{trace.scenario}_{trace.difficulty}.jsonl"
written = append_trace(trace, out)
return 200, {"trace_path": str(written), "metrics": trace.metrics}, _JSON
def _create_spectate(body: dict, registry: dict) -> Response:
# Lazy imports keep `import proteus.web.local.server` free of provider SDKs.
from proteus.game.agents.vanilla import VanillaAgent
from proteus.providers import make_provider
from proteus.game.runtime.spectate import SpectateSession
try:
scenario = body["scenario"]
difficulty = Difficulty(body.get("difficulty", "easy"))
seed = body.get("seed", _DEFAULT_SEED)
seed = None if seed in ("", None) else int(seed)
play_turns = int(body.get("play_turns", 100))
model = body.get("model", "fake:demo")
probe = bool(body.get("probe", False))
except (KeyError, ValueError) as exc:
return _err(400, f"bad spectate params: {exc}")
if scenario not in list_scenarios():
return _err(400, f"unknown scenario {scenario!r}")
try:
provider = make_provider(model)
except Exception as exc: # missing SDK / key / bad spec -> clear client error
return _err(400, f"provider {model!r}: {exc}")
memory, use_default, mem_err = _resolve_web_memory(
body.get("memory"), scenario=scenario, difficulty=difficulty, seed=seed,
model=model, memory_root=body.get("memory_root") or "runs/memory",
)
if mem_err is not None:
return _err(400, mem_err)
agent = VanillaAgent(provider)
session = SpectateSession(
scenario, agent=agent, model_name=provider.model_name,
difficulty=difficulty, seed=seed, play_turns=play_turns, use_probe=probe,
memory=memory, use_default_memory=use_default,
)
sid = uuid.uuid4().hex
registry[sid] = session
return 200, {"session_id": sid, "state": session.state(),
"memory": _memory_info(session)}, _JSON
def _spectate_next(sid: str, registry: dict) -> Response:
session = registry.get(sid)
if session is None:
return _err(404, f"unknown session {sid!r}")
try:
state = session.advance()
except core.SessionFinishedError:
return _err(409, "session already finished")
except Exception as exc: # provider failure (missing key / network): no stack leak
return _err(500, f"agent step failed: {exc}")
return 200, {"state": state}, _JSON
def _spectate_finish(sid: str, body: dict, registry: dict) -> Response:
session = registry.get(sid)
if session is None:
return _err(404, f"unknown session {sid!r}")
try:
trace = session.finish()
except core.SessionNotFinishedError:
return _err(409, "session not finished")
out = (body or {}).get("out") or f"runs/web_spectate_{trace.scenario}_{trace.difficulty}.jsonl"
written = append_trace(trace, out)
return 200, {"trace_path": str(written), "metrics": trace.metrics}, _JSON
def handle_request(
method: str, path: str, body: dict | None, registry: dict,
) -> Response:
"""Route one request to a (status, payload, content_type) response."""
path = path.split("?", 1)[0] # ignore any query string
if method == "GET" and path == "/":
return 200, _STATIC.read_bytes(), _HTML
if method == "GET" and path == "/config":
return 200, _config_payload(), _JSON
if method == "POST" and path == "/session":
return _create_session(body or {}, registry)
if path.startswith("/session/"):
rest = path[len("/session/"):]
if method == "POST" and rest.endswith("/act"):
return _act(rest[: -len("/act")], body or {}, registry)
if method == "POST" and rest.endswith("/finish"):
return _finish(rest[: -len("/finish")], body or {}, registry)
if method == "GET":
session = registry.get(rest)
if session is None:
return _err(404, f"unknown session {rest!r}")
return 200, {"state": session.state()}, _JSON
if method == "POST" and path == "/spectate":
return _create_spectate(body or {}, registry)
if path.startswith("/spectate/"):
rest = path[len("/spectate/"):]
if method == "POST" and rest.endswith("/next"):
return _spectate_next(rest[: -len("/next")], registry)
if method == "POST" and rest.endswith("/finish"):
return _spectate_finish(rest[: -len("/finish")], body or {}, registry)
if method == "GET":
session = registry.get(rest)
if session is None:
return _err(404, f"unknown session {rest!r}")
return 200, {"state": session.state()}, _JSON
return _err(404, f"no route for {method} {path}")
class _Handler(BaseHTTPRequestHandler):
registry: dict = {}
def log_message(self, *_args): # silence default stderr logging
return
def _dispatch(self, method: str) -> None:
length = int(self.headers.get("Content-Length", 0) or 0)
raw = self.rfile.read(length) if length else b""
body = json.loads(raw) if raw else None
try:
status, payload, ctype = handle_request(
method, self.path, body, self.registry,
)
except Exception as exc: # never leak a stack to the client
status, payload, ctype = 500, {"error": str(exc)}, _JSON
self.send_response(status)
self.send_header("Content-Type", ctype)
data = payload if isinstance(payload, bytes) else json.dumps(payload).encode()
self.send_header("Content-Length", str(len(data)))
self.end_headers()
self.wfile.write(data)
def do_GET(self) -> None:
self._dispatch("GET")
def do_POST(self) -> None:
self._dispatch("POST")
def make_server(host: str = "127.0.0.1", port: int = 8000) -> ThreadingHTTPServer:
"""Build (but do not serve) a ThreadingHTTPServer with a fresh registry."""
handler = type("BoundHandler", (_Handler,), {"registry": {}})
return ThreadingHTTPServer((host, port), handler)