| """ |
| engine/encoder.py |
| |
| FashionCLIPEncoder — wraps a HuggingFace CLIP model for text and image encoding. |
| Extracted from finalized_search_engine_full_script.py (lines 482-652). |
| """ |
|
|
| import logging |
| from pathlib import Path |
| from typing import List, Optional |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from PIL import Image |
| from transformers import CLIPModel, CLIPProcessor |
|
|
| from backend.app.config import SearchConfig |
|
|
| logger = logging.getLogger(__name__) |
|
|
| __all__ = ["FashionCLIPEncoder"] |
|
|
|
|
| class FashionCLIPEncoder: |
| """ |
| v3.1 — Handles models that return BaseModelOutputWithPooling |
| instead of raw tensors from get_text_features / get_image_features. |
| """ |
|
|
| def __init__(self, config: SearchConfig): |
| self.config = config |
| self.device = config.device |
| self.model = None |
| self.processor = None |
| self.model_name = None |
| self._load_model() |
|
|
| def _load_model(self): |
| models_to_try = [self.config.primary_model, self.config.fallback_model] |
| for model_name in models_to_try: |
| try: |
| logger.info(f"Loading model: {model_name}") |
| kwargs = {} |
| if self.config.hf_token: |
| kwargs['token'] = self.config.hf_token |
| self.model = CLIPModel.from_pretrained(model_name, **kwargs) |
| self.processor = CLIPProcessor.from_pretrained(model_name, **kwargs) |
| self.model = self.model.to(self.device) |
| self.model.eval() |
| self.model_name = model_name |
|
|
| |
| test_inputs = self.processor( |
| text=["test"], return_tensors="pt", |
| padding=True, truncation=True, max_length=77, |
| ) |
| test_inputs = {k: v.to(self.device) for k, v in test_inputs.items()} |
| with torch.no_grad(): |
| test_out = self.model.get_text_features(**test_inputs) |
| test_tensor = self._to_tensor(test_out) |
| actual_dim = test_tensor.shape[-1] |
| if actual_dim != self.config.embedding_dim: |
| logger.info( |
| f"Model embedding dim = {actual_dim} " |
| f"(config said {self.config.embedding_dim}). Updating config." |
| ) |
| self.config.embedding_dim = actual_dim |
|
|
| logger.info(f"Model loaded: {model_name} on {self.device} (dim={actual_dim})") |
| return |
| except Exception as e: |
| logger.warning(f"Failed to load {model_name}: {e}") |
| continue |
| raise RuntimeError( |
| "Could not load any CLIP model. Check internet connection and HF_TOKEN." |
| ) |
|
|
| @staticmethod |
| def _to_tensor(output) -> torch.Tensor: |
| if isinstance(output, torch.Tensor): |
| return output |
| if hasattr(output, 'pooler_output') and output.pooler_output is not None: |
| return output.pooler_output |
| if hasattr(output, 'last_hidden_state'): |
| return output.last_hidden_state.mean(dim=1) |
| if hasattr(output, 'text_embeds'): |
| return output.text_embeds |
| if hasattr(output, 'image_embeds'): |
| return output.image_embeds |
| if isinstance(output, (tuple, list)) and len(output) > 0: |
| return output[0] if isinstance(output[0], torch.Tensor) else output[1] |
| raise TypeError( |
| f"Cannot extract tensor from model output of type {type(output)}. " |
| f"Available attributes: {[a for a in dir(output) if not a.startswith('_')]}" |
| ) |
|
|
| @torch.no_grad() |
| def encode_texts(self, texts: List[str], batch_size: Optional[int] = None) -> np.ndarray: |
| batch_size = batch_size or min(self.config.embed_batch_size * 4, 256) |
| texts = [str(t) if t and str(t) != 'nan' else '' for t in texts] |
| all_emb = [] |
| for i in range(0, len(texts), batch_size): |
| batch = texts[i:i + batch_size] |
| inputs = self.processor( |
| text=batch, return_tensors="pt", |
| padding=True, truncation=True, max_length=77, |
| ) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| raw = self.model.get_text_features(**inputs) |
| feats = self._to_tensor(raw) |
| feats = F.normalize(feats, p=2, dim=-1).cpu().numpy() |
| all_emb.append(feats) |
| return np.vstack(all_emb).astype(np.float32) |
|
|
| @torch.no_grad() |
| def encode_images_from_paths( |
| self, paths: List[Path], batch_size: Optional[int] = None, |
| ) -> np.ndarray: |
| batch_size = batch_size or self.config.embed_batch_size |
| n = len(paths) |
| dim = self.config.embedding_dim |
| embeddings = np.zeros((n, dim), dtype=np.float32) |
|
|
| for start in range(0, n, batch_size): |
| end = min(start + batch_size, n) |
| batch_paths = paths[start:end] |
|
|
| images = [] |
| valid_in_batch = [] |
| for j, p in enumerate(batch_paths): |
| try: |
| img = Image.open(p).convert("RGB") |
| images.append(img) |
| valid_in_batch.append(start + j) |
| except Exception: |
| pass |
|
|
| if not images: |
| continue |
|
|
| try: |
| inputs = self.processor(images=images, return_tensors="pt", padding=True) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| if self.device == "cuda": |
| with torch.amp.autocast("cuda"): |
| raw = self.model.get_image_features(**inputs) |
| else: |
| raw = self.model.get_image_features(**inputs) |
| feats = self._to_tensor(raw) |
| feats = F.normalize(feats, p=2, dim=-1).cpu().numpy() |
| for local_j, global_j in enumerate(valid_in_batch): |
| embeddings[global_j] = feats[local_j] |
| except Exception as e: |
| logger.warning(f"Batch encoding failed at {start}: {e}") |
|
|
| if self.device == "cuda" and start % (batch_size * 10) == 0: |
| torch.cuda.empty_cache() |
|
|
| return embeddings |
|
|
| @torch.no_grad() |
| def encode_images(self, images: List[Image.Image], batch_size: Optional[int] = None) -> np.ndarray: |
| batch_size = batch_size or self.config.embed_batch_size |
| all_emb = [] |
| for i in range(0, len(images), batch_size): |
| batch = images[i:i + batch_size] |
| inputs = self.processor(images=batch, return_tensors="pt", padding=True) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| if self.device == "cuda": |
| with torch.amp.autocast("cuda"): |
| raw = self.model.get_image_features(**inputs) |
| else: |
| raw = self.model.get_image_features(**inputs) |
| feats = self._to_tensor(raw) |
| all_emb.append(F.normalize(feats, p=2, dim=-1).cpu().numpy()) |
| return np.vstack(all_emb).astype(np.float32) |
|
|
| @torch.no_grad() |
| def encode_query_text(self, query: str) -> np.ndarray: |
| prompted = [tmpl.format(query) for tmpl in self.config.prompt_templates] |
| embeddings = self.encode_texts(prompted) |
| avg = embeddings.mean(axis=0, keepdims=True) |
| avg = avg / (np.linalg.norm(avg, axis=-1, keepdims=True) + 1e-8) |
| return avg.astype(np.float32) |
|
|
| @torch.no_grad() |
| def encode_multimodal_query( |
| self, text: str, image: Image.Image, text_weight: float = 0.5, |
| ) -> np.ndarray: |
| text_emb = self.encode_query_text(text) |
| img_emb = self.encode_images([image]) |
| fused = text_weight * text_emb + (1 - text_weight) * img_emb |
| fused = fused / (np.linalg.norm(fused, axis=-1, keepdims=True) + 1e-8) |
| return fused.astype(np.float32) |
|
|