# multi_modal_processor.py import torch import torch.nn.functional as F from transformers import AutoModel, AutoProcessor, AutoConfig from typing import List, Union from PIL import Image import requests import io import os import traceback import numpy as np from transformers.utils import logging as hf_logging # Suppress warnings os.environ["TRANSFORMERS_VERBOSITY"] = "error" class MultiModalEncoder: """ Encodes text OR images into a shared, NORMALIZED embedding space using google/siglip-so400m-patch16-256-i18n. This class is intended for creating embeddings for vector search. """ def __init__(self, model_id="google/siglip-so400m-patch16-256-i18n", dtype: torch.dtype = torch.bfloat16, device: str = None): # Force silence progress bars locally for this class hf_logging.set_verbosity_error() hf_logging.disable_progress_bar() self.model_id = model_id if device: self.device = device else: self.device = "cuda" if torch.cuda.is_available() else "cpu" self.dtype = dtype try: hf_token = os.getenv("Hf_TOKEN") or os.getenv("HF_TOKEN") hf_kwargs = {"token": hf_token} if hf_token else {} # --- SigLIP Loading with Config Fix --- self.processor = AutoProcessor.from_pretrained( self.model_id, use_fast=True, **hf_kwargs ) config = AutoConfig.from_pretrained(self.model_id, **hf_kwargs) if not hasattr(config, 'projection_dim'): # print("❗ Config missing projection_dim, patching...") config.projection_dim = config.text_config.hidden_size self.model = AutoModel.from_pretrained( self.model_id, config=config, dtype=self.dtype, # Use torch_dtype for from_pretrained trust_remote_code=False, **hf_kwargs ).to(self.device).eval() # ----------------------------------------------- self.embedding_dim = config.projection_dim except Exception as e: print(f"❌ Failed to load SigLIP model or components: {e}") traceback.print_exc() raise @torch.no_grad() def __call__(self, x: Union[List[str], List[Image.Image]]) -> torch.Tensor: """ Encode a batch of text or images into normalized [batch_size, embedding_dim] vectors. This is correct for storing in a vector DB for cosine similarity. """ if not x: return torch.empty(0, self.embedding_dim).to(self.device) is_text = isinstance(x[0], str) autocast_dtype = self.dtype if self.dtype in [torch.float16, torch.bfloat16] else None device_str = self.device.type if isinstance(self.device, torch.device) else self.device with torch.autocast(device_type=device_str, dtype=autocast_dtype, enabled=(autocast_dtype is not None)): try: if is_text: inputs = self.processor(text=x, return_tensors="pt", padding=True, truncation=True).to(self.device) embeddings = self.model.get_text_features(**inputs) else: # Ensure all images are RGB to avoid "Unable to infer channel dimension format" valid_images = [img.convert("RGB") for img in x] inputs = self.processor(images=valid_images, return_tensors="pt").to(self.device) embeddings = self.model.get_image_features(**inputs) # EXTRACT TENSOR IF OUTPUT IS A MODEL OUTPUT OBJECT if not isinstance(embeddings, torch.Tensor): if hasattr(embeddings, 'pooler_output'): embeddings = embeddings.pooler_output elif hasattr(embeddings, 'last_hidden_state'): # Fallback for models without pooler_output but with hidden state (e.g. usage of [CLS] or mean pooling needed?) # For SigLIP/CLIP get_image_features, it should return the features. # If it returns an object, it might be the raw output. # Let's try to assume it matches the expected embedding dim. embeddings = embeddings.last_hidden_state elif isinstance(embeddings, (tuple, list)): embeddings = embeddings[0] # Normalize in float32 for numerical stability embeddings = F.normalize(embeddings.float(), p=2, dim=-1) final_embeddings = embeddings.to(self.dtype) return final_embeddings except Exception as e: # Silently fail or log debug only if needed print(f"ERROR in MultiModalEncoder: {e}", flush=True) traceback.print_exc() return torch.empty(0, self.embedding_dim).to(self.device) # --- Test block (SigLIP) --- if __name__ == "__main__": # This test now uses the encoder class exactly as you intend to. MODEL_ID = "google/siglip-so400m-patch16-256-i18n" print(f"\n--- MultiModalEncoder Test ({MODEL_ID}) ---") texts = [ "Uranus", # Text 0 "Anus", # Text 1 "Ass", # Text 2 "Planet", # Text 3 "Dog" # Text 4 ] try: img_urls = [ "https://pbs.twimg.com/media/G3ra9C8W0AAGR8V.jpg", # Image 0: Uranus meme pic ] headers = {"User-Agent": "Mozilla/5.0"} images = [ Image.open(io.BytesIO(requests.get(u, headers=headers).content)) for u in img_urls ] size = 256 # Model's expected size images.append(Image.new("RGB", (size, size), color="green")) # Image 1: Green Square print(f"✅ Downloaded test image and created green square (size {size}x{size}).") except Exception as e: print(f"❌ Failed to load images: {e}") traceback.print_exc() exit() try: # 1. Initialize your encoder encoder = MultiModalEncoder(model_id=MODEL_ID) print("\n--- Encoding Texts (Separately) ---") text_embeddings = encoder(texts) # Uses __call__ print(f"Shape: {text_embeddings.shape}") print("\n--- Encoding Images (Separately) ---") image_embeddings = encoder(images) # Uses __call__ print(f"Shape: {image_embeddings.shape}") print("\n--- Similarity Check (Your Goal) ---") # 2. Calculate Cosine Similarity # This is just a dot product because the encoder __call__ method # already normalized the vectors. similarity_matrix = torch.matmul(image_embeddings.cpu(), text_embeddings.cpu().T).numpy() np.set_printoptions(precision=4, suppress=True) print("\nCosine Similarity matrix (image × text):") # Row: Images (0: Uranus Pic, 1: Green) # Col: Texts (0: Uranus, 1: Anus, 2: Ass, 3: Planet, 4: Dog) print(similarity_matrix) print("\nSpecific Similarity Scores (Cosine Similarity, -1.0 to 1.0):") print(f"Image 0 (Uranus pic) vs Text 0 (Uranus): {similarity_matrix[0][0]:.4f}") print(f"Image 0 (Uranus pic) vs Text 1 (Anus): {similarity_matrix[0][1]:.4f}") print(f"Image 0 (Uranus pic) vs Text 3 (Planet): {similarity_matrix[0][3]:.4f}") print(f"Image 0 (Uranus pic) vs Text 4 (Dog): {similarity_matrix[0][4]:.4f}") print(f"Image 1 (Green) vs Text 4 (Dog): {similarity_matrix[1][4]:.4f}") except Exception as e: print(f"\n--- An error occurred during the SigLIP test run ---") print(f"Error: {e}") traceback.print_exc()