esandorfi's picture
Domain features first reorganisation
68f48a7
from __future__ import annotations
from dataclasses import dataclass
from api.classify.banks import EmbeddingBank, LabelSetBank
from api.label_sets.hash import stable_hash
from api.classify.results import ClassificationResult, StageTimings
from api.label_sets.schemas import LabelSet
class FakeClipStore:
"""
No torch, no transformers. Builds deterministic banks from label-set payload.
"""
def build_bank(self, label_set: LabelSet) -> LabelSetBank:
payload = label_set.model_dump()
h = stable_hash(payload)
dom_ids = tuple(d.id for d in label_set.domains)
# feats unused by fake classifier; keep minimal placeholder
domains = EmbeddingBank(ids=dom_ids, feats=None) # type: ignore[arg-type]
labels_by_domain: dict[str, EmbeddingBank] = {}
for domain_id, items in label_set.labels_by_domain.items():
ids = tuple(x.id for x in items)
labels_by_domain[domain_id] = EmbeddingBank(ids=ids, feats=None) # type: ignore[arg-type]
return LabelSetBank(label_set_hash=h, name=label_set.name, domains=domains, labels_by_domain=labels_by_domain)
@dataclass(slots=True)
class FakeTwoStageClassifier:
"""
Deterministic output: picks first N domains and first K labels from those domains.
"""
def classify(self, bank: LabelSetBank, image, *, domain_top_n: int, top_k: int) -> ClassificationResult:
chosen_domains = list(bank.domains.ids[:domain_top_n])
domain_hits = [(d, 1.0) for d in chosen_domains]
label_ids: list[str] = []
for d in chosen_domains:
b = bank.labels_by_domain.get(d)
if b:
label_ids.extend(list(b.ids))
label_hits = [(lid, 1.0) for lid in label_ids[:top_k]]
return ClassificationResult(
domain_hits=domain_hits,
chosen_domains=chosen_domains,
label_hits=label_hits,
timings=StageTimings(total_ms=1, domain_ms=1, labels_ms=0),
)