|
|
import torch |
|
|
import open_clip |
|
|
from open_clip.tokenizer import HFTokenizer |
|
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import pyarrow.parquet as pq |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
import os |
|
|
|
|
|
class SigLIPModel: |
|
|
def __init__(self, |
|
|
ckpt_path="./checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin", |
|
|
model_name="ViT-SO400M-14-SigLIP-384", |
|
|
tokenizer_path="./checkpoints/ViT-SO400M-14-SigLIP-384", |
|
|
embedding_path="./embedding_datasets/10percent_siglip_encoded/all_siglip_embeddings.parquet", |
|
|
device=None): |
|
|
|
|
|
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.model_name = model_name |
|
|
self.ckpt_path = ckpt_path |
|
|
self.tokenizer_path = tokenizer_path |
|
|
self.embedding_path = embedding_path |
|
|
|
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.preprocess = None |
|
|
self.df_embed = None |
|
|
self.image_embeddings = None |
|
|
|
|
|
self.load_model() |
|
|
self.load_embeddings() |
|
|
|
|
|
def load_model(self): |
|
|
print(f"Loading SigLIP model from {self.ckpt_path}...") |
|
|
try: |
|
|
|
|
|
if not os.path.exists(self.ckpt_path): |
|
|
print(f"Warning: Checkpoint not found at {self.ckpt_path}") |
|
|
|
|
|
self.tokenizer = HFTokenizer(snapshot_download(repo_id="timm/ViT-SO400M-14-SigLIP-384")) |
|
|
|
|
|
self.ckpt_path = hf_hub_download("timm/ViT-SO400M-14-SigLIP-384", "open_clip_pytorch_model.bin") |
|
|
self.model, _, self.preprocess = open_clip.create_model_and_transforms( |
|
|
self.model_name, |
|
|
pretrained=self.ckpt_path |
|
|
) |
|
|
|
|
|
self.model = self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
print(f"SigLIP model loaded on {self.device}") |
|
|
except Exception as e: |
|
|
print(f"Error loading SigLIP model: {e}") |
|
|
|
|
|
def load_embeddings(self): |
|
|
print(f"Loading SigLIP embeddings from {self.embedding_path}...") |
|
|
try: |
|
|
if not os.path.exists(self.embedding_path): |
|
|
print(f"Warning: Embedding file not found at {self.embedding_path}") |
|
|
return |
|
|
|
|
|
self.df_embed = pq.read_table(self.embedding_path).to_pandas() |
|
|
|
|
|
|
|
|
image_embeddings_np = np.stack(self.df_embed['embedding'].values) |
|
|
self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float() |
|
|
self.image_embeddings = F.normalize(self.image_embeddings, dim=-1) |
|
|
print(f"SigLIP Data loaded: {len(self.df_embed)} records") |
|
|
except Exception as e: |
|
|
print(f"Error loading SigLIP embeddings: {e}") |
|
|
|
|
|
def encode_text(self, text): |
|
|
if self.model is None or self.tokenizer is None: |
|
|
return None |
|
|
|
|
|
text_tokens = self.tokenizer([text], context_length=self.model.context_length).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
if self.device == "cuda": |
|
|
with torch.amp.autocast('cuda'): |
|
|
text_features = self.model.encode_text(text_tokens) |
|
|
else: |
|
|
text_features = self.model.encode_text(text_tokens) |
|
|
|
|
|
text_features = F.normalize(text_features, dim=-1) |
|
|
return text_features |
|
|
|
|
|
def encode_image(self, image): |
|
|
if self.model is None: |
|
|
return None |
|
|
|
|
|
|
|
|
if isinstance(image, Image.Image): |
|
|
image = image.convert("RGB") |
|
|
|
|
|
|
|
|
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
if self.device == "cuda": |
|
|
with torch.amp.autocast('cuda'): |
|
|
image_features = self.model.encode_image(image_tensor) |
|
|
else: |
|
|
image_features = self.model.encode_image(image_tensor) |
|
|
|
|
|
image_features = F.normalize(image_features, dim=-1) |
|
|
return image_features |
|
|
|
|
|
def search(self, query_features, top_k=5, top_percent=None, threshold=0.0): |
|
|
if self.image_embeddings is None: |
|
|
return None, None, None |
|
|
|
|
|
|
|
|
query_features = query_features.float() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
similarity = (self.image_embeddings @ query_features.T).squeeze() |
|
|
probs = similarity.detach().cpu().numpy() |
|
|
|
|
|
if top_percent is not None: |
|
|
k = int(len(probs) * top_percent) |
|
|
if k < 1: k = 1 |
|
|
threshold = np.partition(probs, -k)[-k] |
|
|
|
|
|
|
|
|
mask = probs >= threshold |
|
|
filtered_indices = np.where(mask)[0] |
|
|
|
|
|
|
|
|
top_indices = np.argsort(probs)[-top_k:][::-1] |
|
|
|
|
|
return probs, filtered_indices, top_indices |
|
|
|