from __future__ import annotations import time from dataclasses import dataclass import torch from api.classify.banks import EmbeddingBank, LabelSetBank from api.model.clip_store import ClipStore from api.classify.results import ClassificationResult, StageTimings @dataclass(slots=True) class ClipScorer: scale: float def probs(self, image_feat: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor: logits = (image_feat @ text_feats.T) * self.scale return torch.softmax(logits - logits.max(), dim=-1) @dataclass(slots=True) class TwoStageClassifier: store: ClipStore def classify(self, bank: LabelSetBank, image, *, domain_top_n: int, top_k: int) -> ClassificationResult: t0 = time.time() image_feat = self.store.encode_image(image) scorer = ClipScorer(scale=self.store.logit_scale()) t_dom = time.time() domain_probs = scorer.probs(image_feat, bank.domains.feats) domain_hits, chosen_domains = self._top_hits(bank.domains.ids, domain_probs, k=domain_top_n) domain_ms = int((time.time() - t_dom) * 1000) t_lab = time.time() labels_bank = self._merge_label_banks(bank, chosen_domains) label_hits: list[tuple[str, float]] = [] if labels_bank is not None: label_probs = scorer.probs(image_feat, labels_bank.feats) label_hits, _ = self._top_hits(labels_bank.ids, label_probs, k=top_k) labels_ms = int((time.time() - t_lab) * 1000) total_ms = int((time.time() - t0) * 1000) return ClassificationResult( domain_hits=domain_hits, chosen_domains=chosen_domains, label_hits=label_hits, timings=StageTimings(total_ms=total_ms, domain_ms=domain_ms, labels_ms=labels_ms), ) @staticmethod def _top_hits(ids: tuple[str, ...], probs: torch.Tensor, *, k: int): k = min(k, probs.numel()) values, indices = torch.topk(probs, k) hits = [(ids[i], float(values[j])) for j, i in enumerate(indices.tolist())] chosen = [i for i, _ in hits] return hits, chosen @staticmethod def _merge_label_banks(bank: LabelSetBank, chosen_domains: list[str]) -> EmbeddingBank | None: banks = [bank.labels_by_domain[d] for d in chosen_domains if d in bank.labels_by_domain] if not banks: return None merged_ids = tuple(x for b in banks for x in b.ids) merged_feats = torch.cat([b.feats for b in banks], dim=0) return EmbeddingBank(ids=merged_ids, feats=merged_feats)