Spaces:
Running
Running
| import os | |
| import inspect | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| import pyarrow.parquet as pq | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| try: | |
| # FarSLIP must use its vendored open_clip fork (checkpoint format / model defs differ). | |
| from .FarSLIP.open_clip.factory import create_model_and_transforms, get_tokenizer | |
| _OPENCLIP_BACKEND = "vendored_farslip" | |
| print("Successfully imported FarSLIP vendored open_clip.") | |
| except ImportError as e: | |
| raise ImportError( | |
| "Failed to import FarSLIP vendored open_clip from 'models/FarSLIP/open_clip'. " | |
| ) from e | |
| class FarSLIPModel: | |
| def __init__(self, | |
| ckpt_path="./checkpoints/FarSLIP/FarSLIP2_ViT-B-16.pt", | |
| model_name="ViT-B-16", | |
| embedding_path="./embedding_datasets/10percent_farslip_encoded/all_farslip_embeddings.parquet", | |
| device=None): | |
| self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model_name = model_name | |
| if 'hf' in ckpt_path: | |
| ckpt_path = hf_hub_download("ZhenShiL/FarSLIP", "FarSLIP2_ViT-B-16.pt") | |
| self.ckpt_path = ckpt_path | |
| self.embedding_path = embedding_path | |
| self.model = None | |
| self.tokenizer = None | |
| self.preprocess = None | |
| self.df_embed = None | |
| self.image_embeddings = None | |
| # Force setup path and reload open_clip for FarSLIP | |
| # self.setup_path_and_reload() | |
| self.load_model() | |
| if self.embedding_path: | |
| self.load_embeddings() | |
| def load_model(self): | |
| print(f"Loading FarSLIP model from {self.ckpt_path}...") | |
| try: | |
| # We need to import open_clip here to ensure we get the right one or at least try | |
| # from models.FarSLIP.open_clip import create_model_and_transforms, get_tokenizer | |
| # from models.FarSLIP.open_clip.factory import create_model_and_transforms, get_tokenizer | |
| # | |
| # from open_clip_farslip import create_model_and_transforms, get_tokenizer | |
| if not os.path.exists(self.ckpt_path): | |
| print(f"Warning: Checkpoint not found at {self.ckpt_path}") | |
| # Try downloading? (Skipping for now as per instructions to use local) | |
| # Different open_clip variants expose slightly different factory signatures. | |
| # Build kwargs and filter by the actual callable signature (no sys.path hacks). | |
| factory_kwargs = { | |
| "model_name": self.model_name, | |
| "pretrained": self.ckpt_path, | |
| "precision": "amp", | |
| "device": self.device, | |
| "output_dict": True, | |
| "force_quick_gelu": False, | |
| "long_clip": "load_from_scratch", | |
| } | |
| sig = inspect.signature(create_model_and_transforms) | |
| supported = set(sig.parameters.keys()) | |
| # Some variants take model_name as positional first arg; keep both styles working. | |
| if "model_name" in supported: | |
| call_kwargs = {k: v for k, v in factory_kwargs.items() if k in supported} | |
| self.model, _, self.preprocess = create_model_and_transforms(**call_kwargs) | |
| else: | |
| # Positional model name | |
| call_kwargs = {k: v for k, v in factory_kwargs.items() if k in supported and k != "model_name"} | |
| self.model, _, self.preprocess = create_model_and_transforms(self.model_name, **call_kwargs) | |
| self.tokenizer = get_tokenizer(self.model_name) | |
| self.model.eval() | |
| print(f"FarSLIP model loaded on {self.device} (backend={_OPENCLIP_BACKEND})") | |
| except Exception as e: | |
| print(f"Error loading FarSLIP model: {e}") | |
| def load_embeddings(self): | |
| print(f"Loading FarSLIP embeddings from {self.embedding_path}...") | |
| try: | |
| if not os.path.exists(self.embedding_path): | |
| print(f"Warning: Embedding file not found at {self.embedding_path}") | |
| return | |
| self.df_embed = pq.read_table(self.embedding_path).to_pandas() | |
| image_embeddings_np = np.stack(self.df_embed['embedding'].values) | |
| self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float() | |
| self.image_embeddings = F.normalize(self.image_embeddings, dim=-1) | |
| print(f"FarSLIP Data loaded: {len(self.df_embed)} records") | |
| except Exception as e: | |
| print(f"Error loading FarSLIP embeddings: {e}") | |
| def encode_text(self, text): | |
| if self.model is None or self.tokenizer is None: | |
| return None | |
| text_tokens = self.tokenizer([text], context_length=self.model.context_length).to(self.device) | |
| with torch.no_grad(): | |
| if self.device == "cuda": | |
| with torch.amp.autocast('cuda'): | |
| text_features = self.model.encode_text(text_tokens) | |
| else: | |
| text_features = self.model.encode_text(text_tokens) | |
| text_features = F.normalize(text_features, dim=-1) | |
| return text_features | |
| def encode_image(self, image): | |
| if self.model is None: | |
| return None | |
| if isinstance(image, Image.Image): | |
| image = image.convert("RGB") | |
| image_tensor = self.preprocess(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| if self.device == "cuda": | |
| with torch.amp.autocast('cuda'): | |
| image_features = self.model.encode_image(image_tensor) | |
| else: | |
| image_features = self.model.encode_image(image_tensor) | |
| image_features = F.normalize(image_features, dim=-1) | |
| return image_features | |
| def search(self, query_features, top_k=5, top_percent=None, threshold=0.0): | |
| if self.image_embeddings is None: | |
| return None, None, None | |
| query_features = query_features.float() | |
| # Similarity calculation | |
| # FarSLIP might use different scaling, but usually dot product for normalized vectors | |
| probs = (self.image_embeddings @ query_features.T).detach().cpu().numpy().flatten() | |
| if top_percent is not None: | |
| k = int(len(probs) * top_percent) | |
| if k < 1: k = 1 | |
| threshold = np.partition(probs, -k)[-k] | |
| mask = probs >= threshold | |
| filtered_indices = np.where(mask)[0] | |
| top_indices = np.argsort(probs)[-top_k:][::-1] | |
| return probs, filtered_indices, top_indices | |