Zero-Shot Image Classification
Transformers
Safetensors
English
clip
fashion
multimodal
image-search
text-search
embeddings
contrastive-learning
zero-shot-classification
Instructions to use Leacb4/gap-clip with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Leacb4/gap-clip with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("zero-shot-image-classification", model="Leacb4/gap-clip") pipe( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/hub/parrots.png", candidate_labels=["animals", "humans", "landscape"], )# Load model directly from transformers import AutoProcessor, AutoModelForZeroShotImageClassification processor = AutoProcessor.from_pretrained("Leacb4/gap-clip") model = AutoModelForZeroShotImageClassification.from_pretrained("Leacb4/gap-clip") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Section 5.2 — Category Model Evaluation (Table 2) | |
| ================================================== | |
| Evaluates GAP-CLIP vs the Fashion-CLIP baseline on hierarchy (category) | |
| classification using three datasets: | |
| - Fashion-MNIST (10 categories) | |
| - KAGL Marqo (external, real-world fashion e-commerce) | |
| - Internal validation dataset | |
| Produces hierarchy confusion matrices (text + image) for both models on each | |
| dataset. | |
| Metrics match Table 2 in the paper: | |
| - Text/image embedding NN accuracy | |
| - Text/image embedding separation score | |
| Run directly: | |
| python sec52_category_model_eval.py | |
| Paper reference: Section 5.2, Table 2. | |
| """ | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import difflib | |
| from collections import defaultdict | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from sklearn.metrics import classification_report, accuracy_score | |
| from sklearn.preprocessing import normalize | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from PIL import Image | |
| from io import BytesIO | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| from config import ( | |
| ROOT_DIR, | |
| main_model_path, | |
| main_emb_dim, | |
| hierarchy_model_path, | |
| color_emb_dim, | |
| hierarchy_emb_dim, | |
| local_dataset_path, | |
| column_local_image_path, | |
| ) | |
| from utils.datasets import ( | |
| load_fashion_mnist_dataset, | |
| ) | |
| from utils.embeddings import extract_clip_embeddings | |
| from utils.metrics import ( | |
| compute_similarity_metrics, | |
| compute_embedding_accuracy, | |
| compute_centroid_accuracy, | |
| predict_labels_from_embeddings, | |
| create_confusion_matrix, | |
| ) | |
| from utils.model_loader import load_gap_clip, load_baseline_fashion_clip | |
| # ============================================================================ | |
| # 1b. KAGL Marqo utilities | |
| # ============================================================================ | |
| class KaggleHierarchyDataset(Dataset): | |
| """KAGL Marqo dataset returning (image, description, color, hierarchy).""" | |
| def __init__(self, dataframe, image_size=224): | |
| self.dataframe = dataframe.reset_index(drop=True) | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def __len__(self): | |
| return len(self.dataframe) | |
| def __getitem__(self, idx): | |
| row = self.dataframe.iloc[idx] | |
| image_data = row["image"] | |
| if isinstance(image_data, dict) and "bytes" in image_data: | |
| image = Image.open(BytesIO(image_data["bytes"])).convert("RGB") | |
| elif hasattr(image_data, "convert"): | |
| image = image_data.convert("RGB") | |
| else: | |
| image = Image.open(BytesIO(image_data)).convert("RGB") | |
| image = self.transform(image) | |
| description = str(row["text"]) | |
| color = str(row.get("baseColour", "unknown")).lower() | |
| hierarchy = str(row["hierarchy"]) | |
| return image, description, color, hierarchy | |
| def load_kaggle_marqo_with_hierarchy(max_samples=10000, hierarchy_classes=None, raw_df=None): | |
| """Load KAGL Marqo dataset with hierarchy labels derived from articleType. | |
| Args: | |
| raw_df: Pre-downloaded DataFrame to skip the HuggingFace download. | |
| """ | |
| if raw_df is not None: | |
| df = raw_df.copy() | |
| print(f"Using cached KAGL DataFrame for hierarchy evaluation: {len(df)} samples") | |
| else: | |
| from datasets import load_dataset | |
| print("Loading KAGL Marqo dataset for hierarchy evaluation...") | |
| dataset = load_dataset("Marqo/KAGL") | |
| df = dataset["data"].to_pandas() | |
| print(f"Dataset loaded: {len(df)} samples, columns: {list(df.columns)}") | |
| # Use the most specific category column as hierarchy source | |
| hierarchy_col = 'category2' | |
| if hierarchy_col is None: | |
| print("WARNING: No hierarchy column found in KAGL dataset") | |
| return None | |
| print(f"Using '{hierarchy_col}' as hierarchy source") | |
| df = df.dropna(subset=["text", "image", hierarchy_col]) | |
| df["hierarchy"] = df[hierarchy_col].astype(str).str.strip() | |
| # If hierarchy_classes provided, map KAGL types to model hierarchy classes | |
| if hierarchy_classes: | |
| hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] | |
| mapped = [] | |
| for _, row in df.iterrows(): | |
| kagl_type = row["hierarchy"].lower() | |
| matched = None | |
| # Exact match | |
| if kagl_type in hierarchy_classes_lower: | |
| matched = hierarchy_classes[hierarchy_classes_lower.index(kagl_type)] | |
| else: | |
| # Substring match | |
| for h_class in hierarchy_classes: | |
| h_lower = h_class.lower() | |
| if h_lower in kagl_type or kagl_type in h_lower: | |
| matched = h_class | |
| break | |
| if matched is None: | |
| close = difflib.get_close_matches(kagl_type, hierarchy_classes_lower, n=1, cutoff=0.6) | |
| if close: | |
| matched = hierarchy_classes[hierarchy_classes_lower.index(close[0])] | |
| mapped.append(matched) | |
| df["hierarchy"] = mapped | |
| df = df.dropna(subset=["hierarchy"]) | |
| print(f"After hierarchy mapping: {len(df)} samples") | |
| if len(df) > max_samples: | |
| df = df.sample(n=max_samples, random_state=42) | |
| print(f"Using {len(df)} samples, {df['hierarchy'].nunique()} hierarchy classes: " | |
| f"{sorted(df['hierarchy'].unique())}") | |
| return KaggleHierarchyDataset(df) | |
| # ============================================================================ | |
| # 1c. Local validation dataset utilities | |
| # ============================================================================ | |
| class LocalHierarchyDataset(Dataset): | |
| """Local validation dataset returning (image, description, color, hierarchy).""" | |
| def __init__(self, dataframe, image_size=224): | |
| self.dataframe = dataframe.reset_index(drop=True) | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def __len__(self): | |
| return len(self.dataframe) | |
| def __getitem__(self, idx): | |
| row = self.dataframe.iloc[idx] | |
| try: | |
| img_path = row[column_local_image_path] | |
| if not os.path.isabs(img_path): | |
| img_path = os.path.join(ROOT_DIR, img_path) | |
| image = Image.open(img_path).convert("RGB") | |
| except Exception: | |
| image = Image.new("RGB", (224, 224), color="gray") | |
| image = self.transform(image) | |
| description = str(row["text"]) | |
| color = str(row.get("color", "unknown")) | |
| hierarchy = str(row["hierarchy"]) | |
| return image, description, color, hierarchy | |
| def load_local_validation_with_hierarchy(max_samples=10000, hierarchy_classes=None, raw_df=None): | |
| """Load internal validation dataset with hierarchy labels. | |
| Args: | |
| raw_df: Pre-loaded DataFrame to skip CSV read. | |
| """ | |
| if raw_df is not None: | |
| df = raw_df.copy() | |
| print(f"Using cached local DataFrame for hierarchy evaluation: {len(df)} samples") | |
| else: | |
| print("Loading local validation dataset for hierarchy evaluation...") | |
| df = pd.read_csv(local_dataset_path) | |
| print(f"Dataset loaded: {len(df)} samples") | |
| df = df.dropna(subset=[column_local_image_path, "hierarchy"]) | |
| df["hierarchy"] = df["hierarchy"].astype(str).str.strip() | |
| df = df[df["hierarchy"].str.len() > 0] | |
| if hierarchy_classes: | |
| hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] | |
| df["hierarchy_lower"] = df["hierarchy"].str.lower() | |
| df = df[df["hierarchy_lower"].isin(hierarchy_classes_lower)] | |
| # Restore proper casing from hierarchy_classes | |
| case_map = {h.lower(): h for h in hierarchy_classes} | |
| df["hierarchy"] = df["hierarchy_lower"].map(case_map) | |
| df = df.drop(columns=["hierarchy_lower"]) | |
| print(f"After filtering: {len(df)} samples, {df['hierarchy'].nunique()} classes") | |
| if len(df) > max_samples: | |
| df = df.sample(n=max_samples, random_state=42) | |
| print(f"Using {len(df)} samples, classes: {sorted(df['hierarchy'].unique())}") | |
| return LocalHierarchyDataset(df) | |
| # ============================================================================ | |
| # 2. Evaluator | |
| # ============================================================================ | |
| class CategoryModelEvaluator: | |
| """ | |
| Produces hierarchy confusion matrices for GAP-CLIP and the | |
| baseline Fashion-CLIP on Fashion-MNIST, KAGL Marqo, and internal datasets. | |
| """ | |
| def __init__(self, device='mps', directory='gap_clip_confusion_matrices', | |
| gap_clip_model=None, gap_clip_processor=None, | |
| baseline_model=None, baseline_processor=None, | |
| hierarchy_classes=None, | |
| kaggle_raw_df=None, local_raw_df=None): | |
| self.device = torch.device(device) if isinstance(device, str) else device | |
| self.directory = directory | |
| self.kaggle_raw_df = kaggle_raw_df | |
| self.local_raw_df = local_raw_df | |
| self.color_emb_dim = color_emb_dim | |
| self.hierarchy_emb_dim = hierarchy_emb_dim | |
| self.main_emb_dim = main_emb_dim | |
| self.hierarchy_end_dim = self.color_emb_dim + self.hierarchy_emb_dim | |
| os.makedirs(self.directory, exist_ok=True) | |
| # --- hierarchy classes --- | |
| if hierarchy_classes is not None: | |
| self.hierarchy_classes = hierarchy_classes | |
| print(f"Using provided hierarchy classes: {len(self.hierarchy_classes)} classes") | |
| else: | |
| print("Loading hierarchy classes from hierarchy model...") | |
| if not os.path.exists(hierarchy_model_path): | |
| raise FileNotFoundError(f"Hierarchy model file {hierarchy_model_path} not found") | |
| hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=self.device) | |
| self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', []) | |
| print(f"Found {len(self.hierarchy_classes)} hierarchy classes: {sorted(self.hierarchy_classes)}") | |
| self.validation_hierarchy_classes = self._load_validation_hierarchy_classes() | |
| if self.validation_hierarchy_classes: | |
| print(f"Validation dataset hierarchies ({len(self.validation_hierarchy_classes)} classes): " | |
| f"{sorted(self.validation_hierarchy_classes)}") | |
| else: | |
| print("Unable to load validation hierarchy classes, falling back to hierarchy model classes.") | |
| self.validation_hierarchy_classes = self.hierarchy_classes | |
| # --- load GAP-CLIP (accept pre-loaded or load from scratch) --- | |
| if gap_clip_model is not None and gap_clip_processor is not None: | |
| self.model = gap_clip_model | |
| self.processor = gap_clip_processor | |
| print("Using pre-loaded GAP-CLIP model") | |
| else: | |
| self.model, self.processor = load_gap_clip(main_model_path, self.device) | |
| print("GAP-CLIP model loaded successfully") | |
| # --- baseline Fashion-CLIP (accept pre-loaded or load from scratch) --- | |
| if baseline_model is not None and baseline_processor is not None: | |
| self.baseline_model = baseline_model | |
| self.baseline_processor = baseline_processor | |
| print("Using pre-loaded baseline Fashion-CLIP model") | |
| else: | |
| self.baseline_model, self.baseline_processor = load_baseline_fashion_clip(self.device) | |
| print("Baseline Fashion-CLIP model loaded successfully") | |
| # ------------------------------------------------------------------ | |
| # helpers | |
| # ------------------------------------------------------------------ | |
| def _load_validation_hierarchy_classes(self): | |
| if not os.path.exists(local_dataset_path): | |
| print(f"Validation dataset not found at {local_dataset_path}") | |
| return [] | |
| try: | |
| df = pd.read_csv(local_dataset_path) | |
| except Exception as exc: | |
| print(f"Failed to read validation dataset: {exc}") | |
| return [] | |
| if 'hierarchy' not in df.columns: | |
| print("Validation dataset does not contain 'hierarchy' column.") | |
| return [] | |
| hierarchies = df['hierarchy'].dropna().astype(str).str.strip() | |
| hierarchies = [h for h in hierarchies if h] | |
| return sorted(set(hierarchies)) | |
| def prepare_shared_fashion_mnist(self, max_samples=10000, batch_size=8): | |
| """ | |
| Build one shared Fashion-MNIST dataset/dataloader to ensure every model | |
| is evaluated on the exact same items. | |
| """ | |
| target_classes = self.validation_hierarchy_classes or self.hierarchy_classes | |
| fashion_dataset = load_fashion_mnist_dataset(max_samples, hierarchy_classes=target_classes) | |
| dataloader = DataLoader(fashion_dataset, batch_size=batch_size, shuffle=False, num_workers=0) | |
| hierarchy_counts = defaultdict(int) | |
| if len(fashion_dataset.dataframe) > 0 and fashion_dataset.label_mapping: | |
| for _, row in fashion_dataset.dataframe.iterrows(): | |
| lid = int(row['label']) | |
| hierarchy_counts[fashion_dataset.label_mapping.get(lid, 'unknown')] += 1 | |
| return fashion_dataset, dataloader, dict(hierarchy_counts) | |
| def _count_labels(labels): | |
| counts = defaultdict(int) | |
| for label in labels: | |
| counts[label] += 1 | |
| return dict(counts) | |
| def _validate_label_distribution(self, labels, expected_counts, context): | |
| observed = self._count_labels(labels) | |
| if observed != expected_counts: | |
| raise ValueError( | |
| f"Label distribution mismatch in {context}. " | |
| f"Expected {expected_counts}, observed {observed}" | |
| ) | |
| # ------------------------------------------------------------------ | |
| # embedding extraction (delegates to shared utils) | |
| # ------------------------------------------------------------------ | |
| def extract_full_embeddings(self, dataloader, embedding_type='text', max_samples=10000): | |
| """Full 512D embeddings from GAP-CLIP (text or image).""" | |
| return extract_clip_embeddings( | |
| self.model, self.processor, dataloader, self.device, | |
| embedding_type=embedding_type, max_samples=max_samples, | |
| desc=f"GAP-CLIP {embedding_type} embeddings", | |
| ) | |
| def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000): | |
| """L2-normalised embeddings from baseline Fashion-CLIP.""" | |
| return extract_clip_embeddings( | |
| self.baseline_model, self.baseline_processor, dataloader, self.device, | |
| embedding_type=embedding_type, max_samples=max_samples, | |
| desc=f"Baseline {embedding_type} embeddings", | |
| ) | |
| def predict_labels_nearest_neighbor(self, embeddings, labels): | |
| """ | |
| Predict labels using 1-NN on the same embedding set. | |
| This matches the accuracy logic used in the evaluation pipeline. | |
| """ | |
| similarities = cosine_similarity(embeddings) | |
| preds = [] | |
| for i in range(len(embeddings)): | |
| sims = similarities[i].copy() | |
| sims[i] = -1.0 | |
| nearest_neighbor_idx = int(np.argmax(sims)) | |
| preds.append(labels[nearest_neighbor_idx]) | |
| return preds | |
| # ------------------------------------------------------------------ | |
| # image + text ensemble | |
| # ------------------------------------------------------------------ | |
| def _compute_img_centroids(self, embeddings, labels): | |
| emb_norm = normalize(embeddings, norm='l2') | |
| centroids = {} | |
| for label in sorted(set(labels)): | |
| idx = [i for i, l in enumerate(labels) if l == label] | |
| centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm='l2')[0] | |
| return centroids | |
| def predict_labels_image_ensemble(self, img_embeddings, labels, | |
| text_protos, cls_names, alpha=0.5): | |
| """Combine image centroids (512D) with text prototypes (512D).""" | |
| img_norm = normalize(img_embeddings, norm='l2') | |
| img_centroids = self._compute_img_centroids(img_norm, labels) | |
| centroid_mat = np.stack([img_centroids[c] for c in cls_names], axis=0) | |
| preds = [] | |
| for i in range(len(img_norm)): | |
| v = img_norm[i:i + 1] | |
| sim_img = cosine_similarity(v, centroid_mat)[0] | |
| sim_txt = cosine_similarity(v, text_protos)[0] | |
| scores = alpha * sim_img + (1 - alpha) * sim_txt | |
| preds.append(cls_names[int(np.argmax(scores))]) | |
| return preds | |
| # ------------------------------------------------------------------ | |
| # confusion matrix & classification report | |
| # ------------------------------------------------------------------ | |
| def evaluate_classification_performance(self, embeddings, labels, | |
| embedding_type="Embeddings", | |
| label_type="Label", | |
| method="nn"): | |
| if method == "nn": | |
| preds = self.predict_labels_nearest_neighbor(embeddings, labels) | |
| elif method == "centroid": | |
| preds = predict_labels_from_embeddings(embeddings, labels) | |
| else: | |
| raise ValueError(f"Unknown classification method: {method}") | |
| acc = accuracy_score(labels, preds) | |
| unique_labels = sorted(set(labels)) | |
| fig, _, cm = create_confusion_matrix( | |
| labels, preds, | |
| f"{embedding_type} - {label_type} Classification ({method.upper()})", | |
| label_type, | |
| ) | |
| report = classification_report(labels, preds, labels=unique_labels, | |
| target_names=unique_labels, output_dict=True) | |
| return { | |
| 'accuracy': acc, | |
| 'predictions': preds, | |
| 'confusion_matrix': cm, | |
| 'labels': unique_labels, | |
| 'classification_report': report, | |
| 'figure': fig, | |
| } | |
| def save_confusion_matrix_table(self, cm, labels, output_csv_path): | |
| """ | |
| Save confusion matrix values with per-row totals to CSV for auditing. | |
| """ | |
| cm_df = pd.DataFrame(cm, index=labels, columns=labels) | |
| cm_df["row_total"] = cm_df.sum(axis=1) | |
| cm_df.loc["column_total"] = list(cm_df[labels].sum(axis=0)) + [cm_df["row_total"].sum()] | |
| cm_df.to_csv(output_csv_path) | |
| # ================================================================== | |
| # 3. GAP-CLIP evaluation on Fashion-MNIST | |
| # ================================================================== | |
| def evaluate_gap_clip_fashion_mnist(self, max_samples=10000, dataloader=None, expected_counts=None): | |
| print(f"\n{'=' * 60}") | |
| print("Evaluating GAP-CLIP on Fashion-MNIST") | |
| print(" Hierarchy embeddings (dims 16-79)") | |
| print(f" Max samples: {max_samples}") | |
| print(f"{'=' * 60}") | |
| if dataloader is None: | |
| fashion_dataset, dataloader, dataset_counts = self.prepare_shared_fashion_mnist(max_samples=max_samples) | |
| expected_counts = expected_counts or dataset_counts | |
| else: | |
| fashion_dataset = getattr(dataloader, "dataset", None) | |
| if expected_counts is None: | |
| raise ValueError("expected_counts must be provided when using a custom dataloader.") | |
| if fashion_dataset is not None and len(fashion_dataset.dataframe) > 0 and fashion_dataset.label_mapping: | |
| print(f"\nHierarchy distribution in dataset:") | |
| for h in sorted(expected_counts): | |
| print(f" {h}: {expected_counts[h]} samples") | |
| results = {} | |
| # --- full 512D embeddings (text & image) --- | |
| print("\nExtracting full 512-dimensional GAP-CLIP embeddings...") | |
| text_full, _, text_hier = self.extract_full_embeddings(dataloader, 'text', max_samples) | |
| img_full, _, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples) | |
| self._validate_label_distribution(text_hier, expected_counts, "GAP-CLIP text") | |
| self._validate_label_distribution(img_hier, expected_counts, "GAP-CLIP image") | |
| print(f" Text shape: {text_full.shape} | Image shape: {img_full.shape}") | |
| # --- TEXT: hierarchy on specialized 64D (dims 16-79) --- | |
| print("\n--- GAP-CLIP TEXT HIERARCHY (dims 16-79) ---") | |
| text_hier_spec = text_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim] | |
| print(f" Specialized text hierarchy shape: {text_hier_spec.shape}") | |
| text_metrics = compute_similarity_metrics(text_hier_spec, text_hier) | |
| text_class = self.evaluate_classification_performance( | |
| text_hier_spec, text_hier, | |
| "GAP-CLIP Text Hierarchy (64D)", "Hierarchy", | |
| method="nn", | |
| ) | |
| text_metrics.update(text_class) | |
| results['text_hierarchy'] = text_metrics | |
| # --- IMAGE: 64D vs 512D + ensemble --- | |
| print("\n--- GAP-CLIP IMAGE HIERARCHY (64D vs 512D) ---") | |
| img_hier_spec = img_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim] | |
| print(f" Specialized image hierarchy shape: {img_hier_spec.shape}") | |
| print(" Testing specialized 64D...") | |
| spec_metrics = compute_similarity_metrics(img_hier_spec, img_hier) | |
| spec_class = self.evaluate_classification_performance( | |
| img_hier_spec, img_hier, | |
| "GAP-CLIP Image Hierarchy (64D)", "Hierarchy", | |
| method="nn", | |
| ) | |
| print(" Testing full 512D...") | |
| full_metrics = compute_similarity_metrics(img_full, img_hier) | |
| full_class = self.evaluate_classification_performance( | |
| img_full, img_hier, | |
| "GAP-CLIP Image Hierarchy (512D full)", "Hierarchy", | |
| method="nn", | |
| ) | |
| if full_class['accuracy'] >= spec_class['accuracy']: | |
| print(f" 512D wins: {full_class['accuracy'] * 100:.1f}% vs {spec_class['accuracy'] * 100:.1f}%") | |
| img_metrics, img_class = full_metrics, full_class | |
| else: | |
| print(f" 64D wins: {spec_class['accuracy'] * 100:.1f}% vs {full_class['accuracy'] * 100:.1f}%") | |
| img_metrics, img_class = spec_metrics, spec_class | |
| # --- ensemble image + text prototypes --- | |
| print("\n Testing GAP-CLIP image + text ensemble (prototypes per class)...") | |
| cls_names = sorted(set(img_hier)) | |
| prompts = [f"a photo of a {c}" for c in cls_names] | |
| text_inputs = self.processor(text=prompts, return_tensors="pt", padding=True, truncation=True) | |
| text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} | |
| with torch.no_grad(): | |
| txt_feats = self.model.get_text_features(**text_inputs) | |
| txt_feats = txt_feats / txt_feats.norm(dim=-1, keepdim=True) | |
| text_protos = txt_feats.cpu().numpy() | |
| ensemble_preds = self.predict_labels_image_ensemble( | |
| img_full, img_hier, text_protos, cls_names, alpha=0.7, | |
| ) | |
| ensemble_acc = accuracy_score(img_hier, ensemble_preds) | |
| print(f" Ensemble accuracy (alpha=0.7): {ensemble_acc * 100:.2f}%") | |
| img_metrics.update(img_class) | |
| img_metrics['ensemble_accuracy'] = ensemble_acc | |
| results['image_hierarchy'] = img_metrics | |
| # --- save confusion matrix figures --- | |
| for key in ['text_hierarchy', 'image_hierarchy']: | |
| fig = results[key]['figure'] | |
| fig.savefig( | |
| os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.png"), | |
| dpi=300, bbox_inches='tight', | |
| ) | |
| self.save_confusion_matrix_table( | |
| results[key]['confusion_matrix'], | |
| results[key]['labels'], | |
| os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.csv"), | |
| ) | |
| plt.close(fig) | |
| del text_full, img_full, text_hier_spec, img_hier_spec | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return results | |
| # ================================================================== | |
| # 4. Baseline Fashion-CLIP evaluation on Fashion-MNIST | |
| # ================================================================== | |
| def evaluate_baseline_fashion_mnist(self, max_samples=10000, dataloader=None, expected_counts=None): | |
| print(f"\n{'=' * 60}") | |
| print("Evaluating Baseline Fashion-CLIP on Fashion-MNIST") | |
| print(f" Max samples: {max_samples}") | |
| print(f"{'=' * 60}") | |
| if dataloader is None: | |
| _, dataloader, dataset_counts = self.prepare_shared_fashion_mnist(max_samples=max_samples) | |
| expected_counts = expected_counts or dataset_counts | |
| elif expected_counts is None: | |
| raise ValueError("expected_counts must be provided when using a custom dataloader.") | |
| results = {} | |
| # --- text --- | |
| print("\nExtracting baseline text embeddings...") | |
| text_emb, _, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) | |
| self._validate_label_distribution(text_hier, expected_counts, "baseline text") | |
| print(f" Baseline text shape: {text_emb.shape}") | |
| text_metrics = compute_similarity_metrics(text_emb, text_hier) | |
| text_class = self.evaluate_classification_performance( | |
| text_emb, text_hier, | |
| "Baseline Fashion-CLIP Text - Hierarchy", "Hierarchy", | |
| method="nn", | |
| ) | |
| text_metrics.update(text_class) | |
| results['text'] = {'hierarchy': text_metrics} | |
| del text_emb | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # --- image --- | |
| print("\nExtracting baseline image embeddings...") | |
| img_emb, _, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) | |
| self._validate_label_distribution(img_hier, expected_counts, "baseline image") | |
| print(f" Baseline image shape: {img_emb.shape}") | |
| img_metrics = compute_similarity_metrics(img_emb, img_hier) | |
| img_class = self.evaluate_classification_performance( | |
| img_emb, img_hier, | |
| "Baseline Fashion-CLIP Image - Hierarchy", "Hierarchy", | |
| method="nn", | |
| ) | |
| img_metrics.update(img_class) | |
| results['image'] = {'hierarchy': img_metrics} | |
| del img_emb | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| for key in ['text', 'image']: | |
| fig = results[key]['hierarchy']['figure'] | |
| fig.savefig( | |
| os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.png"), | |
| dpi=300, bbox_inches='tight', | |
| ) | |
| self.save_confusion_matrix_table( | |
| results[key]['hierarchy']['confusion_matrix'], | |
| results[key]['hierarchy']['labels'], | |
| os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.csv"), | |
| ) | |
| plt.close(fig) | |
| return results | |
| # ================================================================== | |
| # 5. Generic dataset evaluation (KAGL Marqo / Internal) | |
| # ================================================================== | |
| def evaluate_gap_clip_generic(self, dataloader, dataset_name, max_samples=10000): | |
| """Evaluate GAP-CLIP hierarchy performance on any dataset.""" | |
| print(f"\n{'=' * 60}") | |
| print(f"Evaluating GAP-CLIP on {dataset_name}") | |
| print(f" Hierarchy embeddings (dims 16-79)") | |
| print(f"{'=' * 60}") | |
| results = {} | |
| # --- text hierarchy (64D specialized) --- | |
| print("\nExtracting GAP-CLIP text embeddings...") | |
| text_full, _, text_hier = self.extract_full_embeddings(dataloader, 'text', max_samples) | |
| text_hier_spec = text_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim] | |
| print(f" Text shape: {text_full.shape}, hierarchy subspace: {text_hier_spec.shape}") | |
| text_metrics = compute_similarity_metrics(text_hier_spec, text_hier) | |
| text_class = self.evaluate_classification_performance( | |
| text_hier_spec, text_hier, | |
| f"GAP-CLIP Text Hierarchy – {dataset_name}", "Hierarchy", method="nn", | |
| ) | |
| text_metrics.update(text_class) | |
| results['text_hierarchy'] = text_metrics | |
| # --- image hierarchy (best of 64D vs 512D) --- | |
| print("\nExtracting GAP-CLIP image embeddings...") | |
| img_full, _, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples) | |
| img_hier_spec = img_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim] | |
| spec_metrics = compute_similarity_metrics(img_hier_spec, img_hier) | |
| spec_class = self.evaluate_classification_performance( | |
| img_hier_spec, img_hier, | |
| f"GAP-CLIP Image Hierarchy (64D) – {dataset_name}", "Hierarchy", method="nn", | |
| ) | |
| full_metrics = compute_similarity_metrics(img_full, img_hier) | |
| full_class = self.evaluate_classification_performance( | |
| img_full, img_hier, | |
| f"GAP-CLIP Image Hierarchy (512D) – {dataset_name}", "Hierarchy", method="nn", | |
| ) | |
| if full_class['accuracy'] >= spec_class['accuracy']: | |
| print(f" 512D wins: {full_class['accuracy']*100:.1f}% vs {spec_class['accuracy']*100:.1f}%") | |
| img_metrics, img_class = full_metrics, full_class | |
| else: | |
| print(f" 64D wins: {spec_class['accuracy']*100:.1f}% vs {full_class['accuracy']*100:.1f}%") | |
| img_metrics, img_class = spec_metrics, spec_class | |
| img_metrics.update(img_class) | |
| results['image_hierarchy'] = img_metrics | |
| # --- save confusion matrices --- | |
| prefix = dataset_name.lower().replace(" ", "_") | |
| for key in ['text_hierarchy', 'image_hierarchy']: | |
| fig = results[key]['figure'] | |
| fig.savefig( | |
| os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.png"), | |
| dpi=300, bbox_inches='tight', | |
| ) | |
| self.save_confusion_matrix_table( | |
| results[key]['confusion_matrix'], results[key]['labels'], | |
| os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.csv"), | |
| ) | |
| plt.close(fig) | |
| del text_full, img_full, text_hier_spec, img_hier_spec | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return results | |
| def evaluate_baseline_generic(self, dataloader, dataset_name, max_samples=10000): | |
| """Evaluate baseline Fashion-CLIP hierarchy performance on any dataset.""" | |
| print(f"\n{'=' * 60}") | |
| print(f"Evaluating Baseline Fashion-CLIP on {dataset_name}") | |
| print(f"{'=' * 60}") | |
| results = {} | |
| # --- text --- | |
| print("\nExtracting baseline text embeddings...") | |
| text_emb, _, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) | |
| print(f" Baseline text shape: {text_emb.shape}") | |
| text_metrics = compute_similarity_metrics(text_emb, text_hier) | |
| text_class = self.evaluate_classification_performance( | |
| text_emb, text_hier, | |
| f"Baseline Text Hierarchy – {dataset_name}", "Hierarchy", method="nn", | |
| ) | |
| text_metrics.update(text_class) | |
| results['text'] = {'hierarchy': text_metrics} | |
| del text_emb | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # --- image --- | |
| print("\nExtracting baseline image embeddings...") | |
| img_emb, _, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) | |
| print(f" Baseline image shape: {img_emb.shape}") | |
| img_metrics = compute_similarity_metrics(img_emb, img_hier) | |
| img_class = self.evaluate_classification_performance( | |
| img_emb, img_hier, | |
| f"Baseline Image Hierarchy – {dataset_name}", "Hierarchy", method="nn", | |
| ) | |
| img_metrics.update(img_class) | |
| results['image'] = {'hierarchy': img_metrics} | |
| del img_emb | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| prefix = dataset_name.lower().replace(" ", "_") | |
| for key in ['text', 'image']: | |
| fig = results[key]['hierarchy']['figure'] | |
| fig.savefig( | |
| os.path.join(self.directory, f"baseline_{prefix}_{key}_hierarchy_confusion_matrix.png"), | |
| dpi=300, bbox_inches='tight', | |
| ) | |
| self.save_confusion_matrix_table( | |
| results[key]['hierarchy']['confusion_matrix'], | |
| results[key]['hierarchy']['labels'], | |
| os.path.join(self.directory, f"baseline_{prefix}_{key}_hierarchy_confusion_matrix.csv"), | |
| ) | |
| plt.close(fig) | |
| return results | |
| # ================================================================== | |
| # 6. Full evaluation across all datasets | |
| # ================================================================== | |
| def run_full_evaluation(self, max_samples=10000, batch_size=8): | |
| """Run hierarchy evaluation on all 3 datasets for both models.""" | |
| all_results = {} | |
| # --- Fashion-MNIST --- | |
| shared_dataset, shared_dataloader, shared_counts = self.prepare_shared_fashion_mnist( | |
| max_samples=max_samples, batch_size=batch_size, | |
| ) | |
| all_results['fashion_mnist_gap'] = self.evaluate_gap_clip_fashion_mnist( | |
| max_samples=max_samples, dataloader=shared_dataloader, expected_counts=shared_counts, | |
| ) | |
| all_results['fashion_mnist_baseline'] = self.evaluate_baseline_fashion_mnist( | |
| max_samples=max_samples, dataloader=shared_dataloader, expected_counts=shared_counts, | |
| ) | |
| # --- KAGL Marqo --- | |
| try: | |
| kaggle_dataset = load_kaggle_marqo_with_hierarchy( | |
| max_samples=max_samples, | |
| hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes, | |
| raw_df=self.kaggle_raw_df, | |
| ) | |
| if kaggle_dataset is not None and len(kaggle_dataset) > 0: | |
| kaggle_dataloader = DataLoader(kaggle_dataset, batch_size=batch_size, shuffle=False, num_workers=0) | |
| all_results['kaggle_gap'] = self.evaluate_gap_clip_generic( | |
| kaggle_dataloader, "KAGL Marqo", max_samples, | |
| ) | |
| all_results['kaggle_baseline'] = self.evaluate_baseline_generic( | |
| kaggle_dataloader, "KAGL Marqo", max_samples, | |
| ) | |
| else: | |
| print("WARNING: KAGL Marqo dataset empty after hierarchy mapping, skipping.") | |
| except Exception as e: | |
| print(f"WARNING: Could not evaluate on KAGL Marqo: {e}") | |
| # --- Internal (local validation) --- | |
| try: | |
| local_dataset = load_local_validation_with_hierarchy( | |
| max_samples=max_samples, | |
| hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes, | |
| raw_df=self.local_raw_df, | |
| ) | |
| if local_dataset is not None and len(local_dataset) > 0: | |
| local_dataloader = DataLoader(local_dataset, batch_size=batch_size, shuffle=False, num_workers=0) | |
| all_results['local_gap'] = self.evaluate_gap_clip_generic( | |
| local_dataloader, "Internal", max_samples, | |
| ) | |
| all_results['local_baseline'] = self.evaluate_baseline_generic( | |
| local_dataloader, "Internal", max_samples, | |
| ) | |
| else: | |
| print("WARNING: Local validation dataset empty after hierarchy filtering, skipping.") | |
| except Exception as e: | |
| print(f"WARNING: Could not evaluate on internal dataset: {e}") | |
| # --- Print summary --- | |
| print(f"\n{'=' * 70}") | |
| print("CATEGORY MODEL EVALUATION SUMMARY") | |
| print(f"{'=' * 70}") | |
| for dataset_key, label in [ | |
| ('fashion_mnist_gap', 'Fashion-MNIST (GAP-CLIP)'), | |
| ('fashion_mnist_baseline', 'Fashion-MNIST (Baseline)'), | |
| ('kaggle_gap', 'KAGL Marqo (GAP-CLIP)'), | |
| ('kaggle_baseline', 'KAGL Marqo (Baseline)'), | |
| ('local_gap', 'Internal (GAP-CLIP)'), | |
| ('local_baseline', 'Internal (Baseline)'), | |
| ]: | |
| if dataset_key not in all_results: | |
| continue | |
| res = all_results[dataset_key] | |
| print(f"\n{label}:") | |
| if 'text_hierarchy' in res: | |
| t = res['text_hierarchy'] | |
| i = res['image_hierarchy'] | |
| print(f" Text NN Acc: {t['accuracy']*100:.1f}% | Separation: {t['separation_score']:.4f}") | |
| print(f" Image NN Acc: {i['accuracy']*100:.1f}% | Separation: {i['separation_score']:.4f}") | |
| elif 'text' in res: | |
| t = res['text']['hierarchy'] | |
| i = res['image']['hierarchy'] | |
| print(f" Text NN Acc: {t['accuracy']*100:.1f}% | Separation: {t['separation_score']:.4f}") | |
| print(f" Image NN Acc: {i['accuracy']*100:.1f}% | Separation: {i['separation_score']:.4f}") | |
| return all_results | |
| # ============================================================================ | |
| # 7. Main | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| directory = 'gap_clip_confusion_matrices' | |
| max_samples = 10000 | |
| evaluator = CategoryModelEvaluator(device=device, directory=directory) | |
| evaluator.run_full_evaluation(max_samples=max_samples, batch_size=8) | |