esandorfi's picture
Domain features first reorganisation
68f48a7
from __future__ import annotations
import warnings
import torch
from transformers import CLIPModel, CLIPProcessor
from api.classify.banks import EmbeddingBank, LabelSetBank
from api.label_sets.hash import stable_hash
from api.label_sets.schemas import LabelSet
from api.common.settings import settings
class ClipStore:
"""
Infra: holds CLIP model + processor and can build embedding banks from label sets.
CPU-only for HF CPU Spaces.
"""
def __init__(self) -> None:
self.device = torch.device("cpu")
self.model = CLIPModel.from_pretrained(settings.clip_model_id).to(self.device).eval()
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="`clean_up_tokenization_spaces` was not set",
category=FutureWarning,
module="transformers.tokenization_utils_base",
)
self.processor = CLIPProcessor.from_pretrained(settings.clip_model_id)
if hasattr(self.processor, "tokenizer") and hasattr(self.processor.tokenizer, "clean_up_tokenization_spaces"):
self.processor.tokenizer.clean_up_tokenization_spaces = True
def build_bank(self, label_set: LabelSet) -> LabelSetBank:
payload = label_set.model_dump()
label_set_hash = stable_hash(payload)
domain_ids = tuple(d.id for d in label_set.domains)
domain_prompts = [d.prompt for d in label_set.domains]
domain_feats = self._encode_text(domain_prompts)
labels_by_domain: dict[str, EmbeddingBank] = {}
for domain_id, items in label_set.labels_by_domain.items():
ids = tuple(x.id for x in items)
prompts = [x.prompt for x in items]
feats = self._encode_text(prompts)
labels_by_domain[domain_id] = EmbeddingBank(ids=ids, feats=feats)
return LabelSetBank(
label_set_hash=label_set_hash,
name=label_set.name,
domains=EmbeddingBank(ids=domain_ids, feats=domain_feats),
labels_by_domain=labels_by_domain,
)
def encode_image(self, image) -> torch.Tensor:
inputs = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
feat = self.model.get_image_features(**{k: v.to(self.device) for k, v in inputs.items()})
feat = feat / feat.norm(dim=-1, keepdim=True)
return feat[0] # (D,)
def logit_scale(self) -> float:
return float(self.model.logit_scale.exp().item())
def _encode_text(self, prompts: list[str]) -> torch.Tensor:
inputs = self.processor(text=prompts, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
feats = self.model.get_text_features(**{k: v.to(self.device) for k, v in inputs.items()})
feats = feats / feats.norm(dim=-1, keepdim=True)
return feats # (N, D)