multimodal-search / core /embeddings.py
YoungjaeDev's picture
Upload folder using huggingface_hub
2e15a8b verified
"""SigLIP 2 embedding model wrapper."""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModel, AutoProcessor
if TYPE_CHECKING:
from PIL import Image
class EmbeddingModel:
"""SigLIP 2 embedding model for text and image encoding.
Model: google/siglip2-so400m-patch14-384
Dimension: 1152
"""
MODEL_ID = "google/siglip2-so400m-patch14-384"
EMBEDDING_DIM = 1152
def __init__(self, device: str = "cpu") -> None:
"""Initialize the embedding model.
Args:
device: Device to run the model on ('cpu' or 'cuda').
"""
self.device = device
self.model = None
self.processor = None
def load(self) -> None:
"""Load the model and processor."""
self.processor = AutoProcessor.from_pretrained(self.MODEL_ID)
self.model = AutoModel.from_pretrained(self.MODEL_ID)
self.model.to(self.device)
# Set model to evaluation mode (disable dropout, etc.)
self.model.train(False)
def _ensure_loaded(self) -> None:
"""Ensure model is loaded before inference."""
if self.model is None or self.processor is None:
self.load()
def encode_image(self, image: Image.Image) -> np.ndarray:
"""Encode a single image to embedding vector.
Args:
image: PIL Image to encode.
Returns:
Normalized embedding vector of shape (1152,).
"""
self._ensure_loaded()
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
features = self.model.get_image_features(**inputs)
features = F.normalize(features, dim=-1)
return features.cpu().numpy().squeeze(0)
def encode_images(
self,
images: list[Image.Image],
batch_size: int = 32,
show_progress: bool = True,
) -> np.ndarray:
"""Encode multiple images to embedding vectors.
Args:
images: List of PIL Images to encode.
batch_size: Batch size for processing.
show_progress: Show progress bar.
Returns:
Normalized embedding vectors of shape (N, 1152).
"""
if not images:
return np.empty((0, self.EMBEDDING_DIM), dtype=np.float32)
self._ensure_loaded()
all_embeddings = []
iterator = range(0, len(images), batch_size)
if show_progress:
iterator = tqdm(iterator, desc="Encoding images", unit="batch")
for i in iterator:
batch_images = images[i : i + batch_size]
inputs = self.processor(images=batch_images, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
features = self.model.get_image_features(**inputs)
features = F.normalize(features, dim=-1)
all_embeddings.append(features.cpu().numpy())
return np.concatenate(all_embeddings, axis=0)
def encode_text(self, text: str) -> np.ndarray:
"""Encode a single text to embedding vector.
Args:
text: Text string to encode.
Returns:
Normalized embedding vector of shape (1152,).
"""
self._ensure_loaded()
# SigLIP requires padding="max_length" as trained
inputs = self.processor(
text=text,
padding="max_length",
truncation=True,
return_tensors="pt",
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
features = self.model.get_text_features(**inputs)
features = F.normalize(features, dim=-1)
return features.cpu().numpy().squeeze(0)
def encode_texts(
self,
texts: list[str],
batch_size: int = 32,
show_progress: bool = True,
) -> np.ndarray:
"""Encode multiple texts to embedding vectors.
Args:
texts: List of text strings to encode.
batch_size: Batch size for processing.
show_progress: Show progress bar.
Returns:
Normalized embedding vectors of shape (N, 1152).
"""
if not texts:
return np.empty((0, self.EMBEDDING_DIM), dtype=np.float32)
self._ensure_loaded()
all_embeddings = []
iterator = range(0, len(texts), batch_size)
if show_progress:
iterator = tqdm(iterator, desc="Encoding texts", unit="batch")
for i in iterator:
batch_texts = texts[i : i + batch_size]
inputs = self.processor(
text=batch_texts,
padding="max_length",
truncation=True,
return_tensors="pt",
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
features = self.model.get_text_features(**inputs)
features = F.normalize(features, dim=-1)
all_embeddings.append(features.cpu().numpy())
return np.concatenate(all_embeddings, axis=0)