"""Inference backends: one streaming protocol, several providers. Everything the Warden says or decides flows through Backend.stream(). The game must never block on a dead backend — callers always have a scripted fallback and a timeout. """ from __future__ import annotations import hashlib import json from pathlib import Path from typing import AsyncIterator, Protocol Message = dict[str, str] # {"role": ..., "content": ...} class Backend(Protocol): async def stream( self, messages: list[Message], *, max_tokens: int = 256, temperature: float = 0.6, thinking: bool = False, ) -> AsyncIterator[str]: """Yield response text chunks. thinking enables the model's reasoning block for this one call (slower; the caller pays).""" ... async def complete(backend: Backend, messages: list[Message], **kw) -> str: return "".join([chunk async for chunk in backend.stream(messages, **kw)]) class ScriptedBackend: """Offline fallback and test double: answers from a playbook. The playbook maps a substring (matched against the last message) to a response. Unmatched prompts get the default line. """ def __init__(self, playbook: dict[str, str] | None = None, default: str = "..."): self.playbook = playbook or {} self.default = default self.calls: list[list[Message]] = [] async def stream(self, messages, *, max_tokens=256, temperature=0.6, thinking=False): self.calls.append(messages) last = messages[-1]["content"] if messages else "" response = self.default for needle, reply in self.playbook.items(): if needle in last: response = reply break # Chunked like a real stream so streaming consumers get exercised. for i in range(0, len(response), 8): yield response[i : i + 8] def _fingerprint(messages: list[Message]) -> str: return hashlib.sha256( json.dumps(messages, sort_keys=True).encode() ).hexdigest()[:16] class RecordingBackend: """Wraps a live backend and writes request→response fixtures.""" def __init__(self, inner: Backend, path: Path): self.inner = inner self.path = path async def stream(self, messages, **kw): chunks: list[str] = [] async for chunk in self.inner.stream(messages, **kw): chunks.append(chunk) yield chunk record = {"key": _fingerprint(messages), "chunks": chunks} with self.path.open("a", encoding="utf-8") as f: f.write(json.dumps(record) + "\n") class ReplayBackend: """Replays recorded fixtures; raises on unknown prompts so tests fail loudly instead of silently drifting.""" def __init__(self, path: Path): self.records: dict[str, list[str]] = {} for line in path.read_text(encoding="utf-8").splitlines(): rec = json.loads(line) self.records[rec["key"]] = rec["chunks"] async def stream(self, messages, **kw): key = _fingerprint(messages) if key not in self.records: raise KeyError(f"no recorded response for prompt {key}") for chunk in self.records[key]: yield chunk