Spaces:
Sleeping
Sleeping
File size: 2,578 Bytes
0f38bd9 68f48a7 0f38bd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from __future__ import annotations
import time
from dataclasses import dataclass
import torch
from api.classify.banks import EmbeddingBank, LabelSetBank
from api.model.clip_store import ClipStore
from api.classify.results import ClassificationResult, StageTimings
@dataclass(slots=True)
class ClipScorer:
scale: float
def probs(self, image_feat: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
logits = (image_feat @ text_feats.T) * self.scale
return torch.softmax(logits - logits.max(), dim=-1)
@dataclass(slots=True)
class TwoStageClassifier:
store: ClipStore
def classify(self, bank: LabelSetBank, image, *, domain_top_n: int, top_k: int) -> ClassificationResult:
t0 = time.time()
image_feat = self.store.encode_image(image)
scorer = ClipScorer(scale=self.store.logit_scale())
t_dom = time.time()
domain_probs = scorer.probs(image_feat, bank.domains.feats)
domain_hits, chosen_domains = self._top_hits(bank.domains.ids, domain_probs, k=domain_top_n)
domain_ms = int((time.time() - t_dom) * 1000)
t_lab = time.time()
labels_bank = self._merge_label_banks(bank, chosen_domains)
label_hits: list[tuple[str, float]] = []
if labels_bank is not None:
label_probs = scorer.probs(image_feat, labels_bank.feats)
label_hits, _ = self._top_hits(labels_bank.ids, label_probs, k=top_k)
labels_ms = int((time.time() - t_lab) * 1000)
total_ms = int((time.time() - t0) * 1000)
return ClassificationResult(
domain_hits=domain_hits,
chosen_domains=chosen_domains,
label_hits=label_hits,
timings=StageTimings(total_ms=total_ms, domain_ms=domain_ms, labels_ms=labels_ms),
)
@staticmethod
def _top_hits(ids: tuple[str, ...], probs: torch.Tensor, *, k: int):
k = min(k, probs.numel())
values, indices = torch.topk(probs, k)
hits = [(ids[i], float(values[j])) for j, i in enumerate(indices.tolist())]
chosen = [i for i, _ in hits]
return hits, chosen
@staticmethod
def _merge_label_banks(bank: LabelSetBank, chosen_domains: list[str]) -> EmbeddingBank | None:
banks = [bank.labels_by_domain[d] for d in chosen_domains if d in bank.labels_by_domain]
if not banks:
return None
merged_ids = tuple(x for b in banks for x in b.ids)
merged_feats = torch.cat([b.feats for b in banks], dim=0)
return EmbeddingBank(ids=merged_ids, feats=merged_feats)
|