Zero-Shot Image Classification
Transformers
Safetensors
English
clip
fashion
multimodal
image-search
text-search
embeddings
contrastive-learning
zero-shot-classification
Instructions to use Leacb4/gap-clip with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Leacb4/gap-clip with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("zero-shot-image-classification", model="Leacb4/gap-clip") pipe( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/hub/parrots.png", candidate_labels=["animals", "humans", "landscape"], )# Load model directly from transformers import AutoProcessor, AutoModelForZeroShotImageClassification processor = AutoProcessor.from_pretrained("Leacb4/gap-clip") model = AutoModelForZeroShotImageClassification.from_pretrained("Leacb4/gap-clip") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Shared embedding extraction utilities for GAP-CLIP evaluation scripts. | |
| Consolidates the batch embedding extraction logic that was duplicated across | |
| sec51, sec52, sec533, and sec536 into two reusable functions: | |
| - extract_clip_embeddings() — for any CLIP-based model (GAP-CLIP, Fashion-CLIP) | |
| - extract_color_model_embeddings() — for the specialized 16D ColorCLIP model | |
| """ | |
| from __future__ import annotations | |
| from typing import List, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms | |
| from tqdm import tqdm | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _batch_tensors_to_pil(images: torch.Tensor) -> list: | |
| """Convert a batch of ImageNet-normalised tensors back to PIL images. | |
| This is the shared denormalization logic that was duplicated in every | |
| evaluator's image-embedding extraction method. | |
| """ | |
| pil_images = [] | |
| for i in range(images.shape[0]): | |
| t = images[i] | |
| if t.min() < 0 or t.max() > 1: | |
| mean = torch.tensor([0.485, 0.456, 0.406], device=t.device).view(3, 1, 1) | |
| std = torch.tensor([0.229, 0.224, 0.225], device=t.device).view(3, 1, 1) | |
| t = torch.clamp(t * std + mean, 0, 1) | |
| pil_images.append(transforms.ToPILImage()(t.cpu())) | |
| return pil_images | |
| def _normalize_label(value: object, default: str = "unknown") -> str: | |
| """Convert label-like values to consistent non-empty strings.""" | |
| if value is None: | |
| return default | |
| # Handle pandas/NumPy missing values without importing pandas here. | |
| try: | |
| if bool(np.isnan(value)): # type: ignore[arg-type] | |
| return default | |
| except Exception: | |
| pass | |
| label = str(value).strip().lower() | |
| if not label or label in {"none", "nan"}: | |
| return default | |
| return label.replace("grey", "gray") | |
| # --------------------------------------------------------------------------- | |
| # CLIP-based embedding extraction (GAP-CLIP or Fashion-CLIP) | |
| # --------------------------------------------------------------------------- | |
| def extract_clip_embeddings( | |
| model, | |
| processor, | |
| dataloader: DataLoader, | |
| device: torch.device, | |
| embedding_type: str = "text", | |
| max_samples: int = 10_000, | |
| desc: str | None = None, | |
| ) -> Tuple[np.ndarray, List[str], List[str]]: | |
| """Extract L2-normalised embeddings from any CLIP-based model. | |
| Works with both 3-element batches ``(image, text, color)`` and 4-element | |
| batches ``(image, text, color, hierarchy)``. Always returns three lists | |
| (embeddings, colors, hierarchies); when the batch has no hierarchy column | |
| the third list is filled with ``"unknown"``. | |
| Args: | |
| model: A ``CLIPModel`` (GAP-CLIP, Fashion-CLIP, etc.). | |
| processor: Matching ``CLIPProcessor``. | |
| dataloader: PyTorch DataLoader yielding 3- or 4-element tuples. | |
| device: Target torch device. | |
| embedding_type: ``"text"`` or ``"image"``. | |
| max_samples: Stop after collecting this many samples. | |
| desc: Optional tqdm description override. | |
| Returns: | |
| ``(embeddings, colors, hierarchies)`` where *embeddings* is an | |
| ``(N, D)`` numpy array and the other two are lists of strings. | |
| """ | |
| if desc is None: | |
| desc = f"Extracting {embedding_type} embeddings" | |
| all_embeddings: list[np.ndarray] = [] | |
| all_colors: list[str] = [] | |
| all_hierarchies: list[str] = [] | |
| sample_count = 0 | |
| with torch.no_grad(): | |
| for batch in tqdm(dataloader, desc=desc): | |
| if sample_count >= max_samples: | |
| break | |
| # Support both 3-element and 4-element batch tuples | |
| if len(batch) == 4: | |
| images, texts, colors, hierarchies = batch | |
| else: | |
| images, texts, colors = batch | |
| hierarchies = ["unknown"] * len(colors) | |
| images = images.to(device).expand(-1, 3, -1, -1) | |
| if embedding_type == "image": | |
| pil_images = _batch_tensors_to_pil(images) | |
| inputs = processor(images=pil_images, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| emb = model.get_image_features(**inputs) | |
| else: | |
| inputs = processor( | |
| text=list(texts), | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=77, | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| emb = model.get_text_features(**inputs) | |
| emb = F.normalize(emb, dim=-1) | |
| all_embeddings.append(emb.cpu().numpy()) | |
| all_colors.extend(_normalize_label(c) for c in colors) | |
| all_hierarchies.extend(_normalize_label(h) for h in hierarchies) | |
| sample_count += len(images) | |
| del images, emb | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return np.vstack(all_embeddings), all_colors, all_hierarchies | |
| # --------------------------------------------------------------------------- | |
| # Specialized ColorCLIP embedding extraction | |
| # --------------------------------------------------------------------------- | |
| def extract_color_model_embeddings( | |
| color_model, | |
| dataloader: DataLoader, | |
| device: torch.device, | |
| embedding_type: str = "text", | |
| max_samples: int = 10_000, | |
| desc: str | None = None, | |
| ) -> Tuple[np.ndarray, List[str]]: | |
| """Extract L2-normalised embeddings from the 16D ColorCLIP model. | |
| Args: | |
| color_model: A ``ColorCLIP`` instance. | |
| dataloader: DataLoader yielding at least ``(image, text, color, ...)``. | |
| device: Target torch device. | |
| embedding_type: ``"text"`` or ``"image"``. | |
| max_samples: Stop after collecting this many samples. | |
| desc: Optional tqdm description override. | |
| Returns: | |
| ``(embeddings, colors)`` — embeddings is ``(N, 16)`` numpy array. | |
| """ | |
| if desc is None: | |
| desc = f"Extracting {embedding_type} color-model embeddings" | |
| all_embeddings: list[np.ndarray] = [] | |
| all_colors: list[str] = [] | |
| sample_count = 0 | |
| with torch.no_grad(): | |
| for batch in tqdm(dataloader, desc=desc): | |
| if sample_count >= max_samples: | |
| break | |
| images, texts, colors = batch[0], batch[1], batch[2] | |
| images = images.to(device).expand(-1, 3, -1, -1) | |
| if embedding_type == "text": | |
| emb = color_model.get_text_embeddings(list(texts)) | |
| else: | |
| emb = color_model.get_image_embeddings(images) | |
| emb = F.normalize(emb, dim=-1) | |
| all_embeddings.append(emb.cpu().numpy()) | |
| normalized_colors = [ | |
| str(c).lower().strip().replace("grey", "gray") for c in colors | |
| ] | |
| all_colors.extend(normalized_colors) | |
| sample_count += len(images) | |
| return np.vstack(all_embeddings), all_colors | |