Zero-Shot Image Classification
Transformers
Safetensors
English
clip
fashion
multimodal
image-search
text-search
embeddings
contrastive-learning
zero-shot-classification
Instructions to use Leacb4/gap-clip with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Leacb4/gap-clip with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("zero-shot-image-classification", model="Leacb4/gap-clip") pipe( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/hub/parrots.png", candidate_labels=["animals", "humans", "landscape"], )# Load model directly from transformers import AutoProcessor, AutoModelForZeroShotImageClassification processor = AutoProcessor.from_pretrained("Leacb4/gap-clip") model = AutoModelForZeroShotImageClassification.from_pretrained("Leacb4/gap-clip") - Notebooks
- Google Colab
- Kaggle
File size: 7,193 Bytes
4807234 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | """
Shared embedding extraction utilities for GAP-CLIP evaluation scripts.
Consolidates the batch embedding extraction logic that was duplicated across
sec51, sec52, sec533, and sec536 into two reusable functions:
- extract_clip_embeddings() — for any CLIP-based model (GAP-CLIP, Fashion-CLIP)
- extract_color_model_embeddings() — for the specialized 16D ColorCLIP model
"""
from __future__ import annotations
from typing import List, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _batch_tensors_to_pil(images: torch.Tensor) -> list:
"""Convert a batch of ImageNet-normalised tensors back to PIL images.
This is the shared denormalization logic that was duplicated in every
evaluator's image-embedding extraction method.
"""
pil_images = []
for i in range(images.shape[0]):
t = images[i]
if t.min() < 0 or t.max() > 1:
mean = torch.tensor([0.485, 0.456, 0.406], device=t.device).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225], device=t.device).view(3, 1, 1)
t = torch.clamp(t * std + mean, 0, 1)
pil_images.append(transforms.ToPILImage()(t.cpu()))
return pil_images
def _normalize_label(value: object, default: str = "unknown") -> str:
"""Convert label-like values to consistent non-empty strings."""
if value is None:
return default
# Handle pandas/NumPy missing values without importing pandas here.
try:
if bool(np.isnan(value)): # type: ignore[arg-type]
return default
except Exception:
pass
label = str(value).strip().lower()
if not label or label in {"none", "nan"}:
return default
return label.replace("grey", "gray")
# ---------------------------------------------------------------------------
# CLIP-based embedding extraction (GAP-CLIP or Fashion-CLIP)
# ---------------------------------------------------------------------------
def extract_clip_embeddings(
model,
processor,
dataloader: DataLoader,
device: torch.device,
embedding_type: str = "text",
max_samples: int = 10_000,
desc: str | None = None,
) -> Tuple[np.ndarray, List[str], List[str]]:
"""Extract L2-normalised embeddings from any CLIP-based model.
Works with both 3-element batches ``(image, text, color)`` and 4-element
batches ``(image, text, color, hierarchy)``. Always returns three lists
(embeddings, colors, hierarchies); when the batch has no hierarchy column
the third list is filled with ``"unknown"``.
Args:
model: A ``CLIPModel`` (GAP-CLIP, Fashion-CLIP, etc.).
processor: Matching ``CLIPProcessor``.
dataloader: PyTorch DataLoader yielding 3- or 4-element tuples.
device: Target torch device.
embedding_type: ``"text"`` or ``"image"``.
max_samples: Stop after collecting this many samples.
desc: Optional tqdm description override.
Returns:
``(embeddings, colors, hierarchies)`` where *embeddings* is an
``(N, D)`` numpy array and the other two are lists of strings.
"""
if desc is None:
desc = f"Extracting {embedding_type} embeddings"
all_embeddings: list[np.ndarray] = []
all_colors: list[str] = []
all_hierarchies: list[str] = []
sample_count = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc=desc):
if sample_count >= max_samples:
break
# Support both 3-element and 4-element batch tuples
if len(batch) == 4:
images, texts, colors, hierarchies = batch
else:
images, texts, colors = batch
hierarchies = ["unknown"] * len(colors)
images = images.to(device).expand(-1, 3, -1, -1)
if embedding_type == "image":
pil_images = _batch_tensors_to_pil(images)
inputs = processor(images=pil_images, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
emb = model.get_image_features(**inputs)
else:
inputs = processor(
text=list(texts),
return_tensors="pt",
padding=True,
truncation=True,
max_length=77,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
emb = model.get_text_features(**inputs)
emb = F.normalize(emb, dim=-1)
all_embeddings.append(emb.cpu().numpy())
all_colors.extend(_normalize_label(c) for c in colors)
all_hierarchies.extend(_normalize_label(h) for h in hierarchies)
sample_count += len(images)
del images, emb
if torch.cuda.is_available():
torch.cuda.empty_cache()
return np.vstack(all_embeddings), all_colors, all_hierarchies
# ---------------------------------------------------------------------------
# Specialized ColorCLIP embedding extraction
# ---------------------------------------------------------------------------
def extract_color_model_embeddings(
color_model,
dataloader: DataLoader,
device: torch.device,
embedding_type: str = "text",
max_samples: int = 10_000,
desc: str | None = None,
) -> Tuple[np.ndarray, List[str]]:
"""Extract L2-normalised embeddings from the 16D ColorCLIP model.
Args:
color_model: A ``ColorCLIP`` instance.
dataloader: DataLoader yielding at least ``(image, text, color, ...)``.
device: Target torch device.
embedding_type: ``"text"`` or ``"image"``.
max_samples: Stop after collecting this many samples.
desc: Optional tqdm description override.
Returns:
``(embeddings, colors)`` — embeddings is ``(N, 16)`` numpy array.
"""
if desc is None:
desc = f"Extracting {embedding_type} color-model embeddings"
all_embeddings: list[np.ndarray] = []
all_colors: list[str] = []
sample_count = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc=desc):
if sample_count >= max_samples:
break
images, texts, colors = batch[0], batch[1], batch[2]
images = images.to(device).expand(-1, 3, -1, -1)
if embedding_type == "text":
emb = color_model.get_text_embeddings(list(texts))
else:
emb = color_model.get_image_embeddings(images)
emb = F.normalize(emb, dim=-1)
all_embeddings.append(emb.cpu().numpy())
normalized_colors = [
str(c).lower().strip().replace("grey", "gray") for c in colors
]
all_colors.extend(normalized_colors)
sample_count += len(images)
return np.vstack(all_embeddings), all_colors
|