File size: 1,644 Bytes
6835659 c98d24c d7c075c 960dff6 d7c075c 960dff6 d7c075c c98d24c 960dff6 c98d24c 6835659 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | from __future__ import annotations
import numpy as np
import torch
from transformers import CLIPModel, CLIPProcessor
class TextEmbedder:
"""
CLIP projected text features (512-d).
"""
def __init__(
self,
model_name: str = "openai/clip-vit-base-patch32",
device: str = "cpu",
):
self.device = device
self.processor = CLIPProcessor.from_pretrained(model_name)
self.model = CLIPModel.from_pretrained(model_name).to(self.device)
self.model.eval()
@torch.no_grad()
def embed(self, text: str) -> np.ndarray:
inputs = self.processor(
text=[text],
return_tensors="pt",
padding=True,
truncation=True,
).to(self.device)
feats = self.model.get_text_features(**inputs)
# Handle different transformers versions
target_dim = getattr(self.model.config, "projection_dim", 512)
if not isinstance(feats, torch.Tensor):
pooled = feats.pooler_output
if pooled.shape[-1] != target_dim:
proj = getattr(self.model, "text_projection", None)
if proj is not None:
pooled = proj(pooled)
feats = pooled
if feats.dim() == 3:
pooled = feats[:, 0, :]
if pooled.shape[-1] != target_dim:
proj = getattr(self.model, "text_projection", None)
if proj is not None:
pooled = proj(pooled)
feats = pooled
if feats.dim() == 2:
feats = feats[0]
return feats.cpu().numpy().astype("float32")
|