Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from api.classify.banks import EmbeddingBank, LabelSetBank | |
| from api.label_sets.hash import stable_hash | |
| from api.classify.results import ClassificationResult, StageTimings | |
| from api.label_sets.schemas import LabelSet | |
| class FakeClipStore: | |
| """ | |
| No torch, no transformers. Builds deterministic banks from label-set payload. | |
| """ | |
| def build_bank(self, label_set: LabelSet) -> LabelSetBank: | |
| payload = label_set.model_dump() | |
| h = stable_hash(payload) | |
| dom_ids = tuple(d.id for d in label_set.domains) | |
| # feats unused by fake classifier; keep minimal placeholder | |
| domains = EmbeddingBank(ids=dom_ids, feats=None) # type: ignore[arg-type] | |
| labels_by_domain: dict[str, EmbeddingBank] = {} | |
| for domain_id, items in label_set.labels_by_domain.items(): | |
| ids = tuple(x.id for x in items) | |
| labels_by_domain[domain_id] = EmbeddingBank(ids=ids, feats=None) # type: ignore[arg-type] | |
| return LabelSetBank(label_set_hash=h, name=label_set.name, domains=domains, labels_by_domain=labels_by_domain) | |
| class FakeTwoStageClassifier: | |
| """ | |
| Deterministic output: picks first N domains and first K labels from those domains. | |
| """ | |
| def classify(self, bank: LabelSetBank, image, *, domain_top_n: int, top_k: int) -> ClassificationResult: | |
| chosen_domains = list(bank.domains.ids[:domain_top_n]) | |
| domain_hits = [(d, 1.0) for d in chosen_domains] | |
| label_ids: list[str] = [] | |
| for d in chosen_domains: | |
| b = bank.labels_by_domain.get(d) | |
| if b: | |
| label_ids.extend(list(b.ids)) | |
| label_hits = [(lid, 1.0) for lid in label_ids[:top_k]] | |
| return ClassificationResult( | |
| domain_hits=domain_hits, | |
| chosen_domains=chosen_domains, | |
| label_hits=label_hits, | |
| timings=StageTimings(total_ms=1, domain_ms=1, labels_ms=0), | |
| ) | |