File size: 5,467 Bytes
eb1aec4
 
 
0c5db6e
eb1aec4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c5db6e
66ee1b2
e7e838f
 
 
 
 
66ee1b2
eb1aec4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import open_clip
from open_clip.tokenizer import HFTokenizer
from huggingface_hub import hf_hub_download, snapshot_download
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import torch.nn.functional as F
from PIL import Image
import os

class SigLIPModel:
    def __init__(self, 
                 ckpt_path="./checkpoints/ViT-SO400M-14-SigLIP-384/open_clip_pytorch_model.bin", 
                 model_name="ViT-SO400M-14-SigLIP-384",
                 tokenizer_path="./checkpoints/ViT-SO400M-14-SigLIP-384",
                 embedding_path="./embedding_datasets/10percent_siglip_encoded/all_siglip_embeddings.parquet",
                 device=None):
        
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        self.model_name = model_name
        self.ckpt_path = ckpt_path
        self.tokenizer_path = tokenizer_path
        self.embedding_path = embedding_path
        
        self.model = None
        self.tokenizer = None
        self.preprocess = None
        self.df_embed = None
        self.image_embeddings = None
        
        self.load_model()
        self.load_embeddings()

    def load_model(self):
        print(f"Loading SigLIP model from {self.ckpt_path}...")
        try:
            # Check if paths exist, if not try relative paths or raise warning
            if not os.path.exists(self.ckpt_path):
                print(f"Warning: Checkpoint not found at {self.ckpt_path}")
            
            self.tokenizer = HFTokenizer(snapshot_download(repo_id="timm/ViT-SO400M-14-SigLIP-384"))
            
            self.ckpt_path = hf_hub_download("timm/ViT-SO400M-14-SigLIP-384", "open_clip_pytorch_model.bin")
            self.model, _, self.preprocess = open_clip.create_model_and_transforms(
                self.model_name, 
                pretrained=self.ckpt_path
            )
            
            self.model = self.model.to(self.device)
            self.model.eval()            
                
            print(f"SigLIP model loaded on {self.device}")
        except Exception as e:
            print(f"Error loading SigLIP model: {e}")

    def load_embeddings(self):
        print(f"Loading SigLIP embeddings from {self.embedding_path}...")
        try:
            if not os.path.exists(self.embedding_path):
                print(f"Warning: Embedding file not found at {self.embedding_path}")
                return

            self.df_embed = pq.read_table(self.embedding_path).to_pandas()
            
            # Pre-compute image embeddings tensor
            image_embeddings_np = np.stack(self.df_embed['embedding'].values)
            self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float()
            self.image_embeddings = F.normalize(self.image_embeddings, dim=-1)
            print(f"SigLIP Data loaded: {len(self.df_embed)} records")
        except Exception as e:
            print(f"Error loading SigLIP embeddings: {e}")

    def encode_text(self, text):
        if self.model is None or self.tokenizer is None:
            return None
        
        text_tokens = self.tokenizer([text], context_length=self.model.context_length).to(self.device)
        
        with torch.no_grad():
            if self.device == "cuda":
                with torch.amp.autocast('cuda'):
                    text_features = self.model.encode_text(text_tokens)
            else:
                text_features = self.model.encode_text(text_tokens)
            
            text_features = F.normalize(text_features, dim=-1)
        return text_features

    def encode_image(self, image):
        if self.model is None:
            return None
        
        # Ensure RGB
        if isinstance(image, Image.Image):
            image = image.convert("RGB")
        
        # Preprocess
        image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)

        with torch.no_grad():
            if self.device == "cuda":
                with torch.amp.autocast('cuda'):
                    image_features = self.model.encode_image(image_tensor)
            else:
                image_features = self.model.encode_image(image_tensor)
            
            image_features = F.normalize(image_features, dim=-1)
        return image_features

    def search(self, query_features, top_k=5, top_percent=None, threshold=0.0):
        if self.image_embeddings is None:
            return None, None, None

        # Ensure query_features is float32
        query_features = query_features.float()
        
        # Similarity calculation
        # Logits: (N_images, 1)
        # logits = self.image_embeddings @ query_features.T * self.model.logit_scale.exp() + self.model.logit_bias
        # probs = torch.sigmoid(logits).detach().cpu().numpy().flatten()
        
        # Use Cosine Similarity directly (aligned with SigLIP_embdding.ipynb)
        similarity = (self.image_embeddings @ query_features.T).squeeze()
        probs = similarity.detach().cpu().numpy()
        
        if top_percent is not None:
            k = int(len(probs) * top_percent)
            if k < 1: k = 1
            threshold = np.partition(probs, -k)[-k]

        # Filter by threshold
        mask = probs >= threshold
        filtered_indices = np.where(mask)[0]
        
        # Get top k
        top_indices = np.argsort(probs)[-top_k:][::-1]
        
        return probs, filtered_indices, top_indices