File size: 4,729 Bytes
3c124f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""Visual fingerprint recognizer β€” a CLIP image encoder in ONNX, run in the daemon
(no torch / no sidecar; onnxruntime works on Python 3.14). Embeds the labeled refs
in recognize/refs/ and matches a peek crop by cosine similarity β†’ a tool/site label.

Refs are input-area crops for terminals (see recognize/README), so a peek of plain
scrollback stays BELOW threshold β†’ "unknown" rather than confusing claude-code with
codex. Build the index once (downloads ~350MB model first time), then recognize()
is fast and offline.
"""

import base64
import io
import os
import urllib.request
from pathlib import Path

import numpy as np
from PIL import Image

ROOT = Path(__file__).resolve().parent.parent
REFS = ROOT / "recognize" / "refs"
MODEL_DIR = ROOT / "recognize" / "model"
MODEL_PATH = MODEL_DIR / "clip-vit-b32-vision.onnx"
INDEX_PATH = MODEL_DIR / "index.npz"
# Qdrant's ONNX export of CLIP ViT-B/32's image encoder (512-d). Image-only, no torch.
MODEL_URL = "https://huggingface.co/Qdrant/clip-ViT-B-32-vision/resolve/main/model.onnx"

# standard CLIP preprocessing constants
_MEAN = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32)
_STD = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32)
# cosine floor: below this a match is "unknown" (CLIP UI shots run high; tune via env)
THRESHOLD = float(os.environ.get("PUCK_RECOGNIZE_THRESHOLD", "0.82"))

_session = None
_in_name = ""
_out_name = ""
_index: tuple[np.ndarray, np.ndarray] | None = None  # (vectors[N,512], labels[N])


def _ensure_model() -> None:
    global _session, _in_name, _out_name
    if _session is not None:
        return
    import onnxruntime as ort

    MODEL_DIR.mkdir(parents=True, exist_ok=True)
    if not MODEL_PATH.exists():
        print(f"recognizer: downloading model β†’ {MODEL_PATH} (~350MB, one time)…")
        req = urllib.request.Request(MODEL_URL, headers={"User-Agent": "puck/1.0"})
        with urllib.request.urlopen(req) as r, MODEL_PATH.open("wb") as f:
            f.write(r.read())
    _session = ort.InferenceSession(str(MODEL_PATH), providers=["CPUExecutionProvider"])
    _in_name = _session.get_inputs()[0].name
    _out_name = _session.get_outputs()[0].name


def _preprocess(img: Image.Image) -> np.ndarray:
    img = img.convert("RGB")
    w, h = img.size
    s = 224 / min(w, h)
    img = img.resize((max(224, round(w * s)), max(224, round(h * s))), Image.BICUBIC)
    w, h = img.size
    left, top = (w - 224) // 2, (h - 224) // 2
    img = img.crop((left, top, left + 224, top + 224))
    arr = (np.asarray(img, dtype=np.float32) / 255.0 - _MEAN) / _STD
    return arr.transpose(2, 0, 1)[None].astype(np.float32)  # [1,3,224,224]


def _embed(img: Image.Image) -> np.ndarray:
    _ensure_model()
    out = _session.run([_out_name], {_in_name: _preprocess(img)})[0][0].astype(np.float32)
    return out / (np.linalg.norm(out) + 1e-8)


def build_index() -> dict:
    """Embed every labeled ref β†’ cache vectors + labels. Call after collecting/labeling."""
    global _index
    vecs: list[np.ndarray] = []
    labels: list[str] = []
    if REFS.exists():
        for d in sorted(p for p in REFS.iterdir() if p.is_dir()):
            for f in d.iterdir():
                if f.suffix.lower() in (".png", ".jpg", ".jpeg"):
                    try:
                        vecs.append(_embed(Image.open(f)))
                        labels.append(d.name)
                    except Exception as e:  # noqa: BLE001 β€” skip a bad image, keep building
                        print(f"recognizer: skip {f.name}: {e}")
    if vecs:
        MODEL_DIR.mkdir(parents=True, exist_ok=True)
        np.savez(INDEX_PATH, vectors=np.stack(vecs), labels=np.array(labels))
        _index = (np.stack(vecs), np.array(labels))
    else:
        _index = None
    return {"refs": len(vecs), "labels": sorted(set(labels))}


def _load_index() -> tuple[np.ndarray, np.ndarray] | None:
    global _index
    if _index is None and INDEX_PATH.exists():
        d = np.load(INDEX_PATH, allow_pickle=True)
        _index = (d["vectors"], d["labels"])
    return _index


def recognize(image_data_url: str) -> tuple[str | None, float]:
    """Match a crop against the fingerprint library β†’ (label, score) or (None, score).
    Nearest reference by cosine; below THRESHOLD β†’ unknown (so scrollback doesn't lie)."""
    idx = _load_index()
    if idx is None:
        return None, 0.0
    vectors, labels = idx
    _, _, b64 = image_data_url.partition(",")
    q = _embed(Image.open(io.BytesIO(base64.b64decode(b64))))
    sims = vectors @ q  # both unit-normalized β†’ cosine
    i = int(np.argmax(sims))
    score = float(sims[i])
    return (str(labels[i]), score) if score >= THRESHOLD else (None, score)