ParshvPatel's picture
feat: HuggingFace Spaces deployment
d992912
"""
engine/encoder.py
FashionCLIPEncoder — wraps a HuggingFace CLIP model for text and image encoding.
Extracted from finalized_search_engine_full_script.py (lines 482-652).
"""
import logging
from pathlib import Path
from typing import List, Optional
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
from backend.app.config import SearchConfig
logger = logging.getLogger(__name__)
__all__ = ["FashionCLIPEncoder"]
class FashionCLIPEncoder:
"""
v3.1 — Handles models that return BaseModelOutputWithPooling
instead of raw tensors from get_text_features / get_image_features.
"""
def __init__(self, config: SearchConfig):
self.config = config
self.device = config.device
self.model = None
self.processor = None
self.model_name = None
self._load_model()
def _load_model(self):
models_to_try = [self.config.primary_model, self.config.fallback_model]
for model_name in models_to_try:
try:
logger.info(f"Loading model: {model_name}")
kwargs = {}
if self.config.hf_token:
kwargs['token'] = self.config.hf_token
self.model = CLIPModel.from_pretrained(model_name, **kwargs)
self.processor = CLIPProcessor.from_pretrained(model_name, **kwargs)
self.model = self.model.to(self.device)
self.model.eval()
self.model_name = model_name
# ── Probe the model to find actual embedding dim ──
test_inputs = self.processor(
text=["test"], return_tensors="pt",
padding=True, truncation=True, max_length=77,
)
test_inputs = {k: v.to(self.device) for k, v in test_inputs.items()}
with torch.no_grad():
test_out = self.model.get_text_features(**test_inputs)
test_tensor = self._to_tensor(test_out)
actual_dim = test_tensor.shape[-1]
if actual_dim != self.config.embedding_dim:
logger.info(
f"Model embedding dim = {actual_dim} "
f"(config said {self.config.embedding_dim}). Updating config."
)
self.config.embedding_dim = actual_dim
logger.info(f"Model loaded: {model_name} on {self.device} (dim={actual_dim})")
return
except Exception as e:
logger.warning(f"Failed to load {model_name}: {e}")
continue
raise RuntimeError(
"Could not load any CLIP model. Check internet connection and HF_TOKEN."
)
@staticmethod
def _to_tensor(output) -> torch.Tensor:
if isinstance(output, torch.Tensor):
return output
if hasattr(output, 'pooler_output') and output.pooler_output is not None:
return output.pooler_output
if hasattr(output, 'last_hidden_state'):
return output.last_hidden_state.mean(dim=1)
if hasattr(output, 'text_embeds'):
return output.text_embeds
if hasattr(output, 'image_embeds'):
return output.image_embeds
if isinstance(output, (tuple, list)) and len(output) > 0:
return output[0] if isinstance(output[0], torch.Tensor) else output[1]
raise TypeError(
f"Cannot extract tensor from model output of type {type(output)}. "
f"Available attributes: {[a for a in dir(output) if not a.startswith('_')]}"
)
@torch.no_grad()
def encode_texts(self, texts: List[str], batch_size: Optional[int] = None) -> np.ndarray:
batch_size = batch_size or min(self.config.embed_batch_size * 4, 256)
texts = [str(t) if t and str(t) != 'nan' else '' for t in texts]
all_emb = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
inputs = self.processor(
text=batch, return_tensors="pt",
padding=True, truncation=True, max_length=77,
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
raw = self.model.get_text_features(**inputs)
feats = self._to_tensor(raw)
feats = F.normalize(feats, p=2, dim=-1).cpu().numpy()
all_emb.append(feats)
return np.vstack(all_emb).astype(np.float32)
@torch.no_grad()
def encode_images_from_paths(
self, paths: List[Path], batch_size: Optional[int] = None,
) -> np.ndarray:
batch_size = batch_size or self.config.embed_batch_size
n = len(paths)
dim = self.config.embedding_dim
embeddings = np.zeros((n, dim), dtype=np.float32)
for start in range(0, n, batch_size):
end = min(start + batch_size, n)
batch_paths = paths[start:end]
images = []
valid_in_batch = []
for j, p in enumerate(batch_paths):
try:
img = Image.open(p).convert("RGB")
images.append(img)
valid_in_batch.append(start + j)
except Exception:
pass
if not images:
continue
try:
inputs = self.processor(images=images, return_tensors="pt", padding=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
if self.device == "cuda":
with torch.amp.autocast("cuda"):
raw = self.model.get_image_features(**inputs)
else:
raw = self.model.get_image_features(**inputs)
feats = self._to_tensor(raw)
feats = F.normalize(feats, p=2, dim=-1).cpu().numpy()
for local_j, global_j in enumerate(valid_in_batch):
embeddings[global_j] = feats[local_j]
except Exception as e:
logger.warning(f"Batch encoding failed at {start}: {e}")
if self.device == "cuda" and start % (batch_size * 10) == 0:
torch.cuda.empty_cache()
return embeddings
@torch.no_grad()
def encode_images(self, images: List[Image.Image], batch_size: Optional[int] = None) -> np.ndarray:
batch_size = batch_size or self.config.embed_batch_size
all_emb = []
for i in range(0, len(images), batch_size):
batch = images[i:i + batch_size]
inputs = self.processor(images=batch, return_tensors="pt", padding=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
if self.device == "cuda":
with torch.amp.autocast("cuda"):
raw = self.model.get_image_features(**inputs)
else:
raw = self.model.get_image_features(**inputs)
feats = self._to_tensor(raw)
all_emb.append(F.normalize(feats, p=2, dim=-1).cpu().numpy())
return np.vstack(all_emb).astype(np.float32)
@torch.no_grad()
def encode_query_text(self, query: str) -> np.ndarray:
prompted = [tmpl.format(query) for tmpl in self.config.prompt_templates]
embeddings = self.encode_texts(prompted)
avg = embeddings.mean(axis=0, keepdims=True)
avg = avg / (np.linalg.norm(avg, axis=-1, keepdims=True) + 1e-8)
return avg.astype(np.float32)
@torch.no_grad()
def encode_multimodal_query(
self, text: str, image: Image.Image, text_weight: float = 0.5,
) -> np.ndarray:
text_emb = self.encode_query_text(text)
img_emb = self.encode_images([image])
fused = text_weight * text_emb + (1 - text_weight) * img_emb
fused = fused / (np.linalg.norm(fused, axis=-1, keepdims=True) + 1e-8)
return fused.astype(np.float32)