File size: 4,654 Bytes
9fca766
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Memory shards: distilled per-run facts about the player.

Tiny by design — a handful of one-line observations, ranked by tag overlap
and recency, rendered into a few hundred tokens at most. Distillation
(turning a fight's event log into shards) is one cheap no-think LLM call,
with a deterministic extractor as fallback.
"""

from __future__ import annotations

import asyncio
import re
from collections import Counter
from dataclasses import dataclass, field

from scrypt.engine.combat import CombatState, Result


@dataclass
class Shard:
    text: str
    tags: frozenset[str]
    age: int = 0  # bumped each fight


@dataclass
class ShardStore:
    shards: list[Shard] = field(default_factory=list)
    max_shards: int = 12

    def add(self, text: str, tags: set[str]) -> None:
        self.shards.append(Shard(text=text, tags=frozenset(tags)))
        if len(self.shards) > self.max_shards:
            self.shards.pop(0)

    def tick(self) -> None:
        for s in self.shards:
            s.age += 1

    def select(self, query_tags: set[str], k: int = 4) -> list[Shard]:
        def score(s: Shard) -> tuple:
            return (len(s.tags & query_tags), -s.age)

        return sorted(self.shards, key=score, reverse=True)[:k]

    def render(self, query_tags: set[str], k: int = 4) -> str:
        return "\n".join(f"- {s.text}" for s in self.select(query_tags, k))


MAX_FACT_LEN = 100
_BULLET = re.compile(r"^\s*(?:[-*•]|\d+[.)])\s*(.+)$")


def parse_bullets(text: str, limit: int = 3) -> list[str]:
    """Bullet facts from model output: strict shape, bounded lengths.
    Anything that doesn't look like a terse bullet list is rejected."""
    facts = []
    for line in text.splitlines():
        m = _BULLET.match(line)
        if not m:
            continue
        fact = m.group(1).strip().rstrip(".")
        if 0 < len(fact) <= MAX_FACT_LEN:
            facts.append(fact)
    return facts[:limit]


def _tags_for(fact: str, state: CombatState) -> set[str]:
    """Cheap tag assignment so LLM-phrased facts still rank in retrieval."""
    tags = set()
    lowered = fact.lower()
    if any(w in lowered for w in ("won", "lost", "died", "survived", "turn")):
        tags.add("outcome")
    if any(w in lowered for w in ("sacrific", "kill", "feed")):
        tags.add("style")
    seen_cards = {e.data["card"] for e in state.events if "card" in e.data}
    for card_id in seen_cards:
        if card_id.replace("-", " ") in lowered or card_id in lowered:
            tags |= {"deck", card_id}
    return tags or {"style"}


async def distill_with_voice(
    backend, state: CombatState, *, timeout_s: float = 8.0
) -> list[tuple[str, set[str]]]:
    """The Warden writes its own memory. One no-think call; if the output
    isn't a clean bullet list in time, the deterministic extractor's facts
    stand — memory is too load-bearing to trust an unvalidated paragraph."""
    from scrypt.inference.backend import complete

    from .context import build_messages, combat_digest
    from .moments import DISTILL_FRAME

    if backend is None:
        return distill_fight(state)
    try:
        async with asyncio.timeout(timeout_s):
            reply = await complete(
                backend,
                build_messages(DISTILL_FRAME, digest=combat_digest(state)),
                max_tokens=90,
            )
    except Exception:
        return distill_fight(state)
    facts = parse_bullets(reply)
    if not facts:
        return distill_fight(state)
    return [(fact, _tags_for(fact, state)) for fact in facts]


def distill_fight(state: CombatState) -> list[tuple[str, set[str]]]:
    """Deterministic shard extraction from a finished fight's event log.

    (An LLM pass can phrase these better later; facts first.)
    """
    plays = Counter(
        e.data["card"] for e in state.events if e.kind == "played" and e.data.get("player")
    )
    facts: list[tuple[str, set[str]]] = []
    if plays:
        favorite, n = plays.most_common(1)[0]
        if n >= 2:
            facts.append((f"the player leans on {favorite} ({n} plays)", {"deck", favorite}))
    sacrifices = sum(1 for e in state.events if e.kind == "sacrificed")
    if sacrifices >= 4:
        facts.append((f"the player kills their own freely ({sacrifices} sacrifices)", {"style"}))
    if state.result is Result.PLAYER_WIN:
        margin = "barely" if state.turn >= 8 else "fast"
        facts.append((f"the player won {margin} (turn {state.turn + 1}, "
                      f"{state.overkill_cycles} overkill)", {"outcome"}))
    else:
        facts.append((f"the player lost on turn {state.turn + 1}", {"outcome"}))
    return facts