Zero-Shot Image Classification
Transformers
Safetensors
English
clip
fashion
multimodal
image-search
text-search
embeddings
contrastive-learning
zero-shot-classification
Instructions to use Leacb4/gap-clip with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Leacb4/gap-clip with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("zero-shot-image-classification", model="Leacb4/gap-clip") pipe( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/hub/parrots.png", candidate_labels=["animals", "humans", "landscape"], )# Load model directly from transformers import AutoProcessor, AutoModelForZeroShotImageClassification processor = AutoProcessor.from_pretrained("Leacb4/gap-clip") model = AutoModelForZeroShotImageClassification.from_pretrained("Leacb4/gap-clip") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Shared evaluation metrics for GAP-CLIP experiments. | |
| Provides nearest-neighbor accuracy, separation score, centroid-based accuracy, | |
| and confusion matrix generation — used across all evaluation sections. | |
| """ | |
| from __future__ import annotations | |
| from collections import defaultdict | |
| from typing import List, Optional, Tuple | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import seaborn as sns | |
| from sklearn.metrics import accuracy_score, classification_report, confusion_matrix | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from sklearn.preprocessing import normalize | |
| def compute_similarity_metrics( | |
| embeddings: np.ndarray, | |
| labels: List[str], | |
| max_samples: int = 5000, | |
| ) -> dict: | |
| """Compute intra/inter-class similarities and nearest-neighbor accuracy. | |
| Uses vectorized numpy operations for efficiency. | |
| Args: | |
| embeddings: Array of shape (N, D). | |
| labels: List of N class labels. | |
| max_samples: Cap for large datasets (random subsample). | |
| Returns: | |
| Dict with keys: intra_class_mean, inter_class_mean, separation_score, | |
| accuracy (NN), centroid_accuracy, intra_class_similarities, | |
| inter_class_similarities. | |
| """ | |
| if len(embeddings) > max_samples: | |
| indices = np.random.choice(len(embeddings), max_samples, replace=False) | |
| embeddings = embeddings[indices] | |
| labels = [labels[i] for i in indices] | |
| similarities = cosine_similarity(embeddings) | |
| label_array = np.array(labels) | |
| unique_labels = np.unique(label_array) | |
| label_groups = {label: np.where(label_array == label)[0] for label in unique_labels} | |
| intra_class_similarities: List[float] = [] | |
| for indices in label_groups.values(): | |
| if len(indices) > 1: | |
| sub = similarities[np.ix_(indices, indices)] | |
| triu = np.triu_indices_from(sub, k=1) | |
| intra_class_similarities.extend(sub[triu].tolist()) | |
| inter_class_similarities: List[float] = [] | |
| keys = list(label_groups.keys()) | |
| for i in range(len(keys)): | |
| for j in range(i + 1, len(keys)): | |
| inter = similarities[np.ix_(label_groups[keys[i]], label_groups[keys[j]])] | |
| inter_class_similarities.extend(inter.flatten().tolist()) | |
| nn_acc = compute_embedding_accuracy(embeddings, labels, similarities) | |
| centroid_acc = compute_centroid_accuracy(embeddings, labels) | |
| return { | |
| "intra_class_similarities": intra_class_similarities, | |
| "inter_class_similarities": inter_class_similarities, | |
| "intra_class_mean": float(np.mean(intra_class_similarities)) if intra_class_similarities else 0.0, | |
| "inter_class_mean": float(np.mean(inter_class_similarities)) if inter_class_similarities else 0.0, | |
| "separation_score": ( | |
| float(np.mean(intra_class_similarities) - np.mean(inter_class_similarities)) | |
| if intra_class_similarities and inter_class_similarities | |
| else 0.0 | |
| ), | |
| "accuracy": nn_acc, | |
| "centroid_accuracy": centroid_acc, | |
| } | |
| def compute_embedding_accuracy( | |
| embeddings: np.ndarray, | |
| labels: List[str], | |
| similarities: Optional[np.ndarray] = None, | |
| ) -> float: | |
| """Nearest-neighbor classification accuracy (leave-one-out). | |
| Args: | |
| embeddings: Array of shape (N, D). | |
| labels: List of N class labels. | |
| similarities: Pre-computed cosine similarity matrix (N, N). Computed | |
| if not provided. | |
| Returns: | |
| Fraction of samples whose nearest neighbor shares their label. | |
| """ | |
| n = len(embeddings) | |
| if n == 0: | |
| return 0.0 | |
| if similarities is None: | |
| similarities = cosine_similarity(embeddings) | |
| correct = 0 | |
| for i in range(n): | |
| sims = similarities[i].copy() | |
| sims[i] = -1.0 | |
| if labels[np.argmax(sims)] == labels[i]: | |
| correct += 1 | |
| return correct / n | |
| def compute_centroid_accuracy( | |
| embeddings: np.ndarray, | |
| labels: List[str], | |
| ) -> float: | |
| """Centroid-based (1-NN centroid) classification accuracy. | |
| Uses L2-normalized embeddings and centroids for correct cosine comparison. | |
| Args: | |
| embeddings: Array of shape (N, D). | |
| labels: List of N class labels. | |
| Returns: | |
| Fraction of samples classified correctly by nearest centroid. | |
| """ | |
| if len(embeddings) == 0: | |
| return 0.0 | |
| emb_norm = normalize(embeddings, norm="l2") | |
| unique_labels = sorted(set(labels)) | |
| centroids = {} | |
| for label in unique_labels: | |
| idx = [i for i, l in enumerate(labels) if l == label] | |
| centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm="l2")[0] | |
| centroid_labels = list(centroids.keys()) | |
| centroid_matrix = np.vstack([centroids[l] for l in centroid_labels]) | |
| sims = cosine_similarity(emb_norm, centroid_matrix) | |
| predicted = [centroid_labels[int(np.argmax(row))] for row in sims] | |
| return sum(p == t for p, t in zip(predicted, labels)) / len(labels) | |
| def predict_labels_from_embeddings( | |
| embeddings: np.ndarray, | |
| labels: List[str], | |
| ) -> List[str]: | |
| """Predict a label for each embedding using nearest centroid. | |
| Returns: | |
| List of predicted labels (same length as embeddings). | |
| """ | |
| valid_labels = [l for l in set(labels) if l is not None] | |
| if not valid_labels: | |
| return [None] * len(embeddings) | |
| emb_norm = normalize(embeddings, norm="l2") | |
| centroids = {} | |
| for label in valid_labels: | |
| mask = np.array(labels) == label | |
| if np.any(mask): | |
| centroids[label] = np.mean(emb_norm[mask], axis=0) | |
| centroid_labels = list(centroids.keys()) | |
| centroid_matrix = np.vstack([centroids[l] for l in centroid_labels]) | |
| sims = cosine_similarity(emb_norm, centroid_matrix) | |
| return [centroid_labels[int(np.argmax(row))] for row in sims] | |
| def compute_worst_group_accuracy( | |
| embeddings: np.ndarray, | |
| labels: List[str], | |
| groups: List[str], | |
| ) -> Tuple[float, str, dict]: | |
| """Compute per-group nearest-neighbor accuracy and return the worst. | |
| This is the key metric for DFR (Kirichenko et al. 2023): it measures | |
| robustness to spurious correlations by reporting the accuracy of the | |
| worst-performing (color x hierarchy) group. | |
| Args: | |
| embeddings: Array of shape (N, D). | |
| labels: List of N class labels (what we classify). | |
| groups: List of N group labels (e.g. 'red_dress', 'blue_jeans'). | |
| Returns: | |
| (worst_accuracy, worst_group_name, per_group_dict) | |
| where per_group_dict maps group -> accuracy. | |
| """ | |
| similarities = cosine_similarity(embeddings) | |
| label_array = np.array(labels) | |
| group_array = np.array(groups) | |
| unique_groups = np.unique(group_array) | |
| per_group: dict = {} | |
| for g in unique_groups: | |
| mask = group_array == g | |
| idxs = np.where(mask)[0] | |
| if len(idxs) < 2: | |
| continue | |
| correct = 0 | |
| for i in idxs: | |
| sims = similarities[i].copy() | |
| sims[i] = -1.0 | |
| nn_idx = np.argmax(sims) | |
| if label_array[nn_idx] == label_array[i]: | |
| correct += 1 | |
| per_group[g] = correct / len(idxs) | |
| if not per_group: | |
| return 0.0, "", {} | |
| worst_group = min(per_group, key=per_group.get) | |
| return per_group[worst_group], worst_group, per_group | |
| def create_confusion_matrix( | |
| true_labels: List[str], | |
| predicted_labels: List[str], | |
| title: str = "Confusion Matrix", | |
| label_type: str = "Label", | |
| ) -> Tuple[plt.Figure, float, np.ndarray]: | |
| """Create and return a seaborn confusion-matrix heatmap figure. | |
| Args: | |
| true_labels: Ground-truth labels. | |
| predicted_labels: Predicted labels. | |
| title: Plot title prefix. | |
| label_type: Axis label (e.g. "Color", "Category"). | |
| Returns: | |
| (fig, accuracy, cm_array) | |
| """ | |
| unique_labels = sorted(set(true_labels + predicted_labels)) | |
| cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels) | |
| acc = accuracy_score(true_labels, predicted_labels) | |
| fig = plt.figure(figsize=(10, 8)) | |
| sns.heatmap( | |
| cm, | |
| annot=True, | |
| fmt="d", | |
| cmap="Blues", | |
| xticklabels=unique_labels, | |
| yticklabels=unique_labels, | |
| ) | |
| plt.title(f"{title}\nAccuracy: {acc:.3f} ({acc * 100:.1f}%)") | |
| plt.ylabel(f"True {label_type}") | |
| plt.xlabel(f"Predicted {label_type}") | |
| plt.xticks(rotation=45) | |
| plt.yticks(rotation=0) | |
| plt.tight_layout() | |
| return fig, acc, cm | |