esandorfi commited on
Commit
0f38bd9
·
unverified ·
1 Parent(s): 591d38a

Implement ClipScorer and TwoStageClassifier classes

Browse files
Files changed (1) hide show
  1. clip_service.py +68 -0
clip_service.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from dataclasses import dataclass
5
+
6
+ import torch
7
+
8
+ from app.banks import EmbeddingBank, LabelSetBank
9
+ from app.clip_store import ClipStore
10
+ from app.results import ClassificationResult, StageTimings
11
+
12
+
13
+ @dataclass(slots=True)
14
+ class ClipScorer:
15
+ scale: float
16
+
17
+ def probs(self, image_feat: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
18
+ logits = (image_feat @ text_feats.T) * self.scale
19
+ return torch.softmax(logits - logits.max(), dim=-1)
20
+
21
+
22
+ @dataclass(slots=True)
23
+ class TwoStageClassifier:
24
+ store: ClipStore
25
+
26
+ def classify(self, bank: LabelSetBank, image, *, domain_top_n: int, top_k: int) -> ClassificationResult:
27
+ t0 = time.time()
28
+
29
+ image_feat = self.store.encode_image(image)
30
+ scorer = ClipScorer(scale=self.store.logit_scale())
31
+
32
+ t_dom = time.time()
33
+ domain_probs = scorer.probs(image_feat, bank.domains.feats)
34
+ domain_hits, chosen_domains = self._top_hits(bank.domains.ids, domain_probs, k=domain_top_n)
35
+ domain_ms = int((time.time() - t_dom) * 1000)
36
+
37
+ t_lab = time.time()
38
+ labels_bank = self._merge_label_banks(bank, chosen_domains)
39
+ label_hits: list[tuple[str, float]] = []
40
+ if labels_bank is not None:
41
+ label_probs = scorer.probs(image_feat, labels_bank.feats)
42
+ label_hits, _ = self._top_hits(labels_bank.ids, label_probs, k=top_k)
43
+ labels_ms = int((time.time() - t_lab) * 1000)
44
+
45
+ total_ms = int((time.time() - t0) * 1000)
46
+ return ClassificationResult(
47
+ domain_hits=domain_hits,
48
+ chosen_domains=chosen_domains,
49
+ label_hits=label_hits,
50
+ timings=StageTimings(total_ms=total_ms, domain_ms=domain_ms, labels_ms=labels_ms),
51
+ )
52
+
53
+ @staticmethod
54
+ def _top_hits(ids: tuple[str, ...], probs: torch.Tensor, *, k: int):
55
+ k = min(k, probs.numel())
56
+ values, indices = torch.topk(probs, k)
57
+ hits = [(ids[i], float(values[j])) for j, i in enumerate(indices.tolist())]
58
+ chosen = [i for i, _ in hits]
59
+ return hits, chosen
60
+
61
+ @staticmethod
62
+ def _merge_label_banks(bank: LabelSetBank, chosen_domains: list[str]) -> EmbeddingBank | None:
63
+ banks = [bank.labels_by_domain[d] for d in chosen_domains if d in bank.labels_by_domain]
64
+ if not banks:
65
+ return None
66
+ merged_ids = tuple(x for b in banks for x in b.ids)
67
+ merged_feats = torch.cat([b.feats for b in banks], dim=0)
68
+ return EmbeddingBank(ids=merged_ids, feats=merged_feats)