Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import warnings | |
| import torch | |
| from transformers import CLIPModel, CLIPProcessor | |
| from api.classify.banks import EmbeddingBank, LabelSetBank | |
| from api.label_sets.hash import stable_hash | |
| from api.label_sets.schemas import LabelSet | |
| from api.common.settings import settings | |
| class ClipStore: | |
| """ | |
| Infra: holds CLIP model + processor and can build embedding banks from label sets. | |
| CPU-only for HF CPU Spaces. | |
| """ | |
| def __init__(self) -> None: | |
| self.device = torch.device("cpu") | |
| self.model = CLIPModel.from_pretrained(settings.clip_model_id).to(self.device).eval() | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings( | |
| "ignore", | |
| message="`clean_up_tokenization_spaces` was not set", | |
| category=FutureWarning, | |
| module="transformers.tokenization_utils_base", | |
| ) | |
| self.processor = CLIPProcessor.from_pretrained(settings.clip_model_id) | |
| if hasattr(self.processor, "tokenizer") and hasattr(self.processor.tokenizer, "clean_up_tokenization_spaces"): | |
| self.processor.tokenizer.clean_up_tokenization_spaces = True | |
| def build_bank(self, label_set: LabelSet) -> LabelSetBank: | |
| payload = label_set.model_dump() | |
| label_set_hash = stable_hash(payload) | |
| domain_ids = tuple(d.id for d in label_set.domains) | |
| domain_prompts = [d.prompt for d in label_set.domains] | |
| domain_feats = self._encode_text(domain_prompts) | |
| labels_by_domain: dict[str, EmbeddingBank] = {} | |
| for domain_id, items in label_set.labels_by_domain.items(): | |
| ids = tuple(x.id for x in items) | |
| prompts = [x.prompt for x in items] | |
| feats = self._encode_text(prompts) | |
| labels_by_domain[domain_id] = EmbeddingBank(ids=ids, feats=feats) | |
| return LabelSetBank( | |
| label_set_hash=label_set_hash, | |
| name=label_set.name, | |
| domains=EmbeddingBank(ids=domain_ids, feats=domain_feats), | |
| labels_by_domain=labels_by_domain, | |
| ) | |
| def encode_image(self, image) -> torch.Tensor: | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| feat = self.model.get_image_features(**{k: v.to(self.device) for k, v in inputs.items()}) | |
| feat = feat / feat.norm(dim=-1, keepdim=True) | |
| return feat[0] # (D,) | |
| def logit_scale(self) -> float: | |
| return float(self.model.logit_scale.exp().item()) | |
| def _encode_text(self, prompts: list[str]) -> torch.Tensor: | |
| inputs = self.processor(text=prompts, return_tensors="pt", padding=True, truncation=True) | |
| with torch.no_grad(): | |
| feats = self.model.get_text_features(**{k: v.to(self.device) for k, v in inputs.items()}) | |
| feats = feats / feats.norm(dim=-1, keepdim=True) | |
| return feats # (N, D) | |