File size: 7,193 Bytes
4807234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""
Shared embedding extraction utilities for GAP-CLIP evaluation scripts.

Consolidates the batch embedding extraction logic that was duplicated across
sec51, sec52, sec533, and sec536 into two reusable functions:

  - extract_clip_embeddings()         — for any CLIP-based model (GAP-CLIP, Fashion-CLIP)
  - extract_color_model_embeddings()  — for the specialized 16D ColorCLIP model
"""

from __future__ import annotations

from typing import List, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _batch_tensors_to_pil(images: torch.Tensor) -> list:
    """Convert a batch of ImageNet-normalised tensors back to PIL images.

    This is the shared denormalization logic that was duplicated in every
    evaluator's image-embedding extraction method.
    """
    pil_images = []
    for i in range(images.shape[0]):
        t = images[i]
        if t.min() < 0 or t.max() > 1:
            mean = torch.tensor([0.485, 0.456, 0.406], device=t.device).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225], device=t.device).view(3, 1, 1)
            t = torch.clamp(t * std + mean, 0, 1)
        pil_images.append(transforms.ToPILImage()(t.cpu()))
    return pil_images


def _normalize_label(value: object, default: str = "unknown") -> str:
    """Convert label-like values to consistent non-empty strings."""
    if value is None:
        return default

    # Handle pandas/NumPy missing values without importing pandas here.
    try:
        if bool(np.isnan(value)):  # type: ignore[arg-type]
            return default
    except Exception:
        pass

    label = str(value).strip().lower()
    if not label or label in {"none", "nan"}:
        return default
    return label.replace("grey", "gray")


# ---------------------------------------------------------------------------
# CLIP-based embedding extraction (GAP-CLIP or Fashion-CLIP)
# ---------------------------------------------------------------------------

def extract_clip_embeddings(
    model,
    processor,
    dataloader: DataLoader,
    device: torch.device,
    embedding_type: str = "text",
    max_samples: int = 10_000,
    desc: str | None = None,
) -> Tuple[np.ndarray, List[str], List[str]]:
    """Extract L2-normalised embeddings from any CLIP-based model.

    Works with both 3-element batches ``(image, text, color)`` and 4-element
    batches ``(image, text, color, hierarchy)``.  Always returns three lists
    (embeddings, colors, hierarchies); when the batch has no hierarchy column
    the third list is filled with ``"unknown"``.

    Args:
        model: A ``CLIPModel`` (GAP-CLIP, Fashion-CLIP, etc.).
        processor: Matching ``CLIPProcessor``.
        dataloader: PyTorch DataLoader yielding 3- or 4-element tuples.
        device: Target torch device.
        embedding_type: ``"text"`` or ``"image"``.
        max_samples: Stop after collecting this many samples.
        desc: Optional tqdm description override.

    Returns:
        ``(embeddings, colors, hierarchies)`` where *embeddings* is an
        ``(N, D)`` numpy array and the other two are lists of strings.
    """
    if desc is None:
        desc = f"Extracting {embedding_type} embeddings"

    all_embeddings: list[np.ndarray] = []
    all_colors: list[str] = []
    all_hierarchies: list[str] = []
    sample_count = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc=desc):
            if sample_count >= max_samples:
                break

            # Support both 3-element and 4-element batch tuples
            if len(batch) == 4:
                images, texts, colors, hierarchies = batch
            else:
                images, texts, colors = batch
                hierarchies = ["unknown"] * len(colors)

            images = images.to(device).expand(-1, 3, -1, -1)

            if embedding_type == "image":
                pil_images = _batch_tensors_to_pil(images)
                inputs = processor(images=pil_images, return_tensors="pt")
                inputs = {k: v.to(device) for k, v in inputs.items()}
                emb = model.get_image_features(**inputs)
            else:
                inputs = processor(
                    text=list(texts),
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=77,
                )
                inputs = {k: v.to(device) for k, v in inputs.items()}
                emb = model.get_text_features(**inputs)

            emb = F.normalize(emb, dim=-1)

            all_embeddings.append(emb.cpu().numpy())
            all_colors.extend(_normalize_label(c) for c in colors)
            all_hierarchies.extend(_normalize_label(h) for h in hierarchies)
            sample_count += len(images)

            del images, emb
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    return np.vstack(all_embeddings), all_colors, all_hierarchies


# ---------------------------------------------------------------------------
# Specialized ColorCLIP embedding extraction
# ---------------------------------------------------------------------------

def extract_color_model_embeddings(
    color_model,
    dataloader: DataLoader,
    device: torch.device,
    embedding_type: str = "text",
    max_samples: int = 10_000,
    desc: str | None = None,
) -> Tuple[np.ndarray, List[str]]:
    """Extract L2-normalised embeddings from the 16D ColorCLIP model.

    Args:
        color_model: A ``ColorCLIP`` instance.
        dataloader: DataLoader yielding at least ``(image, text, color, ...)``.
        device: Target torch device.
        embedding_type: ``"text"`` or ``"image"``.
        max_samples: Stop after collecting this many samples.
        desc: Optional tqdm description override.

    Returns:
        ``(embeddings, colors)`` — embeddings is ``(N, 16)`` numpy array.
    """
    if desc is None:
        desc = f"Extracting {embedding_type} color-model embeddings"

    all_embeddings: list[np.ndarray] = []
    all_colors: list[str] = []
    sample_count = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc=desc):
            if sample_count >= max_samples:
                break

            images, texts, colors = batch[0], batch[1], batch[2]
            images = images.to(device).expand(-1, 3, -1, -1)

            if embedding_type == "text":
                emb = color_model.get_text_embeddings(list(texts))
            else:
                emb = color_model.get_image_embeddings(images)
            emb = F.normalize(emb, dim=-1)

            all_embeddings.append(emb.cpu().numpy())
            normalized_colors = [
                str(c).lower().strip().replace("grey", "gray") for c in colors
            ]
            all_colors.extend(normalized_colors)
            sample_count += len(images)

    return np.vstack(all_embeddings), all_colors