Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import time | |
| from dataclasses import dataclass | |
| import torch | |
| from api.classify.banks import EmbeddingBank, LabelSetBank | |
| from api.model.clip_store import ClipStore | |
| from api.classify.results import ClassificationResult, StageTimings | |
| class ClipScorer: | |
| scale: float | |
| def probs(self, image_feat: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor: | |
| logits = (image_feat @ text_feats.T) * self.scale | |
| return torch.softmax(logits - logits.max(), dim=-1) | |
| class TwoStageClassifier: | |
| store: ClipStore | |
| def classify(self, bank: LabelSetBank, image, *, domain_top_n: int, top_k: int) -> ClassificationResult: | |
| t0 = time.time() | |
| image_feat = self.store.encode_image(image) | |
| scorer = ClipScorer(scale=self.store.logit_scale()) | |
| t_dom = time.time() | |
| domain_probs = scorer.probs(image_feat, bank.domains.feats) | |
| domain_hits, chosen_domains = self._top_hits(bank.domains.ids, domain_probs, k=domain_top_n) | |
| domain_ms = int((time.time() - t_dom) * 1000) | |
| t_lab = time.time() | |
| labels_bank = self._merge_label_banks(bank, chosen_domains) | |
| label_hits: list[tuple[str, float]] = [] | |
| if labels_bank is not None: | |
| label_probs = scorer.probs(image_feat, labels_bank.feats) | |
| label_hits, _ = self._top_hits(labels_bank.ids, label_probs, k=top_k) | |
| labels_ms = int((time.time() - t_lab) * 1000) | |
| total_ms = int((time.time() - t0) * 1000) | |
| return ClassificationResult( | |
| domain_hits=domain_hits, | |
| chosen_domains=chosen_domains, | |
| label_hits=label_hits, | |
| timings=StageTimings(total_ms=total_ms, domain_ms=domain_ms, labels_ms=labels_ms), | |
| ) | |
| def _top_hits(ids: tuple[str, ...], probs: torch.Tensor, *, k: int): | |
| k = min(k, probs.numel()) | |
| values, indices = torch.topk(probs, k) | |
| hits = [(ids[i], float(values[j])) for j, i in enumerate(indices.tolist())] | |
| chosen = [i for i, _ in hits] | |
| return hits, chosen | |
| def _merge_label_banks(bank: LabelSetBank, chosen_domains: list[str]) -> EmbeddingBank | None: | |
| banks = [bank.labels_by_domain[d] for d in chosen_domains if d in bank.labels_by_domain] | |
| if not banks: | |
| return None | |
| merged_ids = tuple(x for b in banks for x in b.ids) | |
| merged_feats = torch.cat([b.feats for b in banks], dim=0) | |
| return EmbeddingBank(ids=merged_ids, feats=merged_feats) | |