EarthEmbeddingExplorer / models /siglip_model.py
ML4RS-Anonymous's picture
fix: download repo using snapshot_download
0c5db6e verified
import torch
import open_clip
from open_clip.tokenizer import HFTokenizer
from huggingface_hub import hf_hub_download, snapshot_download
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import torch.nn.functional as F
from PIL import Image
import os
class SigLIPModel:
def __init__(self,
ckpt_path="./checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin",
model_name="ViT-SO400M-14-SigLIP-384",
tokenizer_path="./checkpoints/ViT-SO400M-14-SigLIP-384",
embedding_path="./embedding_datasets/10percent_siglip_encoded/all_siglip_embeddings.parquet",
device=None):
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self.model_name = model_name
self.ckpt_path = ckpt_path
self.tokenizer_path = tokenizer_path
self.embedding_path = embedding_path
self.model = None
self.tokenizer = None
self.preprocess = None
self.df_embed = None
self.image_embeddings = None
self.load_model()
self.load_embeddings()
def load_model(self):
print(f"Loading SigLIP model from {self.ckpt_path}...")
try:
# Check if paths exist, if not try relative paths or raise warning
if not os.path.exists(self.ckpt_path):
print(f"Warning: Checkpoint not found at {self.ckpt_path}")
self.tokenizer = HFTokenizer(snapshot_download(repo_id="timm/ViT-SO400M-14-SigLIP-384"))
self.ckpt_path = hf_hub_download("timm/ViT-SO400M-14-SigLIP-384", "open_clip_pytorch_model.bin")
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
self.model_name,
pretrained=self.ckpt_path
)
self.model = self.model.to(self.device)
self.model.eval()
print(f"SigLIP model loaded on {self.device}")
except Exception as e:
print(f"Error loading SigLIP model: {e}")
def load_embeddings(self):
print(f"Loading SigLIP embeddings from {self.embedding_path}...")
try:
if not os.path.exists(self.embedding_path):
print(f"Warning: Embedding file not found at {self.embedding_path}")
return
self.df_embed = pq.read_table(self.embedding_path).to_pandas()
# Pre-compute image embeddings tensor
image_embeddings_np = np.stack(self.df_embed['embedding'].values)
self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float()
self.image_embeddings = F.normalize(self.image_embeddings, dim=-1)
print(f"SigLIP Data loaded: {len(self.df_embed)} records")
except Exception as e:
print(f"Error loading SigLIP embeddings: {e}")
def encode_text(self, text):
if self.model is None or self.tokenizer is None:
return None
text_tokens = self.tokenizer([text], context_length=self.model.context_length).to(self.device)
with torch.no_grad():
if self.device == "cuda":
with torch.amp.autocast('cuda'):
text_features = self.model.encode_text(text_tokens)
else:
text_features = self.model.encode_text(text_tokens)
text_features = F.normalize(text_features, dim=-1)
return text_features
def encode_image(self, image):
if self.model is None:
return None
# Ensure RGB
if isinstance(image, Image.Image):
image = image.convert("RGB")
# Preprocess
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
if self.device == "cuda":
with torch.amp.autocast('cuda'):
image_features = self.model.encode_image(image_tensor)
else:
image_features = self.model.encode_image(image_tensor)
image_features = F.normalize(image_features, dim=-1)
return image_features
def search(self, query_features, top_k=5, top_percent=None, threshold=0.0):
if self.image_embeddings is None:
return None, None, None
# Ensure query_features is float32
query_features = query_features.float()
# Similarity calculation
# Logits: (N_images, 1)
# logits = self.image_embeddings @ query_features.T * self.model.logit_scale.exp() + self.model.logit_bias
# probs = torch.sigmoid(logits).detach().cpu().numpy().flatten()
# Use Cosine Similarity directly (aligned with SigLIP_embdding.ipynb)
similarity = (self.image_embeddings @ query_features.T).squeeze()
probs = similarity.detach().cpu().numpy()
if top_percent is not None:
k = int(len(probs) * top_percent)
if k < 1: k = 1
threshold = np.partition(probs, -k)[-k]
# Filter by threshold
mask = probs >= threshold
filtered_indices = np.where(mask)[0]
# Get top k
top_indices = np.argsort(probs)[-top_k:][::-1]
return probs, filtered_indices, top_indices