| |
|
|
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoModel, AutoProcessor, AutoConfig |
| from typing import List, Union |
| from PIL import Image |
| import requests |
| import io |
| import os |
| import traceback |
| import numpy as np |
|
|
| from transformers.utils import logging as hf_logging |
|
|
| |
| os.environ["TRANSFORMERS_VERBOSITY"] = "error" |
|
|
| class MultiModalEncoder: |
| """ |
| Encodes text OR images into a shared, NORMALIZED embedding space |
| using google/siglip-so400m-patch16-256-i18n. |
| This class is intended for creating embeddings for vector search. |
| """ |
|
|
| def __init__(self, model_id="google/siglip-so400m-patch16-256-i18n", dtype: torch.dtype = torch.bfloat16, device: str = None): |
| |
| hf_logging.set_verbosity_error() |
| hf_logging.disable_progress_bar() |
|
|
| self.model_id = model_id |
| if device: |
| self.device = device |
| else: |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| self.dtype = dtype |
|
|
|
|
| try: |
| hf_token = os.getenv("Hf_TOKEN") or os.getenv("HF_TOKEN") |
| hf_kwargs = {"token": hf_token} if hf_token else {} |
| |
| self.processor = AutoProcessor.from_pretrained( |
| self.model_id, |
| use_fast=True, |
| **hf_kwargs |
| ) |
|
|
| config = AutoConfig.from_pretrained(self.model_id, **hf_kwargs) |
|
|
| if not hasattr(config, 'projection_dim'): |
| |
| config.projection_dim = config.text_config.hidden_size |
|
|
| self.model = AutoModel.from_pretrained( |
| self.model_id, |
| config=config, |
| dtype=self.dtype, |
| trust_remote_code=False, |
| **hf_kwargs |
| ).to(self.device).eval() |
| |
|
|
| self.embedding_dim = config.projection_dim |
|
|
| except Exception as e: |
| print(f"❌ Failed to load SigLIP model or components: {e}") |
| traceback.print_exc() |
| raise |
|
|
| @torch.no_grad() |
| def __call__(self, x: Union[List[str], List[Image.Image]]) -> torch.Tensor: |
| """ |
| Encode a batch of text or images into normalized [batch_size, embedding_dim] vectors. |
| This is correct for storing in a vector DB for cosine similarity. |
| """ |
| if not x: |
| return torch.empty(0, self.embedding_dim).to(self.device) |
|
|
| is_text = isinstance(x[0], str) |
|
|
| autocast_dtype = self.dtype if self.dtype in [torch.float16, torch.bfloat16] else None |
| device_str = self.device.type if isinstance(self.device, torch.device) else self.device |
| |
| with torch.autocast(device_type=device_str, dtype=autocast_dtype, enabled=(autocast_dtype is not None)): |
| try: |
| if is_text: |
| inputs = self.processor(text=x, return_tensors="pt", padding=True, truncation=True).to(self.device) |
| embeddings = self.model.get_text_features(**inputs) |
| else: |
| |
| valid_images = [img.convert("RGB") for img in x] |
| inputs = self.processor(images=valid_images, return_tensors="pt").to(self.device) |
| embeddings = self.model.get_image_features(**inputs) |
| |
| |
| if not isinstance(embeddings, torch.Tensor): |
| if hasattr(embeddings, 'pooler_output'): |
| embeddings = embeddings.pooler_output |
| elif hasattr(embeddings, 'last_hidden_state'): |
| |
| |
| |
| |
| embeddings = embeddings.last_hidden_state |
| elif isinstance(embeddings, (tuple, list)): |
| embeddings = embeddings[0] |
| |
| |
| embeddings = F.normalize(embeddings.float(), p=2, dim=-1) |
|
|
| final_embeddings = embeddings.to(self.dtype) |
| return final_embeddings |
|
|
| except Exception as e: |
| |
| print(f"ERROR in MultiModalEncoder: {e}", flush=True) |
| traceback.print_exc() |
| return torch.empty(0, self.embedding_dim).to(self.device) |
|
|
| |
| if __name__ == "__main__": |
| |
| |
| MODEL_ID = "google/siglip-so400m-patch16-256-i18n" |
| print(f"\n--- MultiModalEncoder Test ({MODEL_ID}) ---") |
|
|
| texts = [ |
| "Uranus", |
| "Anus", |
| "Ass", |
| "Planet", |
| "Dog" |
| ] |
|
|
| try: |
| img_urls = [ |
| "https://pbs.twimg.com/media/G3ra9C8W0AAGR8V.jpg", |
| ] |
| headers = {"User-Agent": "Mozilla/5.0"} |
| images = [ |
| Image.open(io.BytesIO(requests.get(u, headers=headers).content)) |
| for u in img_urls |
| ] |
| |
| size = 256 |
| images.append(Image.new("RGB", (size, size), color="green")) |
| print(f"✅ Downloaded test image and created green square (size {size}x{size}).") |
|
|
| except Exception as e: |
| print(f"❌ Failed to load images: {e}") |
| traceback.print_exc() |
| exit() |
|
|
| try: |
| |
| encoder = MultiModalEncoder(model_id=MODEL_ID) |
|
|
| print("\n--- Encoding Texts (Separately) ---") |
| text_embeddings = encoder(texts) |
| print(f"Shape: {text_embeddings.shape}") |
|
|
| print("\n--- Encoding Images (Separately) ---") |
| image_embeddings = encoder(images) |
| print(f"Shape: {image_embeddings.shape}") |
|
|
|
|
| print("\n--- Similarity Check (Your Goal) ---") |
| |
| |
| |
| |
| similarity_matrix = torch.matmul(image_embeddings.cpu(), text_embeddings.cpu().T).numpy() |
|
|
| np.set_printoptions(precision=4, suppress=True) |
| print("\nCosine Similarity matrix (image × text):") |
| |
| |
| print(similarity_matrix) |
|
|
| print("\nSpecific Similarity Scores (Cosine Similarity, -1.0 to 1.0):") |
| print(f"Image 0 (Uranus pic) vs Text 0 (Uranus): {similarity_matrix[0][0]:.4f}") |
| print(f"Image 0 (Uranus pic) vs Text 1 (Anus): {similarity_matrix[0][1]:.4f}") |
| print(f"Image 0 (Uranus pic) vs Text 3 (Planet): {similarity_matrix[0][3]:.4f}") |
| print(f"Image 0 (Uranus pic) vs Text 4 (Dog): {similarity_matrix[0][4]:.4f}") |
| print(f"Image 1 (Green) vs Text 4 (Dog): {similarity_matrix[1][4]:.4f}") |
|
|
| except Exception as e: |
| print(f"\n--- An error occurred during the SigLIP test run ---") |
| print(f"Error: {e}") |
| traceback.print_exc() |
|
|