File size: 7,875 Bytes
858826c 5800f64 858826c 0e3516b 5800f64 858826c 0e3516b 858826c 4dd4ab4 858826c 4dd4ab4 858826c 4dd4ab4 858826c 4dd4ab4 858826c 2c39730 858826c 2c39730 858826c 5800f64 858826c 41e0423 858826c 41e0423 858826c 5800f64 41e0423 858826c 8bac54e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | # multi_modal_processor.py
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoProcessor, AutoConfig
from typing import List, Union
from PIL import Image
import requests
import io
import os
import traceback
import numpy as np
from transformers.utils import logging as hf_logging
# Suppress warnings
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
class MultiModalEncoder:
"""
Encodes text OR images into a shared, NORMALIZED embedding space
using google/siglip-so400m-patch16-256-i18n.
This class is intended for creating embeddings for vector search.
"""
def __init__(self, model_id="google/siglip-so400m-patch16-256-i18n", dtype: torch.dtype = torch.bfloat16, device: str = None):
# Force silence progress bars locally for this class
hf_logging.set_verbosity_error()
hf_logging.disable_progress_bar()
self.model_id = model_id
if device:
self.device = device
else:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = dtype
try:
hf_token = os.getenv("Hf_TOKEN") or os.getenv("HF_TOKEN")
hf_kwargs = {"token": hf_token} if hf_token else {}
# --- SigLIP Loading with Config Fix ---
self.processor = AutoProcessor.from_pretrained(
self.model_id,
use_fast=True,
**hf_kwargs
)
config = AutoConfig.from_pretrained(self.model_id, **hf_kwargs)
if not hasattr(config, 'projection_dim'):
# print("❗ Config missing projection_dim, patching...")
config.projection_dim = config.text_config.hidden_size
self.model = AutoModel.from_pretrained(
self.model_id,
config=config,
dtype=self.dtype, # Use torch_dtype for from_pretrained
trust_remote_code=False,
**hf_kwargs
).to(self.device).eval()
# -----------------------------------------------
self.embedding_dim = config.projection_dim
except Exception as e:
print(f"❌ Failed to load SigLIP model or components: {e}")
traceback.print_exc()
raise
@torch.no_grad()
def __call__(self, x: Union[List[str], List[Image.Image]]) -> torch.Tensor:
"""
Encode a batch of text or images into normalized [batch_size, embedding_dim] vectors.
This is correct for storing in a vector DB for cosine similarity.
"""
if not x:
return torch.empty(0, self.embedding_dim).to(self.device)
is_text = isinstance(x[0], str)
autocast_dtype = self.dtype if self.dtype in [torch.float16, torch.bfloat16] else None
device_str = self.device.type if isinstance(self.device, torch.device) else self.device
with torch.autocast(device_type=device_str, dtype=autocast_dtype, enabled=(autocast_dtype is not None)):
try:
if is_text:
inputs = self.processor(text=x, return_tensors="pt", padding=True, truncation=True).to(self.device)
embeddings = self.model.get_text_features(**inputs)
else:
# Ensure all images are RGB to avoid "Unable to infer channel dimension format"
valid_images = [img.convert("RGB") for img in x]
inputs = self.processor(images=valid_images, return_tensors="pt").to(self.device)
embeddings = self.model.get_image_features(**inputs)
# EXTRACT TENSOR IF OUTPUT IS A MODEL OUTPUT OBJECT
if not isinstance(embeddings, torch.Tensor):
if hasattr(embeddings, 'pooler_output'):
embeddings = embeddings.pooler_output
elif hasattr(embeddings, 'last_hidden_state'):
# Fallback for models without pooler_output but with hidden state (e.g. usage of [CLS] or mean pooling needed?)
# For SigLIP/CLIP get_image_features, it should return the features.
# If it returns an object, it might be the raw output.
# Let's try to assume it matches the expected embedding dim.
embeddings = embeddings.last_hidden_state
elif isinstance(embeddings, (tuple, list)):
embeddings = embeddings[0]
# Normalize in float32 for numerical stability
embeddings = F.normalize(embeddings.float(), p=2, dim=-1)
final_embeddings = embeddings.to(self.dtype)
return final_embeddings
except Exception as e:
# Silently fail or log debug only if needed
print(f"ERROR in MultiModalEncoder: {e}", flush=True)
traceback.print_exc()
return torch.empty(0, self.embedding_dim).to(self.device)
# --- Test block (SigLIP) ---
if __name__ == "__main__":
# This test now uses the encoder class exactly as you intend to.
MODEL_ID = "google/siglip-so400m-patch16-256-i18n"
print(f"\n--- MultiModalEncoder Test ({MODEL_ID}) ---")
texts = [
"Uranus", # Text 0
"Anus", # Text 1
"Ass", # Text 2
"Planet", # Text 3
"Dog" # Text 4
]
try:
img_urls = [
"https://pbs.twimg.com/media/G3ra9C8W0AAGR8V.jpg", # Image 0: Uranus meme pic
]
headers = {"User-Agent": "Mozilla/5.0"}
images = [
Image.open(io.BytesIO(requests.get(u, headers=headers).content))
for u in img_urls
]
size = 256 # Model's expected size
images.append(Image.new("RGB", (size, size), color="green")) # Image 1: Green Square
print(f"✅ Downloaded test image and created green square (size {size}x{size}).")
except Exception as e:
print(f"❌ Failed to load images: {e}")
traceback.print_exc()
exit()
try:
# 1. Initialize your encoder
encoder = MultiModalEncoder(model_id=MODEL_ID)
print("\n--- Encoding Texts (Separately) ---")
text_embeddings = encoder(texts) # Uses __call__
print(f"Shape: {text_embeddings.shape}")
print("\n--- Encoding Images (Separately) ---")
image_embeddings = encoder(images) # Uses __call__
print(f"Shape: {image_embeddings.shape}")
print("\n--- Similarity Check (Your Goal) ---")
# 2. Calculate Cosine Similarity
# This is just a dot product because the encoder __call__ method
# already normalized the vectors.
similarity_matrix = torch.matmul(image_embeddings.cpu(), text_embeddings.cpu().T).numpy()
np.set_printoptions(precision=4, suppress=True)
print("\nCosine Similarity matrix (image × text):")
# Row: Images (0: Uranus Pic, 1: Green)
# Col: Texts (0: Uranus, 1: Anus, 2: Ass, 3: Planet, 4: Dog)
print(similarity_matrix)
print("\nSpecific Similarity Scores (Cosine Similarity, -1.0 to 1.0):")
print(f"Image 0 (Uranus pic) vs Text 0 (Uranus): {similarity_matrix[0][0]:.4f}")
print(f"Image 0 (Uranus pic) vs Text 1 (Anus): {similarity_matrix[0][1]:.4f}")
print(f"Image 0 (Uranus pic) vs Text 3 (Planet): {similarity_matrix[0][3]:.4f}")
print(f"Image 0 (Uranus pic) vs Text 4 (Dog): {similarity_matrix[0][4]:.4f}")
print(f"Image 1 (Green) vs Text 4 (Dog): {similarity_matrix[1][4]:.4f}")
except Exception as e:
print(f"\n--- An error occurred during the SigLIP test run ---")
print(f"Error: {e}")
traceback.print_exc()
|