File size: 8,027 Bytes
d992912
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""
engine/encoder.py

FashionCLIPEncoder β€” wraps a HuggingFace CLIP model for text and image encoding.
Extracted from finalized_search_engine_full_script.py (lines 482-652).
"""

import logging
from pathlib import Path
from typing import List, Optional

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import CLIPModel, CLIPProcessor

from backend.app.config import SearchConfig

logger = logging.getLogger(__name__)

__all__ = ["FashionCLIPEncoder"]


class FashionCLIPEncoder:
    """
    v3.1 β€” Handles models that return BaseModelOutputWithPooling
    instead of raw tensors from get_text_features / get_image_features.
    """

    def __init__(self, config: SearchConfig):
        self.config = config
        self.device = config.device
        self.model = None
        self.processor = None
        self.model_name = None
        self._load_model()

    def _load_model(self):
        models_to_try = [self.config.primary_model, self.config.fallback_model]
        for model_name in models_to_try:
            try:
                logger.info(f"Loading model: {model_name}")
                kwargs = {}
                if self.config.hf_token:
                    kwargs['token'] = self.config.hf_token
                self.model = CLIPModel.from_pretrained(model_name, **kwargs)
                self.processor = CLIPProcessor.from_pretrained(model_name, **kwargs)
                self.model = self.model.to(self.device)
                self.model.eval()
                self.model_name = model_name

                # ── Probe the model to find actual embedding dim ──
                test_inputs = self.processor(
                    text=["test"], return_tensors="pt",
                    padding=True, truncation=True, max_length=77,
                )
                test_inputs = {k: v.to(self.device) for k, v in test_inputs.items()}
                with torch.no_grad():
                    test_out = self.model.get_text_features(**test_inputs)
                    test_tensor = self._to_tensor(test_out)
                actual_dim = test_tensor.shape[-1]
                if actual_dim != self.config.embedding_dim:
                    logger.info(
                        f"Model embedding dim = {actual_dim} "
                        f"(config said {self.config.embedding_dim}). Updating config."
                    )
                    self.config.embedding_dim = actual_dim

                logger.info(f"Model loaded: {model_name} on {self.device} (dim={actual_dim})")
                return
            except Exception as e:
                logger.warning(f"Failed to load {model_name}: {e}")
                continue
        raise RuntimeError(
            "Could not load any CLIP model. Check internet connection and HF_TOKEN."
        )

    @staticmethod
    def _to_tensor(output) -> torch.Tensor:
        if isinstance(output, torch.Tensor):
            return output
        if hasattr(output, 'pooler_output') and output.pooler_output is not None:
            return output.pooler_output
        if hasattr(output, 'last_hidden_state'):
            return output.last_hidden_state.mean(dim=1)
        if hasattr(output, 'text_embeds'):
            return output.text_embeds
        if hasattr(output, 'image_embeds'):
            return output.image_embeds
        if isinstance(output, (tuple, list)) and len(output) > 0:
            return output[0] if isinstance(output[0], torch.Tensor) else output[1]
        raise TypeError(
            f"Cannot extract tensor from model output of type {type(output)}. "
            f"Available attributes: {[a for a in dir(output) if not a.startswith('_')]}"
        )

    @torch.no_grad()
    def encode_texts(self, texts: List[str], batch_size: Optional[int] = None) -> np.ndarray:
        batch_size = batch_size or min(self.config.embed_batch_size * 4, 256)
        texts = [str(t) if t and str(t) != 'nan' else '' for t in texts]
        all_emb = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            inputs = self.processor(
                text=batch, return_tensors="pt",
                padding=True, truncation=True, max_length=77,
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            raw = self.model.get_text_features(**inputs)
            feats = self._to_tensor(raw)
            feats = F.normalize(feats, p=2, dim=-1).cpu().numpy()
            all_emb.append(feats)
        return np.vstack(all_emb).astype(np.float32)

    @torch.no_grad()
    def encode_images_from_paths(
        self, paths: List[Path], batch_size: Optional[int] = None,
    ) -> np.ndarray:
        batch_size = batch_size or self.config.embed_batch_size
        n = len(paths)
        dim = self.config.embedding_dim
        embeddings = np.zeros((n, dim), dtype=np.float32)

        for start in range(0, n, batch_size):
            end = min(start + batch_size, n)
            batch_paths = paths[start:end]

            images = []
            valid_in_batch = []
            for j, p in enumerate(batch_paths):
                try:
                    img = Image.open(p).convert("RGB")
                    images.append(img)
                    valid_in_batch.append(start + j)
                except Exception:
                    pass

            if not images:
                continue

            try:
                inputs = self.processor(images=images, return_tensors="pt", padding=True)
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                if self.device == "cuda":
                    with torch.amp.autocast("cuda"):
                        raw = self.model.get_image_features(**inputs)
                else:
                    raw = self.model.get_image_features(**inputs)
                feats = self._to_tensor(raw)
                feats = F.normalize(feats, p=2, dim=-1).cpu().numpy()
                for local_j, global_j in enumerate(valid_in_batch):
                    embeddings[global_j] = feats[local_j]
            except Exception as e:
                logger.warning(f"Batch encoding failed at {start}: {e}")

            if self.device == "cuda" and start % (batch_size * 10) == 0:
                torch.cuda.empty_cache()

        return embeddings

    @torch.no_grad()
    def encode_images(self, images: List[Image.Image], batch_size: Optional[int] = None) -> np.ndarray:
        batch_size = batch_size or self.config.embed_batch_size
        all_emb = []
        for i in range(0, len(images), batch_size):
            batch = images[i:i + batch_size]
            inputs = self.processor(images=batch, return_tensors="pt", padding=True)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            if self.device == "cuda":
                with torch.amp.autocast("cuda"):
                    raw = self.model.get_image_features(**inputs)
            else:
                raw = self.model.get_image_features(**inputs)
            feats = self._to_tensor(raw)
            all_emb.append(F.normalize(feats, p=2, dim=-1).cpu().numpy())
        return np.vstack(all_emb).astype(np.float32)

    @torch.no_grad()
    def encode_query_text(self, query: str) -> np.ndarray:
        prompted = [tmpl.format(query) for tmpl in self.config.prompt_templates]
        embeddings = self.encode_texts(prompted)
        avg = embeddings.mean(axis=0, keepdims=True)
        avg = avg / (np.linalg.norm(avg, axis=-1, keepdims=True) + 1e-8)
        return avg.astype(np.float32)

    @torch.no_grad()
    def encode_multimodal_query(
        self, text: str, image: Image.Image, text_weight: float = 0.5,
    ) -> np.ndarray:
        text_emb = self.encode_query_text(text)
        img_emb = self.encode_images([image])
        fused = text_weight * text_emb + (1 - text_weight) * img_emb
        fused = fused / (np.linalg.norm(fused, axis=-1, keepdims=True) + 1e-8)
        return fused.astype(np.float32)