File size: 6,779 Bytes
eb1aec4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import os
import inspect
import torch
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import torch.nn.functional as F
from PIL import Image
from huggingface_hub import hf_hub_download
try:
# FarSLIP must use its vendored open_clip fork (checkpoint format / model defs differ).
from .FarSLIP.open_clip.factory import create_model_and_transforms, get_tokenizer
_OPENCLIP_BACKEND = "vendored_farslip"
print("Successfully imported FarSLIP vendored open_clip.")
except ImportError as e:
raise ImportError(
"Failed to import FarSLIP vendored open_clip from 'models/FarSLIP/open_clip'. "
) from e
class FarSLIPModel:
def __init__(self,
ckpt_path="./checkpoints/FarSLIP/FarSLIP2_ViT-B-16.pt",
model_name="ViT-B-16",
embedding_path="./embedding_datasets/10percent_farslip_encoded/all_farslip_embeddings.parquet",
device=None):
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self.model_name = model_name
if 'hf' in ckpt_path:
ckpt_path = hf_hub_download("ZhenShiL/FarSLIP", "FarSLIP2_ViT-B-16.pt")
self.ckpt_path = ckpt_path
self.embedding_path = embedding_path
self.model = None
self.tokenizer = None
self.preprocess = None
self.df_embed = None
self.image_embeddings = None
# Force setup path and reload open_clip for FarSLIP
# self.setup_path_and_reload()
self.load_model()
if self.embedding_path:
self.load_embeddings()
def load_model(self):
print(f"Loading FarSLIP model from {self.ckpt_path}...")
try:
# We need to import open_clip here to ensure we get the right one or at least try
# from models.FarSLIP.open_clip import create_model_and_transforms, get_tokenizer
# from models.FarSLIP.open_clip.factory import create_model_and_transforms, get_tokenizer
#
# from open_clip_farslip import create_model_and_transforms, get_tokenizer
if not os.path.exists(self.ckpt_path):
print(f"Warning: Checkpoint not found at {self.ckpt_path}")
# Try downloading? (Skipping for now as per instructions to use local)
# Different open_clip variants expose slightly different factory signatures.
# Build kwargs and filter by the actual callable signature (no sys.path hacks).
factory_kwargs = {
"model_name": self.model_name,
"pretrained": self.ckpt_path,
"precision": "amp",
"device": self.device,
"output_dict": True,
"force_quick_gelu": False,
"long_clip": "load_from_scratch",
}
sig = inspect.signature(create_model_and_transforms)
supported = set(sig.parameters.keys())
# Some variants take model_name as positional first arg; keep both styles working.
if "model_name" in supported:
call_kwargs = {k: v for k, v in factory_kwargs.items() if k in supported}
self.model, _, self.preprocess = create_model_and_transforms(**call_kwargs)
else:
# Positional model name
call_kwargs = {k: v for k, v in factory_kwargs.items() if k in supported and k != "model_name"}
self.model, _, self.preprocess = create_model_and_transforms(self.model_name, **call_kwargs)
self.tokenizer = get_tokenizer(self.model_name)
self.model.eval()
print(f"FarSLIP model loaded on {self.device} (backend={_OPENCLIP_BACKEND})")
except Exception as e:
print(f"Error loading FarSLIP model: {e}")
def load_embeddings(self):
print(f"Loading FarSLIP embeddings from {self.embedding_path}...")
try:
if not os.path.exists(self.embedding_path):
print(f"Warning: Embedding file not found at {self.embedding_path}")
return
self.df_embed = pq.read_table(self.embedding_path).to_pandas()
image_embeddings_np = np.stack(self.df_embed['embedding'].values)
self.image_embeddings = torch.from_numpy(image_embeddings_np).to(self.device).float()
self.image_embeddings = F.normalize(self.image_embeddings, dim=-1)
print(f"FarSLIP Data loaded: {len(self.df_embed)} records")
except Exception as e:
print(f"Error loading FarSLIP embeddings: {e}")
def encode_text(self, text):
if self.model is None or self.tokenizer is None:
return None
text_tokens = self.tokenizer([text], context_length=self.model.context_length).to(self.device)
with torch.no_grad():
if self.device == "cuda":
with torch.amp.autocast('cuda'):
text_features = self.model.encode_text(text_tokens)
else:
text_features = self.model.encode_text(text_tokens)
text_features = F.normalize(text_features, dim=-1)
return text_features
def encode_image(self, image):
if self.model is None:
return None
if isinstance(image, Image.Image):
image = image.convert("RGB")
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
if self.device == "cuda":
with torch.amp.autocast('cuda'):
image_features = self.model.encode_image(image_tensor)
else:
image_features = self.model.encode_image(image_tensor)
image_features = F.normalize(image_features, dim=-1)
return image_features
def search(self, query_features, top_k=5, top_percent=None, threshold=0.0):
if self.image_embeddings is None:
return None, None, None
query_features = query_features.float()
# Similarity calculation
# FarSLIP might use different scaling, but usually dot product for normalized vectors
probs = (self.image_embeddings @ query_features.T).detach().cpu().numpy().flatten()
if top_percent is not None:
k = int(len(probs) * top_percent)
if k < 1: k = 1
threshold = np.partition(probs, -k)[-k]
mask = probs >= threshold
filtered_indices = np.where(mask)[0]
top_indices = np.argsort(probs)[-top_k:][::-1]
return probs, filtered_indices, top_indices
|