oracle / models /multi_modal_processor.py
zirobtc's picture
Upload folder using huggingface_hub
2c39730 verified
# multi_modal_processor.py
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoProcessor, AutoConfig
from typing import List, Union
from PIL import Image
import requests
import io
import os
import traceback
import numpy as np
from transformers.utils import logging as hf_logging
# Suppress warnings
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
class MultiModalEncoder:
"""
Encodes text OR images into a shared, NORMALIZED embedding space
using google/siglip-so400m-patch16-256-i18n.
This class is intended for creating embeddings for vector search.
"""
def __init__(self, model_id="google/siglip-so400m-patch16-256-i18n", dtype: torch.dtype = torch.bfloat16, device: str = None):
# Force silence progress bars locally for this class
hf_logging.set_verbosity_error()
hf_logging.disable_progress_bar()
self.model_id = model_id
if device:
self.device = device
else:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = dtype
try:
hf_token = os.getenv("Hf_TOKEN") or os.getenv("HF_TOKEN")
hf_kwargs = {"token": hf_token} if hf_token else {}
# --- SigLIP Loading with Config Fix ---
self.processor = AutoProcessor.from_pretrained(
self.model_id,
use_fast=True,
**hf_kwargs
)
config = AutoConfig.from_pretrained(self.model_id, **hf_kwargs)
if not hasattr(config, 'projection_dim'):
# print("❗ Config missing projection_dim, patching...")
config.projection_dim = config.text_config.hidden_size
self.model = AutoModel.from_pretrained(
self.model_id,
config=config,
dtype=self.dtype, # Use torch_dtype for from_pretrained
trust_remote_code=False,
**hf_kwargs
).to(self.device).eval()
# -----------------------------------------------
self.embedding_dim = config.projection_dim
except Exception as e:
print(f"❌ Failed to load SigLIP model or components: {e}")
traceback.print_exc()
raise
@torch.no_grad()
def __call__(self, x: Union[List[str], List[Image.Image]]) -> torch.Tensor:
"""
Encode a batch of text or images into normalized [batch_size, embedding_dim] vectors.
This is correct for storing in a vector DB for cosine similarity.
"""
if not x:
return torch.empty(0, self.embedding_dim).to(self.device)
is_text = isinstance(x[0], str)
autocast_dtype = self.dtype if self.dtype in [torch.float16, torch.bfloat16] else None
device_str = self.device.type if isinstance(self.device, torch.device) else self.device
with torch.autocast(device_type=device_str, dtype=autocast_dtype, enabled=(autocast_dtype is not None)):
try:
if is_text:
inputs = self.processor(text=x, return_tensors="pt", padding=True, truncation=True).to(self.device)
embeddings = self.model.get_text_features(**inputs)
else:
# Ensure all images are RGB to avoid "Unable to infer channel dimension format"
valid_images = [img.convert("RGB") for img in x]
inputs = self.processor(images=valid_images, return_tensors="pt").to(self.device)
embeddings = self.model.get_image_features(**inputs)
# EXTRACT TENSOR IF OUTPUT IS A MODEL OUTPUT OBJECT
if not isinstance(embeddings, torch.Tensor):
if hasattr(embeddings, 'pooler_output'):
embeddings = embeddings.pooler_output
elif hasattr(embeddings, 'last_hidden_state'):
# Fallback for models without pooler_output but with hidden state (e.g. usage of [CLS] or mean pooling needed?)
# For SigLIP/CLIP get_image_features, it should return the features.
# If it returns an object, it might be the raw output.
# Let's try to assume it matches the expected embedding dim.
embeddings = embeddings.last_hidden_state
elif isinstance(embeddings, (tuple, list)):
embeddings = embeddings[0]
# Normalize in float32 for numerical stability
embeddings = F.normalize(embeddings.float(), p=2, dim=-1)
final_embeddings = embeddings.to(self.dtype)
return final_embeddings
except Exception as e:
# Silently fail or log debug only if needed
print(f"ERROR in MultiModalEncoder: {e}", flush=True)
traceback.print_exc()
return torch.empty(0, self.embedding_dim).to(self.device)
# --- Test block (SigLIP) ---
if __name__ == "__main__":
# This test now uses the encoder class exactly as you intend to.
MODEL_ID = "google/siglip-so400m-patch16-256-i18n"
print(f"\n--- MultiModalEncoder Test ({MODEL_ID}) ---")
texts = [
"Uranus", # Text 0
"Anus", # Text 1
"Ass", # Text 2
"Planet", # Text 3
"Dog" # Text 4
]
try:
img_urls = [
"https://pbs.twimg.com/media/G3ra9C8W0AAGR8V.jpg", # Image 0: Uranus meme pic
]
headers = {"User-Agent": "Mozilla/5.0"}
images = [
Image.open(io.BytesIO(requests.get(u, headers=headers).content))
for u in img_urls
]
size = 256 # Model's expected size
images.append(Image.new("RGB", (size, size), color="green")) # Image 1: Green Square
print(f"✅ Downloaded test image and created green square (size {size}x{size}).")
except Exception as e:
print(f"❌ Failed to load images: {e}")
traceback.print_exc()
exit()
try:
# 1. Initialize your encoder
encoder = MultiModalEncoder(model_id=MODEL_ID)
print("\n--- Encoding Texts (Separately) ---")
text_embeddings = encoder(texts) # Uses __call__
print(f"Shape: {text_embeddings.shape}")
print("\n--- Encoding Images (Separately) ---")
image_embeddings = encoder(images) # Uses __call__
print(f"Shape: {image_embeddings.shape}")
print("\n--- Similarity Check (Your Goal) ---")
# 2. Calculate Cosine Similarity
# This is just a dot product because the encoder __call__ method
# already normalized the vectors.
similarity_matrix = torch.matmul(image_embeddings.cpu(), text_embeddings.cpu().T).numpy()
np.set_printoptions(precision=4, suppress=True)
print("\nCosine Similarity matrix (image × text):")
# Row: Images (0: Uranus Pic, 1: Green)
# Col: Texts (0: Uranus, 1: Anus, 2: Ass, 3: Planet, 4: Dog)
print(similarity_matrix)
print("\nSpecific Similarity Scores (Cosine Similarity, -1.0 to 1.0):")
print(f"Image 0 (Uranus pic) vs Text 0 (Uranus): {similarity_matrix[0][0]:.4f}")
print(f"Image 0 (Uranus pic) vs Text 1 (Anus): {similarity_matrix[0][1]:.4f}")
print(f"Image 0 (Uranus pic) vs Text 3 (Planet): {similarity_matrix[0][3]:.4f}")
print(f"Image 0 (Uranus pic) vs Text 4 (Dog): {similarity_matrix[0][4]:.4f}")
print(f"Image 1 (Green) vs Text 4 (Dog): {similarity_matrix[1][4]:.4f}")
except Exception as e:
print(f"\n--- An error occurred during the SigLIP test run ---")
print(f"Error: {e}")
traceback.print_exc()