puck / server /recognizer.py
vu1n's picture
Puck β€” desktop fairy familiar (HF Build Small)
3c124f3
Raw
History Blame Contribute Delete
4.73 kB
"""Visual fingerprint recognizer β€” a CLIP image encoder in ONNX, run in the daemon
(no torch / no sidecar; onnxruntime works on Python 3.14). Embeds the labeled refs
in recognize/refs/ and matches a peek crop by cosine similarity β†’ a tool/site label.
Refs are input-area crops for terminals (see recognize/README), so a peek of plain
scrollback stays BELOW threshold β†’ "unknown" rather than confusing claude-code with
codex. Build the index once (downloads ~350MB model first time), then recognize()
is fast and offline.
"""
import base64
import io
import os
import urllib.request
from pathlib import Path
import numpy as np
from PIL import Image
ROOT = Path(__file__).resolve().parent.parent
REFS = ROOT / "recognize" / "refs"
MODEL_DIR = ROOT / "recognize" / "model"
MODEL_PATH = MODEL_DIR / "clip-vit-b32-vision.onnx"
INDEX_PATH = MODEL_DIR / "index.npz"
# Qdrant's ONNX export of CLIP ViT-B/32's image encoder (512-d). Image-only, no torch.
MODEL_URL = "https://huggingface.co/Qdrant/clip-ViT-B-32-vision/resolve/main/model.onnx"
# standard CLIP preprocessing constants
_MEAN = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32)
_STD = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32)
# cosine floor: below this a match is "unknown" (CLIP UI shots run high; tune via env)
THRESHOLD = float(os.environ.get("PUCK_RECOGNIZE_THRESHOLD", "0.82"))
_session = None
_in_name = ""
_out_name = ""
_index: tuple[np.ndarray, np.ndarray] | None = None # (vectors[N,512], labels[N])
def _ensure_model() -> None:
global _session, _in_name, _out_name
if _session is not None:
return
import onnxruntime as ort
MODEL_DIR.mkdir(parents=True, exist_ok=True)
if not MODEL_PATH.exists():
print(f"recognizer: downloading model β†’ {MODEL_PATH} (~350MB, one time)…")
req = urllib.request.Request(MODEL_URL, headers={"User-Agent": "puck/1.0"})
with urllib.request.urlopen(req) as r, MODEL_PATH.open("wb") as f:
f.write(r.read())
_session = ort.InferenceSession(str(MODEL_PATH), providers=["CPUExecutionProvider"])
_in_name = _session.get_inputs()[0].name
_out_name = _session.get_outputs()[0].name
def _preprocess(img: Image.Image) -> np.ndarray:
img = img.convert("RGB")
w, h = img.size
s = 224 / min(w, h)
img = img.resize((max(224, round(w * s)), max(224, round(h * s))), Image.BICUBIC)
w, h = img.size
left, top = (w - 224) // 2, (h - 224) // 2
img = img.crop((left, top, left + 224, top + 224))
arr = (np.asarray(img, dtype=np.float32) / 255.0 - _MEAN) / _STD
return arr.transpose(2, 0, 1)[None].astype(np.float32) # [1,3,224,224]
def _embed(img: Image.Image) -> np.ndarray:
_ensure_model()
out = _session.run([_out_name], {_in_name: _preprocess(img)})[0][0].astype(np.float32)
return out / (np.linalg.norm(out) + 1e-8)
def build_index() -> dict:
"""Embed every labeled ref β†’ cache vectors + labels. Call after collecting/labeling."""
global _index
vecs: list[np.ndarray] = []
labels: list[str] = []
if REFS.exists():
for d in sorted(p for p in REFS.iterdir() if p.is_dir()):
for f in d.iterdir():
if f.suffix.lower() in (".png", ".jpg", ".jpeg"):
try:
vecs.append(_embed(Image.open(f)))
labels.append(d.name)
except Exception as e: # noqa: BLE001 β€” skip a bad image, keep building
print(f"recognizer: skip {f.name}: {e}")
if vecs:
MODEL_DIR.mkdir(parents=True, exist_ok=True)
np.savez(INDEX_PATH, vectors=np.stack(vecs), labels=np.array(labels))
_index = (np.stack(vecs), np.array(labels))
else:
_index = None
return {"refs": len(vecs), "labels": sorted(set(labels))}
def _load_index() -> tuple[np.ndarray, np.ndarray] | None:
global _index
if _index is None and INDEX_PATH.exists():
d = np.load(INDEX_PATH, allow_pickle=True)
_index = (d["vectors"], d["labels"])
return _index
def recognize(image_data_url: str) -> tuple[str | None, float]:
"""Match a crop against the fingerprint library β†’ (label, score) or (None, score).
Nearest reference by cosine; below THRESHOLD β†’ unknown (so scrollback doesn't lie)."""
idx = _load_index()
if idx is None:
return None, 0.0
vectors, labels = idx
_, _, b64 = image_data_url.partition(",")
q = _embed(Image.open(io.BytesIO(base64.b64decode(b64))))
sims = vectors @ q # both unit-normalized β†’ cosine
i = int(np.argmax(sims))
score = float(sims[i])
return (str(labels[i]), score) if score >= THRESHOLD else (None, score)