""" Hierarchy Embedding Evaluation with Fashion-CLIP Baseline Comparison This module provides comprehensive evaluation tools for hierarchy classification models, comparing custom model performance against the Fashion-CLIP baseline. It includes: - Embedding quality metrics (intra-class/inter-class similarity) - Classification accuracy with multiple methods (nearest neighbor, centroid-based) - Confusion matrix generation and visualization - Support for multiple datasets (validation set, Fashion-MNIST, Kaggle Marqo) - Advanced techniques: ZCA whitening, Mahalanobis distance, Test-Time Augmentation Key Features: - Custom model evaluation with full hierarchy classification pipeline - Fashion-CLIP baseline comparison for performance benchmarking - Multi-dataset evaluation (validation, Fashion-MNIST, Kaggle Marqo) - Flexible evaluation options (whitening, Mahalanobis distance) - Detailed metrics: accuracy, F1 scores, confusion matrices Author: Fashion Search Team License: Apache 2.0 """ # Standard library imports import os import warnings from collections import defaultdict from io import BytesIO from typing import Dict, List, Tuple, Optional, Union, Any # Third-party imports import numpy as np import pandas as pd import requests import torch import matplotlib.pyplot as plt import seaborn as sns from PIL import Image from sklearn.metrics import ( accuracy_score, classification_report, confusion_matrix, f1_score, ) from sklearn.metrics.pairwise import cosine_similarity from sklearn.model_selection import train_test_split from torch.utils.data import Dataset, DataLoader from torchvision import transforms from tqdm import tqdm from transformers import CLIPProcessor, CLIPModel as TransformersCLIPModel # Local imports import config from config import device, hierarchy_model_path, hierarchy_column, local_dataset_path from hierarchy_model import Model, HierarchyExtractor, HierarchyDataset, collate_fn # Suppress warnings for cleaner output warnings.filterwarnings('ignore') # ============================================================================ # CONSTANTS AND CONFIGURATION # ============================================================================ # Maximum number of samples for evaluation to prevent memory issues MAX_SAMPLES_EVALUATION = 10000 # Maximum number of inter-class comparisons to prevent O(n²) complexity MAX_INTER_CLASS_COMPARISONS = 10000 # Fashion-MNIST label mapping FASHION_MNIST_LABELS = { 0: "T-shirt/top", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle boot" } # ============================================================================ # UTILITY FUNCTIONS # ============================================================================ def convert_fashion_mnist_to_image(pixel_values: np.ndarray) -> Image.Image: """ Convert Fashion-MNIST pixel values to RGB PIL Image. Args: pixel_values: Flat array of 784 pixel values (28x28) Returns: PIL Image in RGB format """ # Reshape to 28x28 and convert to uint8 image_array = np.array(pixel_values).reshape(28, 28).astype(np.uint8) # Convert grayscale to RGB by duplicating channels image_array = np.stack([image_array] * 3, axis=-1) return Image.fromarray(image_array) def get_fashion_mnist_labels() -> Dict[int, str]: """ Get Fashion-MNIST class labels mapping. Returns: Dictionary mapping label IDs to class names """ return FASHION_MNIST_LABELS.copy() def create_fashion_mnist_to_hierarchy_mapping( hierarchy_classes: List[str] ) -> Dict[int, Optional[str]]: """ Create mapping from Fashion-MNIST labels to custom hierarchy classes. This function performs intelligent matching between Fashion-MNIST categories and the custom model's hierarchy classes using exact, partial, and semantic matching. Args: hierarchy_classes: List of hierarchy class names from the custom model Returns: Dictionary mapping Fashion-MNIST label IDs to hierarchy class names (None if no match found) """ # Normalize hierarchy classes to lowercase for matching hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] # Create mapping dictionary mapping = {} for fm_label_id, fm_label in FASHION_MNIST_LABELS.items(): fm_label_lower = fm_label.lower() matched_hierarchy = None # Strategy 1: Try exact match first if fm_label_lower in hierarchy_classes_lower: matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(fm_label_lower)] # Strategy 2: Try partial matches elif any(h in fm_label_lower or fm_label_lower in h for h in hierarchy_classes_lower): for h_class in hierarchy_classes: h_lower = h_class.lower() if h_lower in fm_label_lower or fm_label_lower in h_lower: matched_hierarchy = h_class break # Strategy 3: Semantic matching for common fashion categories else: # T-shirt/top -> shirt or top if fm_label_lower in ['t-shirt/top', 'top']: if 'top' in hierarchy_classes_lower: matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('top')] elif 'shirt' in hierarchy_classes_lower: matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('shirt')] # Trouser -> pant, bottom elif 'trouser' in fm_label_lower: for possible in ['pant', 'pants', 'trousers', 'trouser', 'bottom']: if possible in hierarchy_classes_lower: matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)] break # Pullover -> sweater, top elif 'pullover' in fm_label_lower: for possible in ['sweater', 'pullover', 'top']: if possible in hierarchy_classes_lower: matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)] break # Dress -> dress elif 'dress' in fm_label_lower: if 'dress' in hierarchy_classes_lower: matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('dress')] # Coat -> coat, jacket elif 'coat' in fm_label_lower: for possible in ['coat', 'jacket', 'outerwear']: if possible in hierarchy_classes_lower: matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)] break # Footwear: Sandal, Sneaker, Ankle boot -> shoes elif fm_label_lower in ['sandal', 'sneaker', 'ankle boot']: for possible in ['shoes', 'shoe', 'footwear', 'sandal', 'sneaker', 'boot']: if possible in hierarchy_classes_lower: matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)] break # Bag -> bag elif 'bag' in fm_label_lower: if 'bag' in hierarchy_classes_lower: matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('bag')] mapping[fm_label_id] = matched_hierarchy # Print mapping result if matched_hierarchy: print(f" {fm_label} ({fm_label_id}) -> {matched_hierarchy}") else: print(f" ⚠️ {fm_label} ({fm_label_id}) -> NO MATCH (will be filtered out)") return mapping # ============================================================================ # DATASET CLASSES # ============================================================================ class FashionMNISTDataset(Dataset): """ Fashion-MNIST Dataset class for evaluation. This dataset handles Fashion-MNIST images with proper preprocessing and label mapping to custom hierarchy classes. Aligned with main_model_evaluation.py for consistent evaluation across different scripts. Args: dataframe: Pandas DataFrame containing Fashion-MNIST data with pixel columns image_size: Target size for image resizing (default: 224) label_mapping: Optional mapping from Fashion-MNIST label IDs to hierarchy classes Returns: Tuple of (image_tensor, description, color, hierarchy) """ def __init__( self, dataframe: pd.DataFrame, image_size: int = 224, label_mapping: Optional[Dict[int, str]] = None ): self.dataframe = dataframe self.image_size = image_size self.labels_map = get_fashion_mnist_labels() self.label_mapping = label_mapping # Standard ImageNet normalization for transfer learning self.transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ]) def __len__(self) -> int: return len(self.dataframe) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, str, str]: """ Get a single item from the dataset. Args: idx: Index of the item to retrieve Returns: Tuple of (image_tensor, description, color, hierarchy) """ row = self.dataframe.iloc[idx] # Extract pixel values (784 pixels for 28x28 image) pixel_cols = [f"pixel{i}" for i in range(1, 785)] pixel_values = row[pixel_cols].values # Convert to PIL Image and apply transforms image = convert_fashion_mnist_to_image(pixel_values) image = self.transform(image) # Get label information label_id = int(row['label']) description = self.labels_map[label_id] color = "unknown" # Fashion-MNIST doesn't have color information # Use mapped hierarchy if available, otherwise use original label if self.label_mapping and label_id in self.label_mapping: hierarchy = self.label_mapping[label_id] else: hierarchy = self.labels_map[label_id] return image, description, color, hierarchy class CLIPDataset(Dataset): """ Dataset class for Fashion-CLIP baseline evaluation. This dataset handles image loading from various sources (local paths, URLs, PIL Images) and applies standard validation transforms without augmentation. Args: dataframe: Pandas DataFrame containing image and text data Returns: Tuple of (image_tensor, description, hierarchy) """ def __init__(self, dataframe: pd.DataFrame): self.dataframe = dataframe # Validation transforms (no augmentation for fair comparison) self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) def __len__(self) -> int: return len(self.dataframe) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, str]: """ Get a single item from the dataset. Args: idx: Index of the item to retrieve Returns: Tuple of (image_tensor, description, hierarchy) """ row = self.dataframe.iloc[idx] # Handle image loading from various sources image = self._load_image(row, idx) # Apply transforms image_tensor = self.transform(image) description = row[config.text_column] hierarchy = row[config.hierarchy_column] return image_tensor, description, hierarchy def _load_image(self, row: pd.Series, idx: int) -> Image.Image: """ Load image from various sources with fallback handling. Args: row: DataFrame row containing image information idx: Index for error reporting Returns: PIL Image in RGB format """ # Try loading from local path first if config.column_local_image_path in row.index and pd.notna(row[config.column_local_image_path]): local_path = row[config.column_local_image_path] try: if os.path.exists(local_path): return Image.open(local_path).convert("RGB") else: print(f"⚠️ Local image not found: {local_path}") except Exception as e: print(f"⚠️ Failed to load local image {idx}: {e}") # Try loading from various data formats image_data = row.get(config.column_url_image) # Handle dictionary format (with bytes) if isinstance(image_data, dict) and 'bytes' in image_data: return Image.open(BytesIO(image_data['bytes'])).convert('RGB') # Handle numpy array (Fashion-MNIST format) if isinstance(image_data, (list, np.ndarray)): pixels = np.array(image_data).reshape(28, 28) return Image.fromarray(pixels.astype(np.uint8)).convert("RGB") # Handle PIL Image directly if isinstance(image_data, Image.Image): return image_data.convert("RGB") # Try loading from URL as fallback try: response = requests.get(image_data, timeout=10) response.raise_for_status() return Image.open(BytesIO(response.content)).convert("RGB") except Exception as e: print(f"⚠️ Failed to load image {idx}: {e}") # Return gray placeholder image return Image.new('RGB', (224, 224), color='gray') # ============================================================================ # EVALUATOR CLASSES # ============================================================================ class CLIPBaselineEvaluator: """ Fashion-CLIP Baseline Evaluator. This class handles the loading and evaluation of the Fashion-CLIP baseline model (patrickjohncyh/fashion-clip) for comparison with custom models. Args: device: Device to run the model on ('cuda', 'mps', or 'cpu') """ def __init__(self, device: str = 'mps'): self.device = torch.device(device) # Load Fashion-CLIP model and processor print("🤗 Loading Fashion-CLIP baseline model from transformers...") model_name = "patrickjohncyh/fashion-clip" self.clip_model = TransformersCLIPModel.from_pretrained(model_name).to(self.device) self.clip_processor = CLIPProcessor.from_pretrained(model_name) self.clip_model.eval() print("✅ Fashion-CLIP model loaded successfully") def extract_clip_embeddings( self, images: List[Union[torch.Tensor, Image.Image]], texts: List[str] ) -> Tuple[np.ndarray, np.ndarray]: """ Extract Fashion-CLIP embeddings for images and texts. This method processes images and texts through the Fashion-CLIP model to generate normalized embeddings. Aligned with main_model_evaluation.py for consistency. Args: images: List of images (tensors or PIL Images) texts: List of text descriptions Returns: Tuple of (image_embeddings, text_embeddings) as numpy arrays """ all_image_embeddings = [] all_text_embeddings = [] # Process in batches for efficiency batch_size = 32 num_batches = (len(images) + batch_size - 1) // batch_size with torch.no_grad(): for batch_idx in tqdm(range(num_batches), desc="Extracting CLIP embeddings"): start_idx = batch_idx * batch_size end_idx = min(start_idx + batch_size, len(images)) batch_images = images[start_idx:end_idx] batch_texts = texts[start_idx:end_idx] # Extract text embeddings text_features = self._extract_text_features(batch_texts) # Extract image embeddings image_features = self._extract_image_features(batch_images) # Store results all_image_embeddings.append(image_features.cpu().numpy()) all_text_embeddings.append(text_features.cpu().numpy()) # Clear memory del text_features, image_features if torch.cuda.is_available(): torch.cuda.empty_cache() return np.vstack(all_image_embeddings), np.vstack(all_text_embeddings) def _extract_text_features(self, texts: List[str]) -> torch.Tensor: """ Extract text features using Fashion-CLIP. Args: texts: List of text descriptions Returns: Normalized text feature embeddings """ # Process text through Fashion-CLIP processor text_inputs = self.clip_processor( text=texts, return_tensors="pt", padding=True, truncation=True, max_length=77 ) text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} # Get text features using dedicated method text_features = self.clip_model.get_text_features(**text_inputs) # Apply L2 normalization (critical for CLIP!) text_features = text_features / text_features.norm(dim=-1, keepdim=True) return text_features def _extract_image_features( self, images: List[Union[torch.Tensor, Image.Image]] ) -> torch.Tensor: """ Extract image features using Fashion-CLIP. Args: images: List of images (tensors or PIL Images) Returns: Normalized image feature embeddings """ # Convert tensor images to PIL Images for proper processing pil_images = [] for img in images: if isinstance(img, torch.Tensor): pil_images.append(self._tensor_to_pil(img)) elif isinstance(img, Image.Image): pil_images.append(img) else: raise ValueError(f"Unsupported image type: {type(img)}") # Process images through Fashion-CLIP processor image_inputs = self.clip_processor( images=pil_images, return_tensors="pt" ) image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()} # Get image features using dedicated method image_features = self.clip_model.get_image_features(**image_inputs) # Apply L2 normalization (critical for CLIP!) image_features = image_features / image_features.norm(dim=-1, keepdim=True) return image_features def _tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image: """ Convert a normalized tensor to PIL Image. Args: tensor: Image tensor (C, H, W) Returns: PIL Image """ if tensor.dim() != 3: raise ValueError(f"Expected 3D tensor, got {tensor.dim()}D") # Denormalize if normalized (undo ImageNet normalization) if tensor.min() < 0 or tensor.max() > 1: mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) tensor = tensor * std + mean tensor = torch.clamp(tensor, 0, 1) # Convert to PIL return transforms.ToPILImage()(tensor) class EmbeddingEvaluator: """ Comprehensive Embedding Evaluator for Hierarchy Classification. This class provides a complete evaluation pipeline for hierarchy classification models, including custom model evaluation and Fashion-CLIP baseline comparison. It supports multiple evaluation metrics, datasets, and advanced techniques. Key Features: - Custom model loading and evaluation - Fashion-CLIP baseline comparison - Multiple classification methods (nearest neighbor, centroid, Mahalanobis) - Advanced techniques (ZCA whitening, Test-Time Augmentation) - Comprehensive metrics (accuracy, F1, confusion matrices) Args: model_path: Path to the trained custom model checkpoint directory: Output directory for saving evaluation results """ def __init__(self, model_path: str, directory: str): self.directory = directory self.device = device # Load and prepare dataset print(f"📁 Using dataset with local images: {local_dataset_path}") df = pd.read_csv(local_dataset_path) print(f"📁 Loaded {len(df)} samples") # Get unique hierarchy classes hierarchy_classes = sorted(df[hierarchy_column].unique().tolist()) print(f"📋 Found {len(hierarchy_classes)} hierarchy classes") # Limit dataset size to prevent memory issues if len(df) > MAX_SAMPLES_EVALUATION: print(f"⚠️ Dataset too large ({len(df)} samples), sampling to {MAX_SAMPLES_EVALUATION} samples") df = self._stratified_sample(df, MAX_SAMPLES_EVALUATION) # Create validation split (20% of data) _, self.val_df = train_test_split( df, test_size=0.2, random_state=42, stratify=df['hierarchy'] ) # Load the custom model self._load_model(model_path) # Initialize Fashion-CLIP baseline self.clip_evaluator = CLIPBaselineEvaluator(device) def _stratified_sample(self, df: pd.DataFrame, max_samples: int) -> pd.DataFrame: """ Perform stratified sampling to maintain class distribution. Args: df: Original DataFrame max_samples: Maximum number of samples to keep Returns: Sampled DataFrame """ # Stratified sampling by hierarchy df_sampled = df.groupby('hierarchy', group_keys=False).apply( lambda x: x.sample( n=min(len(x), int(max_samples * len(x) / len(df))), random_state=42 ) ).reset_index(drop=True) # Adjust to reach exactly max_samples if necessary if len(df_sampled) < max_samples: remaining = max_samples - len(df_sampled) extra = df.sample(n=remaining, random_state=42) df_sampled = pd.concat([df_sampled, extra]).reset_index(drop=True) return df_sampled def _load_model(self, model_path: str): """ Load the custom hierarchy classification model. Args: model_path: Path to the model checkpoint Raises: FileNotFoundError: If model file doesn't exist """ if not os.path.exists(model_path): raise FileNotFoundError(f"Model file {model_path} not found") # Load checkpoint checkpoint = torch.load(model_path, map_location=self.device) # Extract configuration config_dict = checkpoint.get('config', {}) saved_hierarchy_classes = checkpoint['hierarchy_classes'] # Store hierarchy classes self.hierarchy_classes = saved_hierarchy_classes # Create hierarchy extractor self.vocab = HierarchyExtractor(saved_hierarchy_classes) # Create model with saved configuration self.model = Model( num_hierarchy_classes=len(saved_hierarchy_classes), embed_dim=config_dict['embed_dim'], dropout=config_dict['dropout'] ).to(self.device) # Load model weights self.model.load_state_dict(checkpoint['model_state']) self.model.eval() # Print model information print(f"✅ Custom model loaded with:") print(f"📋 Hierarchy classes: {len(saved_hierarchy_classes)}") print(f"🎯 Embed dim: {config_dict['embed_dim']}") print(f"💧 Dropout: {config_dict['dropout']}") print(f"📅 Epoch: {checkpoint.get('epoch', 'unknown')}") def _collate_fn_wrapper(self, batch: List[Tuple]) -> Dict[str, torch.Tensor]: """ Wrapper for collate_fn that can be pickled (required for DataLoader). Handles both formats: - (image, description, hierarchy) for HierarchyDataset - (image, description, color, hierarchy) for FashionMNISTDataset Args: batch: List of samples from dataset Returns: Collated batch dictionary """ # Check batch format if len(batch[0]) == 4: # FashionMNISTDataset format: convert to expected format batch_converted = [(b[0], b[1], b[3]) for b in batch] return collate_fn(batch_converted, self.vocab) else: # HierarchyDataset format: use as is return collate_fn(batch, self.vocab) def create_dataloader( self, dataframe_or_dataset: Union[pd.DataFrame, Dataset], batch_size: int = 16 ) -> DataLoader: """ Create a DataLoader for the custom model. Aligned with main_model_evaluation.py for consistency. Args: dataframe_or_dataset: Either a pandas DataFrame or a Dataset object batch_size: Batch size for the DataLoader Returns: Configured DataLoader """ # Check if it's already a Dataset object if isinstance(dataframe_or_dataset, Dataset): dataset = dataframe_or_dataset print(f"🔍 Using pre-created Dataset object") # Otherwise create dataset from dataframe elif isinstance(dataframe_or_dataset, pd.DataFrame): # Check if this is Fashion-MNIST data if 'pixel1' in dataframe_or_dataset.columns: print(f"🔍 Detected Fashion-MNIST data, creating FashionMNISTDataset") dataset = FashionMNISTDataset(dataframe_or_dataset, image_size=224) else: dataset = HierarchyDataset(dataframe_or_dataset, image_size=224) else: raise ValueError(f"Unsupported type: {type(dataframe_or_dataset)}") # Create DataLoader # Note: num_workers=0 to avoid pickling issues on macOS dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, collate_fn=self._collate_fn_wrapper, num_workers=0, pin_memory=False ) return dataloader def create_clip_dataloader( self, dataframe_or_dataset: Union[pd.DataFrame, Dataset], batch_size: int = 16 ) -> DataLoader: """ Create a DataLoader for Fashion-CLIP baseline. Args: dataframe_or_dataset: Either a pandas DataFrame or a Dataset object batch_size: Batch size for the DataLoader Returns: Configured DataLoader """ # Check if it's already a Dataset object if isinstance(dataframe_or_dataset, Dataset): dataset = dataframe_or_dataset print(f"🔍 Using pre-created Dataset object for CLIP") # Otherwise create dataset from dataframe elif isinstance(dataframe_or_dataset, pd.DataFrame): # Check if this is Fashion-MNIST data if 'pixel1' in dataframe_or_dataset.columns: print("🔍 Detected Fashion-MNIST data for Fashion-CLIP") dataset = FashionMNISTDataset(dataframe_or_dataset, image_size=224) else: dataset = CLIPDataset(dataframe_or_dataset) else: raise ValueError(f"Unsupported type: {type(dataframe_or_dataset)}") # Create DataLoader dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False ) return dataloader def extract_custom_embeddings( self, dataloader: DataLoader, embedding_type: str = 'text', use_tta: bool = False ) -> Tuple[np.ndarray, List[str], List[str]]: """ Extract embeddings from custom model with optional Test-Time Augmentation. Args: dataloader: DataLoader for the dataset embedding_type: Type of embedding to extract ('text', 'image', or 'both') use_tta: Whether to use Test-Time Augmentation for images Returns: Tuple of (embeddings, labels, texts) """ all_embeddings = [] all_labels = [] all_texts = [] with torch.no_grad(): for batch in tqdm(dataloader, desc=f"Extracting custom {embedding_type} embeddings{' with TTA' if use_tta else ''}"): images = batch['image'].to(self.device) hierarchy_indices = batch['hierarchy_indices'].to(self.device) hierarchy_labels = batch['hierarchy'] # Handle Test-Time Augmentation if use_tta and embedding_type == 'image' and images.dim() == 5: embeddings = self._extract_with_tta(images, hierarchy_indices) else: # Standard forward pass out = self.model(image=images, hierarchy_indices=hierarchy_indices) embeddings = out['z_txt'] if embedding_type == 'text' else out['z_img'] all_embeddings.append(embeddings.cpu().numpy()) all_labels.extend(hierarchy_labels) all_texts.extend(hierarchy_labels) # Clear memory del images, hierarchy_indices, embeddings, out if str(self.device) != 'cpu': if torch.cuda.is_available(): torch.cuda.empty_cache() return np.vstack(all_embeddings), all_labels, all_texts def _extract_with_tta( self, images: torch.Tensor, hierarchy_indices: torch.Tensor ) -> torch.Tensor: """ Extract embeddings using Test-Time Augmentation. Args: images: Images with TTA crops [batch_size, tta_crops, C, H, W] hierarchy_indices: Hierarchy class indices Returns: Averaged embeddings [batch_size, embed_dim] """ batch_size, tta_crops, C, H, W = images.shape # Reshape to [batch_size * tta_crops, C, H, W] images_flat = images.view(batch_size * tta_crops, C, H, W) # Repeat hierarchy indices for each TTA crop hierarchy_indices_repeated = hierarchy_indices.unsqueeze(1).repeat(1, tta_crops).view(-1) # Forward pass on all TTA crops out = self.model(image=images_flat, hierarchy_indices=hierarchy_indices_repeated) embeddings_flat = out['z_img'] # Reshape back to [batch_size, tta_crops, embed_dim] embeddings = embeddings_flat.view(batch_size, tta_crops, -1) # Average over TTA crops embeddings = embeddings.mean(dim=1) return embeddings def apply_whitening( self, embeddings: np.ndarray, epsilon: float = 1e-5 ) -> np.ndarray: """ Apply ZCA whitening to embeddings for better feature decorrelation. Whitening removes correlations between dimensions and can improve class separation by normalizing the feature space. Args: embeddings: Input embeddings [N, D] epsilon: Small constant for numerical stability Returns: Whitened embeddings [N, D] """ # Center the data mean = np.mean(embeddings, axis=0, keepdims=True) centered = embeddings - mean # Compute covariance matrix cov = np.cov(centered.T) # Eigenvalue decomposition eigenvalues, eigenvectors = np.linalg.eigh(cov) # ZCA whitening transformation d = np.diag(1.0 / np.sqrt(eigenvalues + epsilon)) whiten_transform = eigenvectors @ d @ eigenvectors.T # Apply whitening whitened = centered @ whiten_transform # L2 normalize after whitening norms = np.linalg.norm(whitened, axis=1, keepdims=True) whitened = whitened / (norms + epsilon) return whitened def compute_similarity_metrics( self, embeddings: np.ndarray, labels: List[str], apply_whitening_norm: bool = False ) -> Dict[str, Any]: """ Compute intra-class and inter-class similarity metrics. Args: embeddings: Embedding vectors labels: Class labels apply_whitening_norm: Whether to apply ZCA whitening Returns: Dictionary containing similarity metrics and accuracies """ # Apply whitening if requested if apply_whitening_norm: embeddings = self.apply_whitening(embeddings) # Compute pairwise cosine similarities similarities = cosine_similarity(embeddings) # Group embeddings by hierarchy hierarchy_groups = defaultdict(list) for i, hierarchy in enumerate(labels): hierarchy_groups[hierarchy].append(i) # Calculate intra-class similarities (same hierarchy) intra_class_similarities = self._compute_intra_class_similarities( similarities, hierarchy_groups ) # Calculate inter-class similarities (different hierarchies) inter_class_similarities = self._compute_inter_class_similarities( similarities, hierarchy_groups ) # Calculate classification accuracies nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities) centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels) return { 'intra_class_similarities': intra_class_similarities, 'inter_class_similarities': inter_class_similarities, 'intra_class_mean': np.mean(intra_class_similarities) if intra_class_similarities else 0, 'inter_class_mean': np.mean(inter_class_similarities) if inter_class_similarities else 0, 'separation_score': np.mean(intra_class_similarities) - np.mean(inter_class_similarities) if intra_class_similarities and inter_class_similarities else 0, 'accuracy': nn_accuracy, 'centroid_accuracy': centroid_accuracy } def _compute_intra_class_similarities( self, similarities: np.ndarray, hierarchy_groups: Dict[str, List[int]] ) -> List[float]: """ Compute within-class similarities. Args: similarities: Pairwise similarity matrix hierarchy_groups: Mapping from hierarchy to sample indices Returns: List of intra-class similarity values """ intra_class_similarities = [] for hierarchy, indices in hierarchy_groups.items(): if len(indices) > 1: # Compare all pairs within the same class for i in range(len(indices)): for j in range(i + 1, len(indices)): sim = similarities[indices[i], indices[j]] intra_class_similarities.append(sim) return intra_class_similarities def _compute_inter_class_similarities( self, similarities: np.ndarray, hierarchy_groups: Dict[str, List[int]] ) -> List[float]: """ Compute between-class similarities with sampling for efficiency. To prevent O(n²) complexity on large datasets, we limit the number of comparisons through sampling. Args: similarities: Pairwise similarity matrix hierarchy_groups: Mapping from hierarchy to sample indices Returns: List of inter-class similarity values """ inter_class_similarities = [] hierarchies = list(hierarchy_groups.keys()) comparison_count = 0 for i in range(len(hierarchies)): for j in range(i + 1, len(hierarchies)): hierarchy1_indices = hierarchy_groups[hierarchies[i]] hierarchy2_indices = hierarchy_groups[hierarchies[j]] # Sample if too many comparisons max_samples_per_pair = min(100, len(hierarchy1_indices), len(hierarchy2_indices)) sampled_idx1 = np.random.choice( hierarchy1_indices, size=min(max_samples_per_pair, len(hierarchy1_indices)), replace=False ) sampled_idx2 = np.random.choice( hierarchy2_indices, size=min(max_samples_per_pair, len(hierarchy2_indices)), replace=False ) # Compute similarities between sampled pairs for idx1 in sampled_idx1: for idx2 in sampled_idx2: if comparison_count >= MAX_INTER_CLASS_COMPARISONS: break sim = similarities[idx1, idx2] inter_class_similarities.append(sim) comparison_count += 1 if comparison_count >= MAX_INTER_CLASS_COMPARISONS: break if comparison_count >= MAX_INTER_CLASS_COMPARISONS: break if comparison_count >= MAX_INTER_CLASS_COMPARISONS: break return inter_class_similarities def compute_embedding_accuracy( self, embeddings: np.ndarray, labels: List[str], similarities: np.ndarray ) -> float: """ Compute classification accuracy using nearest neighbor in embedding space. Args: embeddings: Embedding vectors labels: True class labels similarities: Precomputed similarity matrix Returns: Classification accuracy """ correct_predictions = 0 total_predictions = len(labels) for i in range(len(embeddings)): true_label = labels[i] # Find the most similar embedding (excluding itself) similarities_row = similarities[i].copy() similarities_row[i] = -1 # Exclude self-similarity nearest_neighbor_idx = np.argmax(similarities_row) predicted_label = labels[nearest_neighbor_idx] if predicted_label == true_label: correct_predictions += 1 return correct_predictions / total_predictions if total_predictions > 0 else 0 def compute_centroid_accuracy( self, embeddings: np.ndarray, labels: List[str] ) -> float: """ Compute classification accuracy using hierarchy centroids. Args: embeddings: Embedding vectors labels: True class labels Returns: Classification accuracy """ # Create centroids for each hierarchy unique_hierarchies = list(set(labels)) centroids = {} for hierarchy in unique_hierarchies: hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy] hierarchy_embeddings = embeddings[hierarchy_indices] centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0) # Classify each embedding to nearest centroid correct_predictions = 0 total_predictions = len(labels) for i, embedding in enumerate(embeddings): true_label = labels[i] # Find closest centroid best_similarity = -1 predicted_label = None for hierarchy, centroid in centroids.items(): similarity = cosine_similarity([embedding], [centroid])[0][0] if similarity > best_similarity: best_similarity = similarity predicted_label = hierarchy if predicted_label == true_label: correct_predictions += 1 return correct_predictions / total_predictions if total_predictions > 0 else 0 def compute_mahalanobis_distance( self, point: np.ndarray, centroid: np.ndarray, cov_inv: np.ndarray ) -> float: """ Compute Mahalanobis distance between a point and a centroid. The Mahalanobis distance takes into account the covariance structure of the data, making it more robust than Euclidean distance for high-dimensional spaces. Args: point: Query point centroid: Class centroid cov_inv: Inverse covariance matrix Returns: Mahalanobis distance """ diff = point - centroid distance = np.sqrt(np.dot(np.dot(diff, cov_inv), diff.T)) return distance def predict_hierarchy_from_embeddings( self, embeddings: np.ndarray, labels: List[str], use_mahalanobis: bool = False ) -> List[str]: """ Predict hierarchy from embeddings using centroid-based classification. Args: embeddings: Embedding vectors labels: Training labels for computing centroids use_mahalanobis: Whether to use Mahalanobis distance Returns: List of predicted hierarchy labels """ # Create hierarchy centroids from training data unique_hierarchies = list(set(labels)) centroids = {} cov_inverses = {} for hierarchy in unique_hierarchies: hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy] hierarchy_embeddings = embeddings[hierarchy_indices] centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0) # Compute covariance for Mahalanobis distance if use_mahalanobis and len(hierarchy_embeddings) > 1: cov = np.cov(hierarchy_embeddings.T) # Add regularization for numerical stability cov += np.eye(cov.shape[0]) * 1e-6 try: cov_inverses[hierarchy] = np.linalg.inv(cov) except np.linalg.LinAlgError: # If inversion fails, fallback to identity (Euclidean) cov_inverses[hierarchy] = np.eye(cov.shape[0]) # Predict hierarchy for all embeddings predictions = [] for embedding in embeddings: if use_mahalanobis: predicted_hierarchy = self._predict_with_mahalanobis( embedding, centroids, cov_inverses ) else: predicted_hierarchy = self._predict_with_cosine( embedding, centroids ) predictions.append(predicted_hierarchy) return predictions def _predict_with_mahalanobis( self, embedding: np.ndarray, centroids: Dict[str, np.ndarray], cov_inverses: Dict[str, np.ndarray] ) -> str: """ Predict class using Mahalanobis distance (lower is better). Args: embedding: Query embedding centroids: Class centroids cov_inverses: Inverse covariance matrices Returns: Predicted class label """ best_distance = float('inf') predicted_hierarchy = None for hierarchy, centroid in centroids.items(): if hierarchy in cov_inverses: distance = self.compute_mahalanobis_distance( embedding, centroid, cov_inverses[hierarchy] ) else: # Fallback to cosine similarity for classes with insufficient samples similarity = cosine_similarity([embedding], [centroid])[0][0] distance = 1 - similarity if distance < best_distance: best_distance = distance predicted_hierarchy = hierarchy return predicted_hierarchy def _predict_with_cosine( self, embedding: np.ndarray, centroids: Dict[str, np.ndarray] ) -> str: """ Predict class using cosine similarity (higher is better). Args: embedding: Query embedding centroids: Class centroids Returns: Predicted class label """ best_similarity = -1 predicted_hierarchy = None for hierarchy, centroid in centroids.items(): similarity = cosine_similarity([embedding], [centroid])[0][0] if similarity > best_similarity: best_similarity = similarity predicted_hierarchy = hierarchy return predicted_hierarchy def create_confusion_matrix( self, true_labels: List[str], predicted_labels: List[str], title: str = "Confusion Matrix" ) -> Tuple[plt.Figure, float, np.ndarray]: """ Create and plot confusion matrix. Args: true_labels: Ground truth labels predicted_labels: Predicted labels title: Plot title Returns: Tuple of (figure, accuracy, confusion_matrix) """ # Get unique labels unique_labels = sorted(list(set(true_labels + predicted_labels))) # Create confusion matrix cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels) # Calculate accuracy accuracy = accuracy_score(true_labels, predicted_labels) # Plot confusion matrix plt.figure(figsize=(12, 10)) sns.heatmap( cm, annot=True, fmt='d', cmap='Blues', xticklabels=unique_labels, yticklabels=unique_labels ) plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)') plt.ylabel('True Hierarchy') plt.xlabel('Predicted Hierarchy') plt.xticks(rotation=45) plt.yticks(rotation=0) plt.tight_layout() return plt.gcf(), accuracy, cm def evaluate_classification_performance( self, embeddings: np.ndarray, labels: List[str], embedding_type: str = "Embeddings", apply_whitening_norm: bool = False, use_mahalanobis: bool = False ) -> Dict[str, Any]: """ Evaluate classification performance and create confusion matrix. Args: embeddings: Embedding vectors labels: True class labels embedding_type: Description of embedding type for display apply_whitening_norm: Whether to apply ZCA whitening use_mahalanobis: Whether to use Mahalanobis distance Returns: Dictionary containing classification metrics and visualizations """ # Apply whitening if requested if apply_whitening_norm: embeddings = self.apply_whitening(embeddings) # Predict hierarchy predictions = self.predict_hierarchy_from_embeddings( embeddings, labels, use_mahalanobis=use_mahalanobis ) # Calculate accuracy accuracy = accuracy_score(labels, predictions) # Calculate F1 scores unique_labels = sorted(list(set(labels))) f1_macro = f1_score( labels, predictions, labels=unique_labels, average='macro', zero_division=0 ) f1_weighted = f1_score( labels, predictions, labels=unique_labels, average='weighted', zero_division=0 ) f1_per_class = f1_score( labels, predictions, labels=unique_labels, average=None, zero_division=0 ) # Create confusion matrix fig, acc, cm = self.create_confusion_matrix( labels, predictions, f"{embedding_type} - Hierarchy Classification" ) # Generate classification report report = classification_report( labels, predictions, labels=unique_labels, target_names=unique_labels, output_dict=True ) return { 'accuracy': accuracy, 'f1_macro': f1_macro, 'f1_weighted': f1_weighted, 'f1_per_class': f1_per_class, 'predictions': predictions, 'confusion_matrix': cm, 'classification_report': report, 'figure': fig } def evaluate_dataset_with_baselines( self, dataframe: Union[pd.DataFrame, Dataset], dataset_name: str = "Dataset", use_whitening: bool = False, use_mahalanobis: bool = False ) -> Dict[str, Dict[str, Any]]: """ Evaluate embeddings on a given dataset with both custom model and CLIP baseline. This is the main evaluation method that compares the custom model against the Fashion-CLIP baseline across multiple metrics and embedding types. Aligned with main_model_evaluation.py for consistency (no TTA for fair comparison). Args: dataframe: DataFrame or Dataset to evaluate on dataset_name: Name of the dataset for display use_whitening: Whether to apply ZCA whitening use_mahalanobis: Whether to use Mahalanobis distance Returns: Dictionary containing results for all models and embedding types """ print(f"\n{'='*60}") print(f"Evaluating {dataset_name}") if use_whitening: print(f"🎯 ZCA Whitening ENABLED for better feature decorrelation") if use_mahalanobis: print(f"🎯 Mahalanobis Distance ENABLED for classification") print(f"{'='*60}") results = {} # ===== CUSTOM MODEL EVALUATION ===== print(f"\n🔧 Evaluating Custom Model on {dataset_name}") print("-" * 40) # Create dataloader custom_dataloader = self.create_dataloader(dataframe, batch_size=16) # Evaluate text embeddings text_embeddings, text_labels, texts = self.extract_custom_embeddings( custom_dataloader, 'text', use_tta=False ) text_metrics = self.compute_similarity_metrics( text_embeddings, text_labels, apply_whitening_norm=use_whitening ) text_classification = self.evaluate_classification_performance( text_embeddings, text_labels, "Custom Text Embeddings", apply_whitening_norm=use_whitening, use_mahalanobis=use_mahalanobis ) text_metrics.update(text_classification) results['custom_text'] = text_metrics # Evaluate image embeddings # NOTE: TTA disabled for fair comparison image_embeddings, image_labels, _ = self.extract_custom_embeddings( custom_dataloader, 'image', use_tta=False ) image_metrics = self.compute_similarity_metrics( image_embeddings, image_labels, apply_whitening_norm=use_whitening ) whitening_suffix = " + Whitening" if use_whitening else "" mahalanobis_suffix = " + Mahalanobis" if use_mahalanobis else "" image_classification = self.evaluate_classification_performance( image_embeddings, image_labels, f"Custom Image Embeddings{whitening_suffix}{mahalanobis_suffix}", apply_whitening_norm=use_whitening, use_mahalanobis=use_mahalanobis ) image_metrics.update(image_classification) results['custom_image'] = image_metrics # ===== FASHION-CLIP BASELINE EVALUATION ===== print(f"\n🤗 Evaluating Fashion-CLIP Baseline on {dataset_name}") print("-" * 40) # Create dataloader for Fashion-CLIP clip_dataloader = self.create_clip_dataloader(dataframe, batch_size=8) # Extract data for Fashion-CLIP all_images = [] all_texts = [] all_labels = [] for batch in tqdm(clip_dataloader, desc="Preparing data for Fashion-CLIP"): # Handle different batch formats if len(batch) == 4: images, descriptions, colors, hierarchies = batch else: images, descriptions, hierarchies = batch all_images.extend(images) all_texts.extend(descriptions) all_labels.extend(hierarchies) # Get Fashion-CLIP embeddings clip_image_embeddings, clip_text_embeddings = self.clip_evaluator.extract_clip_embeddings( all_images, all_texts ) # Evaluate Fashion-CLIP text embeddings clip_text_metrics = self.compute_similarity_metrics( clip_text_embeddings, all_labels ) clip_text_classification = self.evaluate_classification_performance( clip_text_embeddings, all_labels, "Fashion-CLIP Text Embeddings" ) clip_text_metrics.update(clip_text_classification) results['clip_text'] = clip_text_metrics # Evaluate Fashion-CLIP image embeddings clip_image_metrics = self.compute_similarity_metrics( clip_image_embeddings, all_labels ) clip_image_classification = self.evaluate_classification_performance( clip_image_embeddings, all_labels, "Fashion-CLIP Image Embeddings" ) clip_image_metrics.update(clip_image_classification) results['clip_image'] = clip_image_metrics # ===== PRINT COMPARISON RESULTS ===== self._print_comparison_results(dataframe, dataset_name, results) # ===== SAVE VISUALIZATIONS ===== self._save_visualizations(dataset_name, results) return results def _print_comparison_results( self, dataframe: Union[pd.DataFrame, Dataset], dataset_name: str, results: Dict[str, Dict[str, Any]] ): """ Print formatted comparison results. Args: dataframe: Dataset being evaluated dataset_name: Name of the dataset results: Evaluation results dictionary """ dataset_size = len(dataframe) if hasattr(dataframe, '__len__') else "N/A" print(f"\n{dataset_name} Results Comparison:") print(f"Dataset size: {dataset_size} samples") print("=" * 80) print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<10} {'NN Acc':<8} {'Centroid Acc':<12} {'F1 Macro':<10}") print("-" * 80) for model_type in ['custom', 'clip']: for emb_type in ['text', 'image']: key = f"{model_type}_{emb_type}" if key in results: metrics = results[key] model_name = "Custom Model" if model_type == 'custom' else "Fashion-CLIP Baseline" print( f"{model_name:<20} " f"{emb_type.capitalize():<10} " f"{metrics['separation_score']:<10.4f} " f"{metrics['accuracy']*100:<8.1f}% " f"{metrics['centroid_accuracy']*100:<12.1f}% " f"{metrics['f1_macro']*100:<10.1f}%" ) def _save_visualizations( self, dataset_name: str, results: Dict[str, Dict[str, Any]] ): """ Save confusion matrices and other visualizations. Args: dataset_name: Name of the dataset results: Evaluation results dictionary """ os.makedirs(self.directory, exist_ok=True) # Save confusion matrices for key, metrics in results.items(): if 'figure' in metrics: filename = f'{self.directory}/{dataset_name.lower()}_{key}_confusion_matrix.png' metrics['figure'].savefig(filename, dpi=300, bbox_inches='tight') plt.close(metrics['figure']) # ============================================================================ # DATASET LOADING FUNCTIONS # ============================================================================ def load_fashion_mnist_dataset( evaluator: EmbeddingEvaluator, max_samples: int = 1000 ) -> FashionMNISTDataset: """ Load and prepare Fashion-MNIST test dataset. This function loads the Fashion-MNIST test set and creates appropriate mappings to the custom model's hierarchy classes. Exactly aligned with main_model_evaluation.py for consistency. Args: evaluator: EmbeddingEvaluator instance with loaded model max_samples: Maximum number of samples to use Returns: FashionMNISTDataset object """ print("📊 Loading Fashion-MNIST test dataset...") df = pd.read_csv(config.fashion_mnist_test_path) print(f"✅ Fashion-MNIST dataset loaded: {len(df)} samples") # Create mapping if hierarchy classes are provided label_mapping = None if evaluator.hierarchy_classes is not None: print("\n🔗 Creating mapping from Fashion-MNIST labels to hierarchy classes:") label_mapping = create_fashion_mnist_to_hierarchy_mapping( evaluator.hierarchy_classes ) # Filter dataset to only include samples that can be mapped valid_label_ids = [ label_id for label_id, hierarchy in label_mapping.items() if hierarchy is not None ] df_filtered = df[df['label'].isin(valid_label_ids)] print( f"\n📊 After filtering to mappable labels: " f"{len(df_filtered)} samples (from {len(df)})" ) # Apply max_samples limit after filtering df_sample = df_filtered.head(max_samples) else: df_sample = df.head(max_samples) print(f"📊 Using {len(df_sample)} samples for evaluation") return FashionMNISTDataset(df_sample, label_mapping=label_mapping) def load_kagl_marqo_dataset(evaluator: EmbeddingEvaluator) -> pd.DataFrame: """ Load and prepare Kaggle Marqo dataset for evaluation. This function loads the Marqo fashion dataset from Hugging Face and preprocesses it for evaluation with the custom model. Args: evaluator: EmbeddingEvaluator instance with loaded model Returns: Formatted pandas DataFrame ready for evaluation """ from datasets import load_dataset print("📊 Loading Kaggle Marqo dataset...") # Load the dataset from Hugging Face dataset = load_dataset("Marqo/KAGL") df = dataset["data"].to_pandas() print(f"✅ Dataset Kaggle loaded") print(f"📊 Before filtering: {len(df)} samples") print(f"📋 Available columns: {list(df.columns)}") print(f"🎨 Available categories: {sorted(df['category2'].unique())}") # Map categories to our hierarchy format df['hierarchy'] = df['category2'].str.lower() df['hierarchy'] = df['hierarchy'].replace({ 'bags': 'bag', 'topwear': 'top', 'flip flops': 'shoes', 'sandal': 'shoes' }) # Filter to only include valid hierarchies valid_hierarchies = df['hierarchy'].dropna().unique() print(f"🎯 Valid hierarchies found: {sorted(valid_hierarchies)}") print(f"🎯 Model hierarchies: {sorted(evaluator.hierarchy_classes)}") df = df[df['hierarchy'].isin(evaluator.hierarchy_classes)] print(f"📊 After filtering to model hierarchies: {len(df)} samples") if len(df) == 0: print("❌ No samples left after hierarchy filtering.") return pd.DataFrame() # Ensure we have text and image data df = df.dropna(subset=['text', 'image']) print(f"📊 After removing missing text/image: {len(df)} samples") # Show sample of text data to verify quality print(f"📝 Sample texts:") for i, (text, hierarchy) in enumerate(zip(df['text'].head(3), df['hierarchy'].head(3))): print(f" {i+1}. [{hierarchy}] {text[:100]}...") # Limit size to prevent memory overload max_samples = 1000 if len(df) > max_samples: print(f"⚠️ Dataset too large ({len(df)} samples), sampling to {max_samples} samples") df_test = df.sample(n=max_samples, random_state=42).reset_index(drop=True) else: df_test = df.copy() print(f"📊 After sampling: {len(df_test)} samples") print(f"📊 Samples per hierarchy:") for hierarchy in sorted(df_test['hierarchy'].unique()): count = len(df_test[df_test['hierarchy'] == hierarchy]) print(f" {hierarchy}: {count} samples") # Create formatted dataset with proper column names kagl_formatted = pd.DataFrame({ 'image_url': df_test['image'], 'text': df_test['text'], 'hierarchy': df_test['hierarchy'] }) print(f"📊 Final dataset size: {len(kagl_formatted)} samples") return kagl_formatted # ============================================================================ # MAIN EXECUTION # ============================================================================ def main(): """ Main evaluation function that runs comprehensive evaluation across multiple datasets. This function evaluates the custom hierarchy classification model against the Fashion-CLIP baseline on: 1. Validation dataset (from training data) 2. Fashion-MNIST test dataset 3. Kaggle Marqo dataset Results include detailed metrics, confusion matrices, and performance comparisons. """ # Setup output directory directory = "hierarchy_model_analysis" print(f"🚀 Starting evaluation with custom model: {hierarchy_model_path}") print(f"🤗 Including Fashion-CLIP baseline comparison") # Initialize evaluator evaluator = EmbeddingEvaluator(hierarchy_model_path, directory) print( f"📊 Final hierarchy classes after initialization: " f"{len(evaluator.vocab.hierarchy_classes)} classes" ) # ===== EVALUATION 1: VALIDATION DATASET ===== print("\n" + "="*60) print("EVALUATING VALIDATION DATASET - CUSTOM MODEL vs FASHION-CLIP BASELINE") print("="*60) val_results = evaluator.evaluate_dataset_with_baselines( evaluator.val_df, "Validation Dataset" ) # ===== EVALUATION 2: FASHION-MNIST TEST DATASET ===== print("\n" + "="*60) print("EVALUATING FASHION-MNIST TEST DATASET - CUSTOM MODEL vs FASHION-CLIP BASELINE") print("="*60) fashion_mnist_dataset = load_fashion_mnist_dataset(evaluator, max_samples=1000) if fashion_mnist_dataset is not None: # Aligned with main_model_evaluation.py: NO TTA for fair baseline comparison fashion_mnist_results = evaluator.evaluate_dataset_with_baselines( fashion_mnist_dataset, "Fashion-MNIST Test Dataset", use_whitening=False, # Disabled for fair comparison use_mahalanobis=False # Disabled for fair comparison ) else: fashion_mnist_results = {} # ===== EVALUATION 3: KAGGLE MARQO DATASET ===== print("\n" + "="*60) print("EVALUATING KAGGLE MARQO DATASET - CUSTOM MODEL vs FASHION-CLIP BASELINE") print("="*60) df_kagl_marqo = load_kagl_marqo_dataset(evaluator) if len(df_kagl_marqo) > 0: kagl_results = evaluator.evaluate_dataset_with_baselines( df_kagl_marqo, "Kaggle Marqo Dataset" ) else: kagl_results = {} # ===== FINAL SUMMARY ===== print(f"\n{'='*80}") print("FINAL EVALUATION SUMMARY - CUSTOM MODEL vs FASHION-CLIP BASELINE") print(f"{'='*80}") # Print validation results print("\n🔍 VALIDATION DATASET RESULTS:") _print_dataset_results(val_results, len(evaluator.val_df)) # Print Fashion-MNIST results if fashion_mnist_results: print("\n👗 FASHION-MNIST TEST DATASET RESULTS:") _print_dataset_results(fashion_mnist_results, 1000) # Print Kaggle results if kagl_results: print("\n🌐 KAGGLE MARQO DATASET RESULTS:") _print_dataset_results( kagl_results, len(df_kagl_marqo) if df_kagl_marqo is not None else 'N/A' ) # Final completion message print(f"\n✅ Evaluation completed! Check '{directory}/' for visualization files.") print(f"📊 Custom model hierarchy classes: {len(evaluator.vocab.hierarchy_classes)} classes") print(f"🤗 Fashion-CLIP baseline comparison included") def _print_dataset_results(results: Dict[str, Dict[str, Any]], dataset_size: int): """ Print formatted results for a single dataset. Args: results: Dictionary containing evaluation results dataset_size: Number of samples in the dataset """ print(f"Dataset size: {dataset_size} samples") print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<12} {'NN Acc':<10} {'Centroid Acc':<12} {'F1 Macro':<10}") print("-" * 80) for model_type in ['custom', 'clip']: for emb_type in ['text', 'image']: key = f"{model_type}_{emb_type}" if key in results: metrics = results[key] model_name = "Custom Model" if model_type == 'custom' else "Fashion-CLIP Baseline" print( f"{model_name:<20} " f"{emb_type.capitalize():<10} " f"{metrics['separation_score']:<12.4f} " f"{metrics['accuracy']*100:<10.1f}% " f"{metrics['centroid_accuracy']*100:<12.1f}% " f"{metrics['f1_macro']*100:<10.1f}%" ) if __name__ == "__main__": main()