File size: 5,467 Bytes
eb1aec4 0c5db6e eb1aec4 0c5db6e 66ee1b2 e7e838f 66ee1b2 eb1aec4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import torch
import open_clip
from open_clip.tokenizer import HFTokenizer
from huggingface_hub import hf_hub_download, snapshot_download
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import torch.nn.functional as F
from PIL import Image
import os
class SigLIPModel:
def __init__(self,
ckpt_path="./checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin",
model_name="ViT-SO400M-14-SigLIP-384",
tokenizer_path="./checkpoints/ViT-SO400M-14-SigLIP-384",
embedding_path="./embedding_datasets/10percent_siglip_encoded/all_siglip_embeddings.parquet",
device=None):
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self.model_name = model_name
self.ckpt_path = ckpt_path
self.tokenizer_path = tokenizer_path
self.embedding_path = embedding_path
self.model = None
self.tokenizer = None
self.preprocess = None
self.df_embed = None
self.image_embeddings = None
self.load_model()
self.load_embeddings()
def load_model(self):
print(f"Loading SigLIP model from {self.ckpt_path}...")
try:
# Check if paths exist, if not try relative paths or raise warning
if not os.path.exists(self.ckpt_path):
print(f"Warning: Checkpoint not found at {self.ckpt_path}")
self.tokenizer = HFTokenizer(snapshot_download(repo_id="timm/ViT-SO400M-14-SigLIP-384"))
self.ckpt_path = hf_hub_download("timm/ViT-SO400M-14-SigLIP-384", "open_clip_pytorch_model.bin")
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
self.model_name,
pretrained=self.ckpt_path
)
self.model = self.model.to(self.device)
self.model.eval()
print(f"SigLIP model loaded on {self.device}")
except Exception as e:
print(f"Error loading SigLIP model: {e}")
def load_embeddings(self):
print(f"Loading SigLIP embeddings from {self.embedding_path}...")
try:
if not os.path.exists(self.embedding_path):
print(f"Warning: Embedding file not found at {self.embedding_path}")
return
self.df_embed = pq.read_table(self.embedding_path).to_pandas()
# Pre-compute image embeddings tensor
image_embeddings_np = np.stack(self.df_embed['embedding'].values)
self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float()
self.image_embeddings = F.normalize(self.image_embeddings, dim=-1)
print(f"SigLIP Data loaded: {len(self.df_embed)} records")
except Exception as e:
print(f"Error loading SigLIP embeddings: {e}")
def encode_text(self, text):
if self.model is None or self.tokenizer is None:
return None
text_tokens = self.tokenizer([text], context_length=self.model.context_length).to(self.device)
with torch.no_grad():
if self.device == "cuda":
with torch.amp.autocast('cuda'):
text_features = self.model.encode_text(text_tokens)
else:
text_features = self.model.encode_text(text_tokens)
text_features = F.normalize(text_features, dim=-1)
return text_features
def encode_image(self, image):
if self.model is None:
return None
# Ensure RGB
if isinstance(image, Image.Image):
image = image.convert("RGB")
# Preprocess
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
if self.device == "cuda":
with torch.amp.autocast('cuda'):
image_features = self.model.encode_image(image_tensor)
else:
image_features = self.model.encode_image(image_tensor)
image_features = F.normalize(image_features, dim=-1)
return image_features
def search(self, query_features, top_k=5, top_percent=None, threshold=0.0):
if self.image_embeddings is None:
return None, None, None
# Ensure query_features is float32
query_features = query_features.float()
# Similarity calculation
# Logits: (N_images, 1)
# logits = self.image_embeddings @ query_features.T * self.model.logit_scale.exp() + self.model.logit_bias
# probs = torch.sigmoid(logits).detach().cpu().numpy().flatten()
# Use Cosine Similarity directly (aligned with SigLIP_embdding.ipynb)
similarity = (self.image_embeddings @ query_features.T).squeeze()
probs = similarity.detach().cpu().numpy()
if top_percent is not None:
k = int(len(probs) * top_percent)
if k < 1: k = 1
threshold = np.partition(probs, -k)[-k]
# Filter by threshold
mask = probs >= threshold
filtered_indices = np.where(mask)[0]
# Get top k
top_indices = np.argsort(probs)[-top_k:][::-1]
return probs, filtered_indices, top_indices
|