File size: 6,779 Bytes
eb1aec4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import inspect
import torch
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import torch.nn.functional as F
from PIL import Image
from huggingface_hub import hf_hub_download

try:
    # FarSLIP must use its vendored open_clip fork (checkpoint format / model defs differ).
    from .FarSLIP.open_clip.factory import create_model_and_transforms, get_tokenizer
    _OPENCLIP_BACKEND = "vendored_farslip"
    print("Successfully imported FarSLIP vendored open_clip.")
except ImportError as e:
    raise ImportError(
        "Failed to import FarSLIP vendored open_clip from 'models/FarSLIP/open_clip'. "
    ) from e

    
class FarSLIPModel:
    def __init__(self, 
                 ckpt_path="./checkpoints/FarSLIP/FarSLIP2_ViT-B-16.pt", 
                 model_name="ViT-B-16",
                 embedding_path="./embedding_datasets/10percent_farslip_encoded/all_farslip_embeddings.parquet",
                 device=None):
        
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        self.model_name = model_name
        if 'hf' in ckpt_path:
            ckpt_path = hf_hub_download("ZhenShiL/FarSLIP", "FarSLIP2_ViT-B-16.pt")
        self.ckpt_path = ckpt_path
        self.embedding_path = embedding_path
        
        self.model = None
        self.tokenizer = None
        self.preprocess = None
        self.df_embed = None
        self.image_embeddings = None
        
        # Force setup path and reload open_clip for FarSLIP
        # self.setup_path_and_reload()
        self.load_model()
        if self.embedding_path:
            self.load_embeddings()


    def load_model(self):
        print(f"Loading FarSLIP model from {self.ckpt_path}...")
        try:
            # We need to import open_clip here to ensure we get the right one or at least try
            # from models.FarSLIP.open_clip import create_model_and_transforms, get_tokenizer
            # from models.FarSLIP.open_clip.factory import create_model_and_transforms, get_tokenizer
            #
            # from open_clip_farslip import create_model_and_transforms, get_tokenizer
            
            if not os.path.exists(self.ckpt_path):
                print(f"Warning: Checkpoint not found at {self.ckpt_path}")
                # Try downloading? (Skipping for now as per instructions to use local)
            
            # Different open_clip variants expose slightly different factory signatures.
            # Build kwargs and filter by the actual callable signature (no sys.path hacks).
            factory_kwargs = {
                "model_name": self.model_name,
                "pretrained": self.ckpt_path,
                "precision": "amp",
                "device": self.device,
                "output_dict": True,
                "force_quick_gelu": False,
                "long_clip": "load_from_scratch",
            }

            sig = inspect.signature(create_model_and_transforms)
            supported = set(sig.parameters.keys())
            # Some variants take model_name as positional first arg; keep both styles working.
            if "model_name" in supported:
                call_kwargs = {k: v for k, v in factory_kwargs.items() if k in supported}
                self.model, _, self.preprocess = create_model_and_transforms(**call_kwargs)
            else:
                # Positional model name
                call_kwargs = {k: v for k, v in factory_kwargs.items() if k in supported and k != "model_name"}
                self.model, _, self.preprocess = create_model_and_transforms(self.model_name, **call_kwargs)
            
            self.tokenizer = get_tokenizer(self.model_name)
            self.model.eval()
            print(f"FarSLIP model loaded on {self.device} (backend={_OPENCLIP_BACKEND})")
        except Exception as e:
            print(f"Error loading FarSLIP model: {e}")

    def load_embeddings(self):
        print(f"Loading FarSLIP embeddings from {self.embedding_path}...")
        try:
            if not os.path.exists(self.embedding_path):
                print(f"Warning: Embedding file not found at {self.embedding_path}")
                return

            self.df_embed = pq.read_table(self.embedding_path).to_pandas()
            
            image_embeddings_np = np.stack(self.df_embed['embedding'].values)
            self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float()
            self.image_embeddings = F.normalize(self.image_embeddings, dim=-1)
            print(f"FarSLIP Data loaded: {len(self.df_embed)} records")
        except Exception as e:
            print(f"Error loading FarSLIP embeddings: {e}")

    def encode_text(self, text):
        if self.model is None or self.tokenizer is None:
            return None
        
        text_tokens = self.tokenizer([text], context_length=self.model.context_length).to(self.device)
        
        with torch.no_grad():
            if self.device == "cuda":
                with torch.amp.autocast('cuda'):
                    text_features = self.model.encode_text(text_tokens)
            else:
                text_features = self.model.encode_text(text_tokens)
            
            text_features = F.normalize(text_features, dim=-1)
        return text_features

    def encode_image(self, image):
        if self.model is None:
            return None
        
        if isinstance(image, Image.Image):
            image = image.convert("RGB")
        
        image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)

        with torch.no_grad():
            if self.device == "cuda":
                with torch.amp.autocast('cuda'):
                    image_features = self.model.encode_image(image_tensor)
            else:
                image_features = self.model.encode_image(image_tensor)
            
            image_features = F.normalize(image_features, dim=-1)
        return image_features

    def search(self, query_features, top_k=5, top_percent=None, threshold=0.0):
        if self.image_embeddings is None:
            return None, None, None

        query_features = query_features.float()
        
        # Similarity calculation
        # FarSLIP might use different scaling, but usually dot product for normalized vectors
        probs = (self.image_embeddings @ query_features.T).detach().cpu().numpy().flatten()
        
        if top_percent is not None:
            k = int(len(probs) * top_percent)
            if k < 1: k = 1
            threshold = np.partition(probs, -k)[-k]

        mask = probs >= threshold
        filtered_indices = np.where(mask)[0]
        
        top_indices = np.argsort(probs)[-top_k:][::-1]
        
        return probs, filtered_indices, top_indices