File size: 2,403 Bytes
5850885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Pure-function Jaccard top-k retrieval over tag sets.

No embeddings, no torch, no tokenizer. Deterministic — same inputs
produce the same ranking, same top-k, same tie-break.

A conservative Jaccard threshold (0.3) limits retrieval noise when
broad pre-seeds would otherwise match every scenario.
"""

from __future__ import annotations

from collections.abc import Iterable

from skill_library.entries import (
    DriftAdaptationCard,
    PlaybookEntry,
    RetrievalResult,
)

JACCARD_MIN: float = 0.3


def jaccard(a: frozenset[str], b: frozenset[str]) -> float:
    """Standard Jaccard on sets."""
    if not a and not b:
        return 1.0
    if not a or not b:
        return 0.0
    return len(a & b) / len(a | b)


def top_k_playbook(
    query_tags: frozenset[str],
    entries: Iterable[PlaybookEntry],
    k: int = 3,
    *,
    min_overlap: float = JACCARD_MIN,
) -> tuple[PlaybookEntry, ...]:
    """Top-k playbook entries by Jaccard overlap with ``query_tags``.

    Ties broken by descending ``avg_speedup`` then by ``before_snippet``
    lexicographic order so the result is stable across runs.
    """
    scored = [(jaccard(query_tags, e.tag_set), e) for e in entries]
    scored = [(j, e) for j, e in scored if j >= min_overlap]
    scored.sort(key=lambda t: (-t[0], -t[1].avg_speedup, t[1].before_snippet))
    return tuple(e for _, e in scored[:k])


def top_k_drift_cards(
    drift_kind: str | None,
    cards: Iterable[DriftAdaptationCard],
    k: int = 1,
) -> tuple[DriftAdaptationCard, ...]:
    """Filter cards by exact drift_kind match, sorted by success_rate desc."""
    if drift_kind is None:
        return ()
    matches = [c for c in cards if c.drift_kind == drift_kind]
    matches.sort(key=lambda c: (-c.success_rate, c.drift_kind))
    return tuple(matches[:k])


def retrieve(
    query_tags: frozenset[str],
    drift_kind: str | None,
    playbook: Iterable[PlaybookEntry],
    drift_cards: Iterable[DriftAdaptationCard],
    *,
    playbook_k: int = 3,
    drift_k: int = 1,
) -> RetrievalResult:
    """Combined retrieval: top-k playbook + top-k drift cards."""
    return RetrievalResult(
        playbook=top_k_playbook(query_tags, playbook, k=playbook_k),
        drift_cards=top_k_drift_cards(drift_kind, drift_cards, k=drift_k),
    )


__all__ = [
    "JACCARD_MIN",
    "jaccard",
    "retrieve",
    "top_k_drift_cards",
    "top_k_playbook",
]