|
|
import torch
|
|
|
import numpy as np
|
|
|
from typing import List
|
|
|
from transformers import CLIPModel, CLIPProcessor
|
|
|
|
|
|
class CLIPEmbedder:
|
|
|
def __init__(self, model_name: str = "openai/clip-vit-base-patch32", device: str = "cpu"):
|
|
|
self.device = device
|
|
|
self.model_name = model_name
|
|
|
|
|
|
print(f"→ Loading CLIP model: {model_name}")
|
|
|
|
|
|
|
|
|
self.model = CLIPModel.from_pretrained(model_name).to(device)
|
|
|
self.processor = CLIPProcessor.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
print(f"✓ CLIP model loaded on {device}")
|
|
|
|
|
|
def encode_text(self, texts: List[str]) -> np.ndarray:
|
|
|
"""Encode text using CLIP"""
|
|
|
with torch.no_grad():
|
|
|
|
|
|
inputs = self.processor(
|
|
|
text=texts,
|
|
|
return_tensors="pt",
|
|
|
padding=True,
|
|
|
truncation=True,
|
|
|
max_length=77
|
|
|
).to(self.device)
|
|
|
|
|
|
|
|
|
text_features = self.model.get_text_features(**inputs)
|
|
|
|
|
|
|
|
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
|
|
|
|
|
return text_features.cpu().numpy()
|
|
|
|
|
|
def encode_single_text(self, text: str) -> np.ndarray:
|
|
|
"""Encode single text"""
|
|
|
return self.encode_text([text])[0]
|
|
|
|
|
|
def __call__(self, texts: List[str]) -> np.ndarray:
|
|
|
"""Make embedder callable"""
|
|
|
return self.encode_text(texts)
|
|
|
|