File size: 2,578 Bytes
0f38bd9
 
 
 
 
 
 
68f48a7
 
 
0f38bd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from __future__ import annotations

import time
from dataclasses import dataclass

import torch

from api.classify.banks import EmbeddingBank, LabelSetBank
from api.model.clip_store import ClipStore
from api.classify.results import ClassificationResult, StageTimings


@dataclass(slots=True)
class ClipScorer:
    scale: float

    def probs(self, image_feat: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
        logits = (image_feat @ text_feats.T) * self.scale
        return torch.softmax(logits - logits.max(), dim=-1)


@dataclass(slots=True)
class TwoStageClassifier:
    store: ClipStore

    def classify(self, bank: LabelSetBank, image, *, domain_top_n: int, top_k: int) -> ClassificationResult:
        t0 = time.time()

        image_feat = self.store.encode_image(image)
        scorer = ClipScorer(scale=self.store.logit_scale())

        t_dom = time.time()
        domain_probs = scorer.probs(image_feat, bank.domains.feats)
        domain_hits, chosen_domains = self._top_hits(bank.domains.ids, domain_probs, k=domain_top_n)
        domain_ms = int((time.time() - t_dom) * 1000)

        t_lab = time.time()
        labels_bank = self._merge_label_banks(bank, chosen_domains)
        label_hits: list[tuple[str, float]] = []
        if labels_bank is not None:
            label_probs = scorer.probs(image_feat, labels_bank.feats)
            label_hits, _ = self._top_hits(labels_bank.ids, label_probs, k=top_k)
        labels_ms = int((time.time() - t_lab) * 1000)

        total_ms = int((time.time() - t0) * 1000)
        return ClassificationResult(
            domain_hits=domain_hits,
            chosen_domains=chosen_domains,
            label_hits=label_hits,
            timings=StageTimings(total_ms=total_ms, domain_ms=domain_ms, labels_ms=labels_ms),
        )

    @staticmethod
    def _top_hits(ids: tuple[str, ...], probs: torch.Tensor, *, k: int):
        k = min(k, probs.numel())
        values, indices = torch.topk(probs, k)
        hits = [(ids[i], float(values[j])) for j, i in enumerate(indices.tolist())]
        chosen = [i for i, _ in hits]
        return hits, chosen

    @staticmethod
    def _merge_label_banks(bank: LabelSetBank, chosen_domains: list[str]) -> EmbeddingBank | None:
        banks = [bank.labels_by_domain[d] for d in chosen_domains if d in bank.labels_by_domain]
        if not banks:
            return None
        merged_ids = tuple(x for b in banks for x in b.ids)
        merged_feats = torch.cat([b.feats for b in banks], dim=0)
        return EmbeddingBank(ids=merged_ids, feats=merged_feats)