""" Section 5.2 — Category Model Evaluation (Table 2) ================================================== Evaluates GAP-CLIP vs the Fashion-CLIP baseline on hierarchy (category) classification using three datasets: - Fashion-MNIST (10 categories) - KAGL Marqo (external, real-world fashion e-commerce) - Internal validation dataset Produces hierarchy confusion matrices (text + image) for both models on each dataset. Metrics match Table 2 in the paper: - Text/image embedding NN accuracy - Text/image embedding separation score Run directly: python sec52_category_model_eval.py Paper reference: Section 5.2, Table 2. """ import os os.environ["TOKENIZERS_PARALLELISM"] = "false" import torch import pandas as pd import numpy as np import matplotlib.pyplot as plt import difflib from collections import defaultdict from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics import classification_report, accuracy_score from sklearn.preprocessing import normalize from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image from io import BytesIO import warnings warnings.filterwarnings('ignore') from config import ( ROOT_DIR, main_model_path, main_emb_dim, hierarchy_model_path, color_emb_dim, hierarchy_emb_dim, local_dataset_path, column_local_image_path, ) from utils.datasets import ( load_fashion_mnist_dataset, ) from utils.embeddings import extract_clip_embeddings from utils.metrics import ( compute_similarity_metrics, compute_embedding_accuracy, compute_centroid_accuracy, predict_labels_from_embeddings, create_confusion_matrix, ) from utils.model_loader import load_gap_clip, load_baseline_fashion_clip # ============================================================================ # 1b. KAGL Marqo utilities # ============================================================================ class KaggleHierarchyDataset(Dataset): """KAGL Marqo dataset returning (image, description, color, hierarchy).""" def __init__(self, dataframe, image_size=224): self.dataframe = dataframe.reset_index(drop=True) self.transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def __len__(self): return len(self.dataframe) def __getitem__(self, idx): row = self.dataframe.iloc[idx] image_data = row["image"] if isinstance(image_data, dict) and "bytes" in image_data: image = Image.open(BytesIO(image_data["bytes"])).convert("RGB") elif hasattr(image_data, "convert"): image = image_data.convert("RGB") else: image = Image.open(BytesIO(image_data)).convert("RGB") image = self.transform(image) description = str(row["text"]) color = str(row.get("baseColour", "unknown")).lower() hierarchy = str(row["hierarchy"]) return image, description, color, hierarchy def load_kaggle_marqo_with_hierarchy(max_samples=10000, hierarchy_classes=None, raw_df=None): """Load KAGL Marqo dataset with hierarchy labels derived from articleType. Args: raw_df: Pre-downloaded DataFrame to skip the HuggingFace download. """ if raw_df is not None: df = raw_df.copy() print(f"Using cached KAGL DataFrame for hierarchy evaluation: {len(df)} samples") else: from datasets import load_dataset print("Loading KAGL Marqo dataset for hierarchy evaluation...") dataset = load_dataset("Marqo/KAGL") df = dataset["data"].to_pandas() print(f"Dataset loaded: {len(df)} samples, columns: {list(df.columns)}") # Use the most specific category column as hierarchy source hierarchy_col = 'category2' if hierarchy_col is None: print("WARNING: No hierarchy column found in KAGL dataset") return None print(f"Using '{hierarchy_col}' as hierarchy source") df = df.dropna(subset=["text", "image", hierarchy_col]) df["hierarchy"] = df[hierarchy_col].astype(str).str.strip() # If hierarchy_classes provided, map KAGL types to model hierarchy classes if hierarchy_classes: hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] mapped = [] for _, row in df.iterrows(): kagl_type = row["hierarchy"].lower() matched = None # Exact match if kagl_type in hierarchy_classes_lower: matched = hierarchy_classes[hierarchy_classes_lower.index(kagl_type)] else: # Substring match for h_class in hierarchy_classes: h_lower = h_class.lower() if h_lower in kagl_type or kagl_type in h_lower: matched = h_class break if matched is None: close = difflib.get_close_matches(kagl_type, hierarchy_classes_lower, n=1, cutoff=0.6) if close: matched = hierarchy_classes[hierarchy_classes_lower.index(close[0])] mapped.append(matched) df["hierarchy"] = mapped df = df.dropna(subset=["hierarchy"]) print(f"After hierarchy mapping: {len(df)} samples") if len(df) > max_samples: df = df.sample(n=max_samples, random_state=42) print(f"Using {len(df)} samples, {df['hierarchy'].nunique()} hierarchy classes: " f"{sorted(df['hierarchy'].unique())}") return KaggleHierarchyDataset(df) # ============================================================================ # 1c. Local validation dataset utilities # ============================================================================ class LocalHierarchyDataset(Dataset): """Local validation dataset returning (image, description, color, hierarchy).""" def __init__(self, dataframe, image_size=224): self.dataframe = dataframe.reset_index(drop=True) self.transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def __len__(self): return len(self.dataframe) def __getitem__(self, idx): row = self.dataframe.iloc[idx] try: img_path = row[column_local_image_path] if not os.path.isabs(img_path): img_path = os.path.join(ROOT_DIR, img_path) image = Image.open(img_path).convert("RGB") except Exception: image = Image.new("RGB", (224, 224), color="gray") image = self.transform(image) description = str(row["text"]) color = str(row.get("color", "unknown")) hierarchy = str(row["hierarchy"]) return image, description, color, hierarchy def load_local_validation_with_hierarchy(max_samples=10000, hierarchy_classes=None, raw_df=None): """Load internal validation dataset with hierarchy labels. Args: raw_df: Pre-loaded DataFrame to skip CSV read. """ if raw_df is not None: df = raw_df.copy() print(f"Using cached local DataFrame for hierarchy evaluation: {len(df)} samples") else: print("Loading local validation dataset for hierarchy evaluation...") df = pd.read_csv(local_dataset_path) print(f"Dataset loaded: {len(df)} samples") df = df.dropna(subset=[column_local_image_path, "hierarchy"]) df["hierarchy"] = df["hierarchy"].astype(str).str.strip() df = df[df["hierarchy"].str.len() > 0] if hierarchy_classes: hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] df["hierarchy_lower"] = df["hierarchy"].str.lower() df = df[df["hierarchy_lower"].isin(hierarchy_classes_lower)] # Restore proper casing from hierarchy_classes case_map = {h.lower(): h for h in hierarchy_classes} df["hierarchy"] = df["hierarchy_lower"].map(case_map) df = df.drop(columns=["hierarchy_lower"]) print(f"After filtering: {len(df)} samples, {df['hierarchy'].nunique()} classes") if len(df) > max_samples: df = df.sample(n=max_samples, random_state=42) print(f"Using {len(df)} samples, classes: {sorted(df['hierarchy'].unique())}") return LocalHierarchyDataset(df) # ============================================================================ # 2. Evaluator # ============================================================================ class CategoryModelEvaluator: """ Produces hierarchy confusion matrices for GAP-CLIP and the baseline Fashion-CLIP on Fashion-MNIST, KAGL Marqo, and internal datasets. """ def __init__(self, device='mps', directory='gap_clip_confusion_matrices', gap_clip_model=None, gap_clip_processor=None, baseline_model=None, baseline_processor=None, hierarchy_classes=None, kaggle_raw_df=None, local_raw_df=None): self.device = torch.device(device) if isinstance(device, str) else device self.directory = directory self.kaggle_raw_df = kaggle_raw_df self.local_raw_df = local_raw_df self.color_emb_dim = color_emb_dim self.hierarchy_emb_dim = hierarchy_emb_dim self.main_emb_dim = main_emb_dim self.hierarchy_end_dim = self.color_emb_dim + self.hierarchy_emb_dim os.makedirs(self.directory, exist_ok=True) # --- hierarchy classes --- if hierarchy_classes is not None: self.hierarchy_classes = hierarchy_classes print(f"Using provided hierarchy classes: {len(self.hierarchy_classes)} classes") else: print("Loading hierarchy classes from hierarchy model...") if not os.path.exists(hierarchy_model_path): raise FileNotFoundError(f"Hierarchy model file {hierarchy_model_path} not found") hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=self.device) self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', []) print(f"Found {len(self.hierarchy_classes)} hierarchy classes: {sorted(self.hierarchy_classes)}") self.validation_hierarchy_classes = self._load_validation_hierarchy_classes() if self.validation_hierarchy_classes: print(f"Validation dataset hierarchies ({len(self.validation_hierarchy_classes)} classes): " f"{sorted(self.validation_hierarchy_classes)}") else: print("Unable to load validation hierarchy classes, falling back to hierarchy model classes.") self.validation_hierarchy_classes = self.hierarchy_classes # --- load GAP-CLIP (accept pre-loaded or load from scratch) --- if gap_clip_model is not None and gap_clip_processor is not None: self.model = gap_clip_model self.processor = gap_clip_processor print("Using pre-loaded GAP-CLIP model") else: self.model, self.processor = load_gap_clip(main_model_path, self.device) print("GAP-CLIP model loaded successfully") # --- baseline Fashion-CLIP (accept pre-loaded or load from scratch) --- if baseline_model is not None and baseline_processor is not None: self.baseline_model = baseline_model self.baseline_processor = baseline_processor print("Using pre-loaded baseline Fashion-CLIP model") else: self.baseline_model, self.baseline_processor = load_baseline_fashion_clip(self.device) print("Baseline Fashion-CLIP model loaded successfully") # ------------------------------------------------------------------ # helpers # ------------------------------------------------------------------ def _load_validation_hierarchy_classes(self): if not os.path.exists(local_dataset_path): print(f"Validation dataset not found at {local_dataset_path}") return [] try: df = pd.read_csv(local_dataset_path) except Exception as exc: print(f"Failed to read validation dataset: {exc}") return [] if 'hierarchy' not in df.columns: print("Validation dataset does not contain 'hierarchy' column.") return [] hierarchies = df['hierarchy'].dropna().astype(str).str.strip() hierarchies = [h for h in hierarchies if h] return sorted(set(hierarchies)) def prepare_shared_fashion_mnist(self, max_samples=10000, batch_size=8): """ Build one shared Fashion-MNIST dataset/dataloader to ensure every model is evaluated on the exact same items. """ target_classes = self.validation_hierarchy_classes or self.hierarchy_classes fashion_dataset = load_fashion_mnist_dataset(max_samples, hierarchy_classes=target_classes) dataloader = DataLoader(fashion_dataset, batch_size=batch_size, shuffle=False, num_workers=0) hierarchy_counts = defaultdict(int) if len(fashion_dataset.dataframe) > 0 and fashion_dataset.label_mapping: for _, row in fashion_dataset.dataframe.iterrows(): lid = int(row['label']) hierarchy_counts[fashion_dataset.label_mapping.get(lid, 'unknown')] += 1 return fashion_dataset, dataloader, dict(hierarchy_counts) @staticmethod def _count_labels(labels): counts = defaultdict(int) for label in labels: counts[label] += 1 return dict(counts) def _validate_label_distribution(self, labels, expected_counts, context): observed = self._count_labels(labels) if observed != expected_counts: raise ValueError( f"Label distribution mismatch in {context}. " f"Expected {expected_counts}, observed {observed}" ) # ------------------------------------------------------------------ # embedding extraction (delegates to shared utils) # ------------------------------------------------------------------ def extract_full_embeddings(self, dataloader, embedding_type='text', max_samples=10000): """Full 512D embeddings from GAP-CLIP (text or image).""" return extract_clip_embeddings( self.model, self.processor, dataloader, self.device, embedding_type=embedding_type, max_samples=max_samples, desc=f"GAP-CLIP {embedding_type} embeddings", ) def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000): """L2-normalised embeddings from baseline Fashion-CLIP.""" return extract_clip_embeddings( self.baseline_model, self.baseline_processor, dataloader, self.device, embedding_type=embedding_type, max_samples=max_samples, desc=f"Baseline {embedding_type} embeddings", ) def predict_labels_nearest_neighbor(self, embeddings, labels): """ Predict labels using 1-NN on the same embedding set. This matches the accuracy logic used in the evaluation pipeline. """ similarities = cosine_similarity(embeddings) preds = [] for i in range(len(embeddings)): sims = similarities[i].copy() sims[i] = -1.0 nearest_neighbor_idx = int(np.argmax(sims)) preds.append(labels[nearest_neighbor_idx]) return preds # ------------------------------------------------------------------ # image + text ensemble # ------------------------------------------------------------------ def _compute_img_centroids(self, embeddings, labels): emb_norm = normalize(embeddings, norm='l2') centroids = {} for label in sorted(set(labels)): idx = [i for i, l in enumerate(labels) if l == label] centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm='l2')[0] return centroids def predict_labels_image_ensemble(self, img_embeddings, labels, text_protos, cls_names, alpha=0.5): """Combine image centroids (512D) with text prototypes (512D).""" img_norm = normalize(img_embeddings, norm='l2') img_centroids = self._compute_img_centroids(img_norm, labels) centroid_mat = np.stack([img_centroids[c] for c in cls_names], axis=0) preds = [] for i in range(len(img_norm)): v = img_norm[i:i + 1] sim_img = cosine_similarity(v, centroid_mat)[0] sim_txt = cosine_similarity(v, text_protos)[0] scores = alpha * sim_img + (1 - alpha) * sim_txt preds.append(cls_names[int(np.argmax(scores))]) return preds # ------------------------------------------------------------------ # confusion matrix & classification report # ------------------------------------------------------------------ def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label", method="nn"): if method == "nn": preds = self.predict_labels_nearest_neighbor(embeddings, labels) elif method == "centroid": preds = predict_labels_from_embeddings(embeddings, labels) else: raise ValueError(f"Unknown classification method: {method}") acc = accuracy_score(labels, preds) unique_labels = sorted(set(labels)) fig, _, cm = create_confusion_matrix( labels, preds, f"{embedding_type} - {label_type} Classification ({method.upper()})", label_type, ) report = classification_report(labels, preds, labels=unique_labels, target_names=unique_labels, output_dict=True) return { 'accuracy': acc, 'predictions': preds, 'confusion_matrix': cm, 'labels': unique_labels, 'classification_report': report, 'figure': fig, } def save_confusion_matrix_table(self, cm, labels, output_csv_path): """ Save confusion matrix values with per-row totals to CSV for auditing. """ cm_df = pd.DataFrame(cm, index=labels, columns=labels) cm_df["row_total"] = cm_df.sum(axis=1) cm_df.loc["column_total"] = list(cm_df[labels].sum(axis=0)) + [cm_df["row_total"].sum()] cm_df.to_csv(output_csv_path) # ================================================================== # 3. GAP-CLIP evaluation on Fashion-MNIST # ================================================================== def evaluate_gap_clip_fashion_mnist(self, max_samples=10000, dataloader=None, expected_counts=None): print(f"\n{'=' * 60}") print("Evaluating GAP-CLIP on Fashion-MNIST") print(" Hierarchy embeddings (dims 16-79)") print(f" Max samples: {max_samples}") print(f"{'=' * 60}") if dataloader is None: fashion_dataset, dataloader, dataset_counts = self.prepare_shared_fashion_mnist(max_samples=max_samples) expected_counts = expected_counts or dataset_counts else: fashion_dataset = getattr(dataloader, "dataset", None) if expected_counts is None: raise ValueError("expected_counts must be provided when using a custom dataloader.") if fashion_dataset is not None and len(fashion_dataset.dataframe) > 0 and fashion_dataset.label_mapping: print(f"\nHierarchy distribution in dataset:") for h in sorted(expected_counts): print(f" {h}: {expected_counts[h]} samples") results = {} # --- full 512D embeddings (text & image) --- print("\nExtracting full 512-dimensional GAP-CLIP embeddings...") text_full, _, text_hier = self.extract_full_embeddings(dataloader, 'text', max_samples) img_full, _, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples) self._validate_label_distribution(text_hier, expected_counts, "GAP-CLIP text") self._validate_label_distribution(img_hier, expected_counts, "GAP-CLIP image") print(f" Text shape: {text_full.shape} | Image shape: {img_full.shape}") # --- TEXT: hierarchy on specialized 64D (dims 16-79) --- print("\n--- GAP-CLIP TEXT HIERARCHY (dims 16-79) ---") text_hier_spec = text_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim] print(f" Specialized text hierarchy shape: {text_hier_spec.shape}") text_metrics = compute_similarity_metrics(text_hier_spec, text_hier) text_class = self.evaluate_classification_performance( text_hier_spec, text_hier, "GAP-CLIP Text Hierarchy (64D)", "Hierarchy", method="nn", ) text_metrics.update(text_class) results['text_hierarchy'] = text_metrics # --- IMAGE: 64D vs 512D + ensemble --- print("\n--- GAP-CLIP IMAGE HIERARCHY (64D vs 512D) ---") img_hier_spec = img_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim] print(f" Specialized image hierarchy shape: {img_hier_spec.shape}") print(" Testing specialized 64D...") spec_metrics = compute_similarity_metrics(img_hier_spec, img_hier) spec_class = self.evaluate_classification_performance( img_hier_spec, img_hier, "GAP-CLIP Image Hierarchy (64D)", "Hierarchy", method="nn", ) print(" Testing full 512D...") full_metrics = compute_similarity_metrics(img_full, img_hier) full_class = self.evaluate_classification_performance( img_full, img_hier, "GAP-CLIP Image Hierarchy (512D full)", "Hierarchy", method="nn", ) if full_class['accuracy'] >= spec_class['accuracy']: print(f" 512D wins: {full_class['accuracy'] * 100:.1f}% vs {spec_class['accuracy'] * 100:.1f}%") img_metrics, img_class = full_metrics, full_class else: print(f" 64D wins: {spec_class['accuracy'] * 100:.1f}% vs {full_class['accuracy'] * 100:.1f}%") img_metrics, img_class = spec_metrics, spec_class # --- ensemble image + text prototypes --- print("\n Testing GAP-CLIP image + text ensemble (prototypes per class)...") cls_names = sorted(set(img_hier)) prompts = [f"a photo of a {c}" for c in cls_names] text_inputs = self.processor(text=prompts, return_tensors="pt", padding=True, truncation=True) text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} with torch.no_grad(): txt_feats = self.model.get_text_features(**text_inputs) txt_feats = txt_feats / txt_feats.norm(dim=-1, keepdim=True) text_protos = txt_feats.cpu().numpy() ensemble_preds = self.predict_labels_image_ensemble( img_full, img_hier, text_protos, cls_names, alpha=0.7, ) ensemble_acc = accuracy_score(img_hier, ensemble_preds) print(f" Ensemble accuracy (alpha=0.7): {ensemble_acc * 100:.2f}%") img_metrics.update(img_class) img_metrics['ensemble_accuracy'] = ensemble_acc results['image_hierarchy'] = img_metrics # --- save confusion matrix figures --- for key in ['text_hierarchy', 'image_hierarchy']: fig = results[key]['figure'] fig.savefig( os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.png"), dpi=300, bbox_inches='tight', ) self.save_confusion_matrix_table( results[key]['confusion_matrix'], results[key]['labels'], os.path.join(self.directory, f"gap_clip_{key}_confusion_matrix.csv"), ) plt.close(fig) del text_full, img_full, text_hier_spec, img_hier_spec if torch.cuda.is_available(): torch.cuda.empty_cache() return results # ================================================================== # 4. Baseline Fashion-CLIP evaluation on Fashion-MNIST # ================================================================== def evaluate_baseline_fashion_mnist(self, max_samples=10000, dataloader=None, expected_counts=None): print(f"\n{'=' * 60}") print("Evaluating Baseline Fashion-CLIP on Fashion-MNIST") print(f" Max samples: {max_samples}") print(f"{'=' * 60}") if dataloader is None: _, dataloader, dataset_counts = self.prepare_shared_fashion_mnist(max_samples=max_samples) expected_counts = expected_counts or dataset_counts elif expected_counts is None: raise ValueError("expected_counts must be provided when using a custom dataloader.") results = {} # --- text --- print("\nExtracting baseline text embeddings...") text_emb, _, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) self._validate_label_distribution(text_hier, expected_counts, "baseline text") print(f" Baseline text shape: {text_emb.shape}") text_metrics = compute_similarity_metrics(text_emb, text_hier) text_class = self.evaluate_classification_performance( text_emb, text_hier, "Baseline Fashion-CLIP Text - Hierarchy", "Hierarchy", method="nn", ) text_metrics.update(text_class) results['text'] = {'hierarchy': text_metrics} del text_emb if torch.cuda.is_available(): torch.cuda.empty_cache() # --- image --- print("\nExtracting baseline image embeddings...") img_emb, _, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) self._validate_label_distribution(img_hier, expected_counts, "baseline image") print(f" Baseline image shape: {img_emb.shape}") img_metrics = compute_similarity_metrics(img_emb, img_hier) img_class = self.evaluate_classification_performance( img_emb, img_hier, "Baseline Fashion-CLIP Image - Hierarchy", "Hierarchy", method="nn", ) img_metrics.update(img_class) results['image'] = {'hierarchy': img_metrics} del img_emb if torch.cuda.is_available(): torch.cuda.empty_cache() for key in ['text', 'image']: fig = results[key]['hierarchy']['figure'] fig.savefig( os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.png"), dpi=300, bbox_inches='tight', ) self.save_confusion_matrix_table( results[key]['hierarchy']['confusion_matrix'], results[key]['hierarchy']['labels'], os.path.join(self.directory, f"baseline_{key}_hierarchy_confusion_matrix.csv"), ) plt.close(fig) return results # ================================================================== # 5. Generic dataset evaluation (KAGL Marqo / Internal) # ================================================================== def evaluate_gap_clip_generic(self, dataloader, dataset_name, max_samples=10000): """Evaluate GAP-CLIP hierarchy performance on any dataset.""" print(f"\n{'=' * 60}") print(f"Evaluating GAP-CLIP on {dataset_name}") print(f" Hierarchy embeddings (dims 16-79)") print(f"{'=' * 60}") results = {} # --- text hierarchy (64D specialized) --- print("\nExtracting GAP-CLIP text embeddings...") text_full, _, text_hier = self.extract_full_embeddings(dataloader, 'text', max_samples) text_hier_spec = text_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim] print(f" Text shape: {text_full.shape}, hierarchy subspace: {text_hier_spec.shape}") text_metrics = compute_similarity_metrics(text_hier_spec, text_hier) text_class = self.evaluate_classification_performance( text_hier_spec, text_hier, f"GAP-CLIP Text Hierarchy – {dataset_name}", "Hierarchy", method="nn", ) text_metrics.update(text_class) results['text_hierarchy'] = text_metrics # --- image hierarchy (best of 64D vs 512D) --- print("\nExtracting GAP-CLIP image embeddings...") img_full, _, img_hier = self.extract_full_embeddings(dataloader, 'image', max_samples) img_hier_spec = img_full[:, self.color_emb_dim:self.color_emb_dim + self.hierarchy_emb_dim] spec_metrics = compute_similarity_metrics(img_hier_spec, img_hier) spec_class = self.evaluate_classification_performance( img_hier_spec, img_hier, f"GAP-CLIP Image Hierarchy (64D) – {dataset_name}", "Hierarchy", method="nn", ) full_metrics = compute_similarity_metrics(img_full, img_hier) full_class = self.evaluate_classification_performance( img_full, img_hier, f"GAP-CLIP Image Hierarchy (512D) – {dataset_name}", "Hierarchy", method="nn", ) if full_class['accuracy'] >= spec_class['accuracy']: print(f" 512D wins: {full_class['accuracy']*100:.1f}% vs {spec_class['accuracy']*100:.1f}%") img_metrics, img_class = full_metrics, full_class else: print(f" 64D wins: {spec_class['accuracy']*100:.1f}% vs {full_class['accuracy']*100:.1f}%") img_metrics, img_class = spec_metrics, spec_class img_metrics.update(img_class) results['image_hierarchy'] = img_metrics # --- save confusion matrices --- prefix = dataset_name.lower().replace(" ", "_") for key in ['text_hierarchy', 'image_hierarchy']: fig = results[key]['figure'] fig.savefig( os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.png"), dpi=300, bbox_inches='tight', ) self.save_confusion_matrix_table( results[key]['confusion_matrix'], results[key]['labels'], os.path.join(self.directory, f"gap_clip_{prefix}_{key}_confusion_matrix.csv"), ) plt.close(fig) del text_full, img_full, text_hier_spec, img_hier_spec if torch.cuda.is_available(): torch.cuda.empty_cache() return results def evaluate_baseline_generic(self, dataloader, dataset_name, max_samples=10000): """Evaluate baseline Fashion-CLIP hierarchy performance on any dataset.""" print(f"\n{'=' * 60}") print(f"Evaluating Baseline Fashion-CLIP on {dataset_name}") print(f"{'=' * 60}") results = {} # --- text --- print("\nExtracting baseline text embeddings...") text_emb, _, text_hier = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) print(f" Baseline text shape: {text_emb.shape}") text_metrics = compute_similarity_metrics(text_emb, text_hier) text_class = self.evaluate_classification_performance( text_emb, text_hier, f"Baseline Text Hierarchy – {dataset_name}", "Hierarchy", method="nn", ) text_metrics.update(text_class) results['text'] = {'hierarchy': text_metrics} del text_emb if torch.cuda.is_available(): torch.cuda.empty_cache() # --- image --- print("\nExtracting baseline image embeddings...") img_emb, _, img_hier = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) print(f" Baseline image shape: {img_emb.shape}") img_metrics = compute_similarity_metrics(img_emb, img_hier) img_class = self.evaluate_classification_performance( img_emb, img_hier, f"Baseline Image Hierarchy – {dataset_name}", "Hierarchy", method="nn", ) img_metrics.update(img_class) results['image'] = {'hierarchy': img_metrics} del img_emb if torch.cuda.is_available(): torch.cuda.empty_cache() prefix = dataset_name.lower().replace(" ", "_") for key in ['text', 'image']: fig = results[key]['hierarchy']['figure'] fig.savefig( os.path.join(self.directory, f"baseline_{prefix}_{key}_hierarchy_confusion_matrix.png"), dpi=300, bbox_inches='tight', ) self.save_confusion_matrix_table( results[key]['hierarchy']['confusion_matrix'], results[key]['hierarchy']['labels'], os.path.join(self.directory, f"baseline_{prefix}_{key}_hierarchy_confusion_matrix.csv"), ) plt.close(fig) return results # ================================================================== # 6. Full evaluation across all datasets # ================================================================== def run_full_evaluation(self, max_samples=10000, batch_size=8): """Run hierarchy evaluation on all 3 datasets for both models.""" all_results = {} # --- Fashion-MNIST --- shared_dataset, shared_dataloader, shared_counts = self.prepare_shared_fashion_mnist( max_samples=max_samples, batch_size=batch_size, ) all_results['fashion_mnist_gap'] = self.evaluate_gap_clip_fashion_mnist( max_samples=max_samples, dataloader=shared_dataloader, expected_counts=shared_counts, ) all_results['fashion_mnist_baseline'] = self.evaluate_baseline_fashion_mnist( max_samples=max_samples, dataloader=shared_dataloader, expected_counts=shared_counts, ) # --- KAGL Marqo --- try: kaggle_dataset = load_kaggle_marqo_with_hierarchy( max_samples=max_samples, hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes, raw_df=self.kaggle_raw_df, ) if kaggle_dataset is not None and len(kaggle_dataset) > 0: kaggle_dataloader = DataLoader(kaggle_dataset, batch_size=batch_size, shuffle=False, num_workers=0) all_results['kaggle_gap'] = self.evaluate_gap_clip_generic( kaggle_dataloader, "KAGL Marqo", max_samples, ) all_results['kaggle_baseline'] = self.evaluate_baseline_generic( kaggle_dataloader, "KAGL Marqo", max_samples, ) else: print("WARNING: KAGL Marqo dataset empty after hierarchy mapping, skipping.") except Exception as e: print(f"WARNING: Could not evaluate on KAGL Marqo: {e}") # --- Internal (local validation) --- try: local_dataset = load_local_validation_with_hierarchy( max_samples=max_samples, hierarchy_classes=self.validation_hierarchy_classes or self.hierarchy_classes, raw_df=self.local_raw_df, ) if local_dataset is not None and len(local_dataset) > 0: local_dataloader = DataLoader(local_dataset, batch_size=batch_size, shuffle=False, num_workers=0) all_results['local_gap'] = self.evaluate_gap_clip_generic( local_dataloader, "Internal", max_samples, ) all_results['local_baseline'] = self.evaluate_baseline_generic( local_dataloader, "Internal", max_samples, ) else: print("WARNING: Local validation dataset empty after hierarchy filtering, skipping.") except Exception as e: print(f"WARNING: Could not evaluate on internal dataset: {e}") # --- Print summary --- print(f"\n{'=' * 70}") print("CATEGORY MODEL EVALUATION SUMMARY") print(f"{'=' * 70}") for dataset_key, label in [ ('fashion_mnist_gap', 'Fashion-MNIST (GAP-CLIP)'), ('fashion_mnist_baseline', 'Fashion-MNIST (Baseline)'), ('kaggle_gap', 'KAGL Marqo (GAP-CLIP)'), ('kaggle_baseline', 'KAGL Marqo (Baseline)'), ('local_gap', 'Internal (GAP-CLIP)'), ('local_baseline', 'Internal (Baseline)'), ]: if dataset_key not in all_results: continue res = all_results[dataset_key] print(f"\n{label}:") if 'text_hierarchy' in res: t = res['text_hierarchy'] i = res['image_hierarchy'] print(f" Text NN Acc: {t['accuracy']*100:.1f}% | Separation: {t['separation_score']:.4f}") print(f" Image NN Acc: {i['accuracy']*100:.1f}% | Separation: {i['separation_score']:.4f}") elif 'text' in res: t = res['text']['hierarchy'] i = res['image']['hierarchy'] print(f" Text NN Acc: {t['accuracy']*100:.1f}% | Separation: {t['separation_score']:.4f}") print(f" Image NN Acc: {i['accuracy']*100:.1f}% | Separation: {i['separation_score']:.4f}") return all_results # ============================================================================ # 7. Main # ============================================================================ if __name__ == "__main__": device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") print(f"Using device: {device}") directory = 'gap_clip_confusion_matrices' max_samples = 10000 evaluator = CategoryModelEvaluator(device=device, directory=directory) evaluator.run_full_evaluation(max_samples=max_samples, batch_size=8)