from __future__ import annotations import warnings import torch from transformers import CLIPModel, CLIPProcessor from api.classify.banks import EmbeddingBank, LabelSetBank from api.label_sets.hash import stable_hash from api.label_sets.schemas import LabelSet from api.common.settings import settings class ClipStore: """ Infra: holds CLIP model + processor and can build embedding banks from label sets. CPU-only for HF CPU Spaces. """ def __init__(self) -> None: self.device = torch.device("cpu") self.model = CLIPModel.from_pretrained(settings.clip_model_id).to(self.device).eval() with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="`clean_up_tokenization_spaces` was not set", category=FutureWarning, module="transformers.tokenization_utils_base", ) self.processor = CLIPProcessor.from_pretrained(settings.clip_model_id) if hasattr(self.processor, "tokenizer") and hasattr(self.processor.tokenizer, "clean_up_tokenization_spaces"): self.processor.tokenizer.clean_up_tokenization_spaces = True def build_bank(self, label_set: LabelSet) -> LabelSetBank: payload = label_set.model_dump() label_set_hash = stable_hash(payload) domain_ids = tuple(d.id for d in label_set.domains) domain_prompts = [d.prompt for d in label_set.domains] domain_feats = self._encode_text(domain_prompts) labels_by_domain: dict[str, EmbeddingBank] = {} for domain_id, items in label_set.labels_by_domain.items(): ids = tuple(x.id for x in items) prompts = [x.prompt for x in items] feats = self._encode_text(prompts) labels_by_domain[domain_id] = EmbeddingBank(ids=ids, feats=feats) return LabelSetBank( label_set_hash=label_set_hash, name=label_set.name, domains=EmbeddingBank(ids=domain_ids, feats=domain_feats), labels_by_domain=labels_by_domain, ) def encode_image(self, image) -> torch.Tensor: inputs = self.processor(images=image, return_tensors="pt") with torch.no_grad(): feat = self.model.get_image_features(**{k: v.to(self.device) for k, v in inputs.items()}) feat = feat / feat.norm(dim=-1, keepdim=True) return feat[0] # (D,) def logit_scale(self) -> float: return float(self.model.logit_scale.exp().item()) def _encode_text(self, prompts: list[str]) -> torch.Tensor: inputs = self.processor(text=prompts, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): feats = self.model.get_text_features(**{k: v.to(self.device) for k, v in inputs.items()}) feats = feats / feats.norm(dim=-1, keepdim=True) return feats # (N, D)