Spaces:
Sleeping
Sleeping
| """SigLIP 2 embedding model wrapper.""" | |
| from __future__ import annotations | |
| from typing import TYPE_CHECKING | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| from transformers import AutoModel, AutoProcessor | |
| if TYPE_CHECKING: | |
| from PIL import Image | |
| class EmbeddingModel: | |
| """SigLIP 2 embedding model for text and image encoding. | |
| Model: google/siglip2-so400m-patch14-384 | |
| Dimension: 1152 | |
| """ | |
| MODEL_ID = "google/siglip2-so400m-patch14-384" | |
| EMBEDDING_DIM = 1152 | |
| def __init__(self, device: str = "cpu") -> None: | |
| """Initialize the embedding model. | |
| Args: | |
| device: Device to run the model on ('cpu' or 'cuda'). | |
| """ | |
| self.device = device | |
| self.model = None | |
| self.processor = None | |
| def load(self) -> None: | |
| """Load the model and processor.""" | |
| self.processor = AutoProcessor.from_pretrained(self.MODEL_ID) | |
| self.model = AutoModel.from_pretrained(self.MODEL_ID) | |
| self.model.to(self.device) | |
| # Set model to evaluation mode (disable dropout, etc.) | |
| self.model.train(False) | |
| def _ensure_loaded(self) -> None: | |
| """Ensure model is loaded before inference.""" | |
| if self.model is None or self.processor is None: | |
| self.load() | |
| def encode_image(self, image: Image.Image) -> np.ndarray: | |
| """Encode a single image to embedding vector. | |
| Args: | |
| image: PIL Image to encode. | |
| Returns: | |
| Normalized embedding vector of shape (1152,). | |
| """ | |
| self._ensure_loaded() | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| features = self.model.get_image_features(**inputs) | |
| features = F.normalize(features, dim=-1) | |
| return features.cpu().numpy().squeeze(0) | |
| def encode_images( | |
| self, | |
| images: list[Image.Image], | |
| batch_size: int = 32, | |
| show_progress: bool = True, | |
| ) -> np.ndarray: | |
| """Encode multiple images to embedding vectors. | |
| Args: | |
| images: List of PIL Images to encode. | |
| batch_size: Batch size for processing. | |
| show_progress: Show progress bar. | |
| Returns: | |
| Normalized embedding vectors of shape (N, 1152). | |
| """ | |
| if not images: | |
| return np.empty((0, self.EMBEDDING_DIM), dtype=np.float32) | |
| self._ensure_loaded() | |
| all_embeddings = [] | |
| iterator = range(0, len(images), batch_size) | |
| if show_progress: | |
| iterator = tqdm(iterator, desc="Encoding images", unit="batch") | |
| for i in iterator: | |
| batch_images = images[i : i + batch_size] | |
| inputs = self.processor(images=batch_images, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| features = self.model.get_image_features(**inputs) | |
| features = F.normalize(features, dim=-1) | |
| all_embeddings.append(features.cpu().numpy()) | |
| return np.concatenate(all_embeddings, axis=0) | |
| def encode_text(self, text: str) -> np.ndarray: | |
| """Encode a single text to embedding vector. | |
| Args: | |
| text: Text string to encode. | |
| Returns: | |
| Normalized embedding vector of shape (1152,). | |
| """ | |
| self._ensure_loaded() | |
| # SigLIP requires padding="max_length" as trained | |
| inputs = self.processor( | |
| text=text, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| features = self.model.get_text_features(**inputs) | |
| features = F.normalize(features, dim=-1) | |
| return features.cpu().numpy().squeeze(0) | |
| def encode_texts( | |
| self, | |
| texts: list[str], | |
| batch_size: int = 32, | |
| show_progress: bool = True, | |
| ) -> np.ndarray: | |
| """Encode multiple texts to embedding vectors. | |
| Args: | |
| texts: List of text strings to encode. | |
| batch_size: Batch size for processing. | |
| show_progress: Show progress bar. | |
| Returns: | |
| Normalized embedding vectors of shape (N, 1152). | |
| """ | |
| if not texts: | |
| return np.empty((0, self.EMBEDDING_DIM), dtype=np.float32) | |
| self._ensure_loaded() | |
| all_embeddings = [] | |
| iterator = range(0, len(texts), batch_size) | |
| if show_progress: | |
| iterator = tqdm(iterator, desc="Encoding texts", unit="batch") | |
| for i in iterator: | |
| batch_texts = texts[i : i + batch_size] | |
| inputs = self.processor( | |
| text=batch_texts, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| features = self.model.get_text_features(**inputs) | |
| features = F.normalize(features, dim=-1) | |
| all_embeddings.append(features.cpu().numpy()) | |
| return np.concatenate(all_embeddings, axis=0) | |