File size: 3,276 Bytes
9fca766
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Inference backends: one streaming protocol, several providers.

Everything the Warden says or decides flows through Backend.stream().
The game must never block on a dead backend — callers always have a
scripted fallback and a timeout.
"""

from __future__ import annotations

import hashlib
import json
from pathlib import Path
from typing import AsyncIterator, Protocol

Message = dict[str, str]  # {"role": ..., "content": ...}


class Backend(Protocol):
    async def stream(
        self,
        messages: list[Message],
        *,
        max_tokens: int = 256,
        temperature: float = 0.6,
        thinking: bool = False,
    ) -> AsyncIterator[str]:
        """Yield response text chunks. thinking enables the model's
        reasoning block for this one call (slower; the caller pays)."""
        ...


async def complete(backend: Backend, messages: list[Message], **kw) -> str:
    return "".join([chunk async for chunk in backend.stream(messages, **kw)])


class ScriptedBackend:
    """Offline fallback and test double: answers from a playbook.

    The playbook maps a substring (matched against the last message) to a
    response. Unmatched prompts get the default line.
    """

    def __init__(self, playbook: dict[str, str] | None = None, default: str = "..."):
        self.playbook = playbook or {}
        self.default = default
        self.calls: list[list[Message]] = []

    async def stream(self, messages, *, max_tokens=256, temperature=0.6, thinking=False):
        self.calls.append(messages)
        last = messages[-1]["content"] if messages else ""
        response = self.default
        for needle, reply in self.playbook.items():
            if needle in last:
                response = reply
                break
        # Chunked like a real stream so streaming consumers get exercised.
        for i in range(0, len(response), 8):
            yield response[i : i + 8]


def _fingerprint(messages: list[Message]) -> str:
    return hashlib.sha256(
        json.dumps(messages, sort_keys=True).encode()
    ).hexdigest()[:16]


class RecordingBackend:
    """Wraps a live backend and writes request→response fixtures."""

    def __init__(self, inner: Backend, path: Path):
        self.inner = inner
        self.path = path

    async def stream(self, messages, **kw):
        chunks: list[str] = []
        async for chunk in self.inner.stream(messages, **kw):
            chunks.append(chunk)
            yield chunk
        record = {"key": _fingerprint(messages), "chunks": chunks}
        with self.path.open("a", encoding="utf-8") as f:
            f.write(json.dumps(record) + "\n")


class ReplayBackend:
    """Replays recorded fixtures; raises on unknown prompts so tests fail
    loudly instead of silently drifting."""

    def __init__(self, path: Path):
        self.records: dict[str, list[str]] = {}
        for line in path.read_text(encoding="utf-8").splitlines():
            rec = json.loads(line)
            self.records[rec["key"]] = rec["chunks"]

    async def stream(self, messages, **kw):
        key = _fingerprint(messages)
        if key not in self.records:
            raise KeyError(f"no recorded response for prompt {key}")
        for chunk in self.records[key]:
            yield chunk