gap-clip / evaluation /hierarchy_evaluation.py
Leacb4's picture
Upload evaluation/hierarchy_evaluation.py with huggingface_hub
d8d8ac9 verified
"""
Hierarchy Embedding Evaluation with Fashion-CLIP Baseline Comparison
This module provides comprehensive evaluation tools for hierarchy classification models,
comparing custom model performance against the Fashion-CLIP baseline. It includes:
- Embedding quality metrics (intra-class/inter-class similarity)
- Classification accuracy with multiple methods (nearest neighbor, centroid-based)
- Confusion matrix generation and visualization
- Support for multiple datasets (validation set, Fashion-MNIST, Kaggle Marqo)
- Advanced techniques: ZCA whitening, Mahalanobis distance, Test-Time Augmentation
Key Features:
- Custom model evaluation with full hierarchy classification pipeline
- Fashion-CLIP baseline comparison for performance benchmarking
- Multi-dataset evaluation (validation, Fashion-MNIST, Kaggle Marqo)
- Flexible evaluation options (whitening, Mahalanobis distance)
- Detailed metrics: accuracy, F1 scores, confusion matrices
Author: Fashion Search Team
License: Apache 2.0
"""
# Standard library imports
import os
import warnings
from collections import defaultdict
from io import BytesIO
from typing import Dict, List, Tuple, Optional, Union, Any
# Third-party imports
import numpy as np
import pandas as pd
import requests
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from sklearn.metrics import (
accuracy_score,
classification_report,
confusion_matrix,
f1_score,
)
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel as TransformersCLIPModel
# Local imports
import config
from config import device, hierarchy_model_path, hierarchy_column, local_dataset_path
from hierarchy_model import Model, HierarchyExtractor, HierarchyDataset, collate_fn
# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')
# ============================================================================
# CONSTANTS AND CONFIGURATION
# ============================================================================
# Maximum number of samples for evaluation to prevent memory issues
MAX_SAMPLES_EVALUATION = 10000
# Maximum number of inter-class comparisons to prevent O(nΒ²) complexity
MAX_INTER_CLASS_COMPARISONS = 10000
# Fashion-MNIST label mapping
FASHION_MNIST_LABELS = {
0: "T-shirt/top",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle boot"
}
# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================
def convert_fashion_mnist_to_image(pixel_values: np.ndarray) -> Image.Image:
"""
Convert Fashion-MNIST pixel values to RGB PIL Image.
Args:
pixel_values: Flat array of 784 pixel values (28x28)
Returns:
PIL Image in RGB format
"""
# Reshape to 28x28 and convert to uint8
image_array = np.array(pixel_values).reshape(28, 28).astype(np.uint8)
# Convert grayscale to RGB by duplicating channels
image_array = np.stack([image_array] * 3, axis=-1)
return Image.fromarray(image_array)
def get_fashion_mnist_labels() -> Dict[int, str]:
"""
Get Fashion-MNIST class labels mapping.
Returns:
Dictionary mapping label IDs to class names
"""
return FASHION_MNIST_LABELS.copy()
def create_fashion_mnist_to_hierarchy_mapping(
hierarchy_classes: List[str]
) -> Dict[int, Optional[str]]:
"""
Create mapping from Fashion-MNIST labels to custom hierarchy classes.
This function performs intelligent matching between Fashion-MNIST categories
and the custom model's hierarchy classes using exact, partial, and semantic matching.
Args:
hierarchy_classes: List of hierarchy class names from the custom model
Returns:
Dictionary mapping Fashion-MNIST label IDs to hierarchy class names
(None if no match found)
"""
# Normalize hierarchy classes to lowercase for matching
hierarchy_classes_lower = [h.lower() for h in hierarchy_classes]
# Create mapping dictionary
mapping = {}
for fm_label_id, fm_label in FASHION_MNIST_LABELS.items():
fm_label_lower = fm_label.lower()
matched_hierarchy = None
# Strategy 1: Try exact match first
if fm_label_lower in hierarchy_classes_lower:
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(fm_label_lower)]
# Strategy 2: Try partial matches
elif any(h in fm_label_lower or fm_label_lower in h for h in hierarchy_classes_lower):
for h_class in hierarchy_classes:
h_lower = h_class.lower()
if h_lower in fm_label_lower or fm_label_lower in h_lower:
matched_hierarchy = h_class
break
# Strategy 3: Semantic matching for common fashion categories
else:
# T-shirt/top -> shirt or top
if fm_label_lower in ['t-shirt/top', 'top']:
if 'top' in hierarchy_classes_lower:
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('top')]
elif 'shirt' in hierarchy_classes_lower:
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('shirt')]
# Trouser -> pant, bottom
elif 'trouser' in fm_label_lower:
for possible in ['pant', 'pants', 'trousers', 'trouser', 'bottom']:
if possible in hierarchy_classes_lower:
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
break
# Pullover -> sweater, top
elif 'pullover' in fm_label_lower:
for possible in ['sweater', 'pullover', 'top']:
if possible in hierarchy_classes_lower:
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
break
# Dress -> dress
elif 'dress' in fm_label_lower:
if 'dress' in hierarchy_classes_lower:
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('dress')]
# Coat -> coat, jacket
elif 'coat' in fm_label_lower:
for possible in ['coat', 'jacket', 'outerwear']:
if possible in hierarchy_classes_lower:
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
break
# Footwear: Sandal, Sneaker, Ankle boot -> shoes
elif fm_label_lower in ['sandal', 'sneaker', 'ankle boot']:
for possible in ['shoes', 'shoe', 'footwear', 'sandal', 'sneaker', 'boot']:
if possible in hierarchy_classes_lower:
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)]
break
# Bag -> bag
elif 'bag' in fm_label_lower:
if 'bag' in hierarchy_classes_lower:
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('bag')]
mapping[fm_label_id] = matched_hierarchy
# Print mapping result
if matched_hierarchy:
print(f" {fm_label} ({fm_label_id}) -> {matched_hierarchy}")
else:
print(f" ⚠️ {fm_label} ({fm_label_id}) -> NO MATCH (will be filtered out)")
return mapping
# ============================================================================
# DATASET CLASSES
# ============================================================================
class FashionMNISTDataset(Dataset):
"""
Fashion-MNIST Dataset class for evaluation.
This dataset handles Fashion-MNIST images with proper preprocessing and
label mapping to custom hierarchy classes. Aligned with main_model_evaluation.py
for consistent evaluation across different scripts.
Args:
dataframe: Pandas DataFrame containing Fashion-MNIST data with pixel columns
image_size: Target size for image resizing (default: 224)
label_mapping: Optional mapping from Fashion-MNIST label IDs to hierarchy classes
Returns:
Tuple of (image_tensor, description, color, hierarchy)
"""
def __init__(
self,
dataframe: pd.DataFrame,
image_size: int = 224,
label_mapping: Optional[Dict[int, str]] = None
):
self.dataframe = dataframe
self.image_size = image_size
self.labels_map = get_fashion_mnist_labels()
self.label_mapping = label_mapping
# Standard ImageNet normalization for transfer learning
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
def __len__(self) -> int:
return len(self.dataframe)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, str, str]:
"""
Get a single item from the dataset.
Args:
idx: Index of the item to retrieve
Returns:
Tuple of (image_tensor, description, color, hierarchy)
"""
row = self.dataframe.iloc[idx]
# Extract pixel values (784 pixels for 28x28 image)
pixel_cols = [f"pixel{i}" for i in range(1, 785)]
pixel_values = row[pixel_cols].values
# Convert to PIL Image and apply transforms
image = convert_fashion_mnist_to_image(pixel_values)
image = self.transform(image)
# Get label information
label_id = int(row['label'])
description = self.labels_map[label_id]
color = "unknown" # Fashion-MNIST doesn't have color information
# Use mapped hierarchy if available, otherwise use original label
if self.label_mapping and label_id in self.label_mapping:
hierarchy = self.label_mapping[label_id]
else:
hierarchy = self.labels_map[label_id]
return image, description, color, hierarchy
class CLIPDataset(Dataset):
"""
Dataset class for Fashion-CLIP baseline evaluation.
This dataset handles image loading from various sources (local paths, URLs, PIL Images)
and applies standard validation transforms without augmentation.
Args:
dataframe: Pandas DataFrame containing image and text data
Returns:
Tuple of (image_tensor, description, hierarchy)
"""
def __init__(self, dataframe: pd.DataFrame):
self.dataframe = dataframe
# Validation transforms (no augmentation for fair comparison)
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def __len__(self) -> int:
return len(self.dataframe)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, str]:
"""
Get a single item from the dataset.
Args:
idx: Index of the item to retrieve
Returns:
Tuple of (image_tensor, description, hierarchy)
"""
row = self.dataframe.iloc[idx]
# Handle image loading from various sources
image = self._load_image(row, idx)
# Apply transforms
image_tensor = self.transform(image)
description = row[config.text_column]
hierarchy = row[config.hierarchy_column]
return image_tensor, description, hierarchy
def _load_image(self, row: pd.Series, idx: int) -> Image.Image:
"""
Load image from various sources with fallback handling.
Args:
row: DataFrame row containing image information
idx: Index for error reporting
Returns:
PIL Image in RGB format
"""
# Try loading from local path first
if config.column_local_image_path in row.index and pd.notna(row[config.column_local_image_path]):
local_path = row[config.column_local_image_path]
try:
if os.path.exists(local_path):
return Image.open(local_path).convert("RGB")
else:
print(f"⚠️ Local image not found: {local_path}")
except Exception as e:
print(f"⚠️ Failed to load local image {idx}: {e}")
# Try loading from various data formats
image_data = row.get(config.column_url_image)
# Handle dictionary format (with bytes)
if isinstance(image_data, dict) and 'bytes' in image_data:
return Image.open(BytesIO(image_data['bytes'])).convert('RGB')
# Handle numpy array (Fashion-MNIST format)
if isinstance(image_data, (list, np.ndarray)):
pixels = np.array(image_data).reshape(28, 28)
return Image.fromarray(pixels.astype(np.uint8)).convert("RGB")
# Handle PIL Image directly
if isinstance(image_data, Image.Image):
return image_data.convert("RGB")
# Try loading from URL as fallback
try:
response = requests.get(image_data, timeout=10)
response.raise_for_status()
return Image.open(BytesIO(response.content)).convert("RGB")
except Exception as e:
print(f"⚠️ Failed to load image {idx}: {e}")
# Return gray placeholder image
return Image.new('RGB', (224, 224), color='gray')
# ============================================================================
# EVALUATOR CLASSES
# ============================================================================
class CLIPBaselineEvaluator:
"""
Fashion-CLIP Baseline Evaluator.
This class handles the loading and evaluation of the Fashion-CLIP baseline model
(patrickjohncyh/fashion-clip) for comparison with custom models.
Args:
device: Device to run the model on ('cuda', 'mps', or 'cpu')
"""
def __init__(self, device: str = 'mps'):
self.device = torch.device(device)
# Load Fashion-CLIP model and processor
print("πŸ€— Loading Fashion-CLIP baseline model from transformers...")
model_name = "patrickjohncyh/fashion-clip"
self.clip_model = TransformersCLIPModel.from_pretrained(model_name).to(self.device)
self.clip_processor = CLIPProcessor.from_pretrained(model_name)
self.clip_model.eval()
print("βœ… Fashion-CLIP model loaded successfully")
def extract_clip_embeddings(
self,
images: List[Union[torch.Tensor, Image.Image]],
texts: List[str]
) -> Tuple[np.ndarray, np.ndarray]:
"""
Extract Fashion-CLIP embeddings for images and texts.
This method processes images and texts through the Fashion-CLIP model
to generate normalized embeddings. Aligned with main_model_evaluation.py
for consistency.
Args:
images: List of images (tensors or PIL Images)
texts: List of text descriptions
Returns:
Tuple of (image_embeddings, text_embeddings) as numpy arrays
"""
all_image_embeddings = []
all_text_embeddings = []
# Process in batches for efficiency
batch_size = 32
num_batches = (len(images) + batch_size - 1) // batch_size
with torch.no_grad():
for batch_idx in tqdm(range(num_batches), desc="Extracting CLIP embeddings"):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(images))
batch_images = images[start_idx:end_idx]
batch_texts = texts[start_idx:end_idx]
# Extract text embeddings
text_features = self._extract_text_features(batch_texts)
# Extract image embeddings
image_features = self._extract_image_features(batch_images)
# Store results
all_image_embeddings.append(image_features.cpu().numpy())
all_text_embeddings.append(text_features.cpu().numpy())
# Clear memory
del text_features, image_features
if torch.cuda.is_available():
torch.cuda.empty_cache()
return np.vstack(all_image_embeddings), np.vstack(all_text_embeddings)
def _extract_text_features(self, texts: List[str]) -> torch.Tensor:
"""
Extract text features using Fashion-CLIP.
Args:
texts: List of text descriptions
Returns:
Normalized text feature embeddings
"""
# Process text through Fashion-CLIP processor
text_inputs = self.clip_processor(
text=texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=77
)
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
# Get text features using dedicated method
text_features = self.clip_model.get_text_features(**text_inputs)
# Apply L2 normalization (critical for CLIP!)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features
def _extract_image_features(
self,
images: List[Union[torch.Tensor, Image.Image]]
) -> torch.Tensor:
"""
Extract image features using Fashion-CLIP.
Args:
images: List of images (tensors or PIL Images)
Returns:
Normalized image feature embeddings
"""
# Convert tensor images to PIL Images for proper processing
pil_images = []
for img in images:
if isinstance(img, torch.Tensor):
pil_images.append(self._tensor_to_pil(img))
elif isinstance(img, Image.Image):
pil_images.append(img)
else:
raise ValueError(f"Unsupported image type: {type(img)}")
# Process images through Fashion-CLIP processor
image_inputs = self.clip_processor(
images=pil_images,
return_tensors="pt"
)
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()}
# Get image features using dedicated method
image_features = self.clip_model.get_image_features(**image_inputs)
# Apply L2 normalization (critical for CLIP!)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features
def _tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image:
"""
Convert a normalized tensor to PIL Image.
Args:
tensor: Image tensor (C, H, W)
Returns:
PIL Image
"""
if tensor.dim() != 3:
raise ValueError(f"Expected 3D tensor, got {tensor.dim()}D")
# Denormalize if normalized (undo ImageNet normalization)
if tensor.min() < 0 or tensor.max() > 1:
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
tensor = tensor * std + mean
tensor = torch.clamp(tensor, 0, 1)
# Convert to PIL
return transforms.ToPILImage()(tensor)
class EmbeddingEvaluator:
"""
Comprehensive Embedding Evaluator for Hierarchy Classification.
This class provides a complete evaluation pipeline for hierarchy classification models,
including custom model evaluation and Fashion-CLIP baseline comparison. It supports
multiple evaluation metrics, datasets, and advanced techniques.
Key Features:
- Custom model loading and evaluation
- Fashion-CLIP baseline comparison
- Multiple classification methods (nearest neighbor, centroid, Mahalanobis)
- Advanced techniques (ZCA whitening, Test-Time Augmentation)
- Comprehensive metrics (accuracy, F1, confusion matrices)
Args:
model_path: Path to the trained custom model checkpoint
directory: Output directory for saving evaluation results
"""
def __init__(self, model_path: str, directory: str):
self.directory = directory
self.device = device
# Load and prepare dataset
print(f"πŸ“ Using dataset with local images: {local_dataset_path}")
df = pd.read_csv(local_dataset_path)
print(f"πŸ“ Loaded {len(df)} samples")
# Get unique hierarchy classes
hierarchy_classes = sorted(df[hierarchy_column].unique().tolist())
print(f"πŸ“‹ Found {len(hierarchy_classes)} hierarchy classes")
# Limit dataset size to prevent memory issues
if len(df) > MAX_SAMPLES_EVALUATION:
print(f"⚠️ Dataset too large ({len(df)} samples), sampling to {MAX_SAMPLES_EVALUATION} samples")
df = self._stratified_sample(df, MAX_SAMPLES_EVALUATION)
# Create validation split (20% of data)
_, self.val_df = train_test_split(
df,
test_size=0.2,
random_state=42,
stratify=df['hierarchy']
)
# Load the custom model
self._load_model(model_path)
# Initialize Fashion-CLIP baseline
self.clip_evaluator = CLIPBaselineEvaluator(device)
def _stratified_sample(self, df: pd.DataFrame, max_samples: int) -> pd.DataFrame:
"""
Perform stratified sampling to maintain class distribution.
Args:
df: Original DataFrame
max_samples: Maximum number of samples to keep
Returns:
Sampled DataFrame
"""
# Stratified sampling by hierarchy
df_sampled = df.groupby('hierarchy', group_keys=False).apply(
lambda x: x.sample(
n=min(len(x), int(max_samples * len(x) / len(df))),
random_state=42
)
).reset_index(drop=True)
# Adjust to reach exactly max_samples if necessary
if len(df_sampled) < max_samples:
remaining = max_samples - len(df_sampled)
extra = df.sample(n=remaining, random_state=42)
df_sampled = pd.concat([df_sampled, extra]).reset_index(drop=True)
return df_sampled
def _load_model(self, model_path: str):
"""
Load the custom hierarchy classification model.
Args:
model_path: Path to the model checkpoint
Raises:
FileNotFoundError: If model file doesn't exist
"""
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file {model_path} not found")
# Load checkpoint
checkpoint = torch.load(model_path, map_location=self.device)
# Extract configuration
config_dict = checkpoint.get('config', {})
saved_hierarchy_classes = checkpoint['hierarchy_classes']
# Store hierarchy classes
self.hierarchy_classes = saved_hierarchy_classes
# Create hierarchy extractor
self.vocab = HierarchyExtractor(saved_hierarchy_classes)
# Create model with saved configuration
self.model = Model(
num_hierarchy_classes=len(saved_hierarchy_classes),
embed_dim=config_dict['embed_dim'],
dropout=config_dict['dropout']
).to(self.device)
# Load model weights
self.model.load_state_dict(checkpoint['model_state'])
self.model.eval()
# Print model information
print(f"βœ… Custom model loaded with:")
print(f"πŸ“‹ Hierarchy classes: {len(saved_hierarchy_classes)}")
print(f"🎯 Embed dim: {config_dict['embed_dim']}")
print(f"πŸ’§ Dropout: {config_dict['dropout']}")
print(f"πŸ“… Epoch: {checkpoint.get('epoch', 'unknown')}")
def _collate_fn_wrapper(self, batch: List[Tuple]) -> Dict[str, torch.Tensor]:
"""
Wrapper for collate_fn that can be pickled (required for DataLoader).
Handles both formats:
- (image, description, hierarchy) for HierarchyDataset
- (image, description, color, hierarchy) for FashionMNISTDataset
Args:
batch: List of samples from dataset
Returns:
Collated batch dictionary
"""
# Check batch format
if len(batch[0]) == 4:
# FashionMNISTDataset format: convert to expected format
batch_converted = [(b[0], b[1], b[3]) for b in batch]
return collate_fn(batch_converted, self.vocab)
else:
# HierarchyDataset format: use as is
return collate_fn(batch, self.vocab)
def create_dataloader(
self,
dataframe_or_dataset: Union[pd.DataFrame, Dataset],
batch_size: int = 16
) -> DataLoader:
"""
Create a DataLoader for the custom model.
Aligned with main_model_evaluation.py for consistency.
Args:
dataframe_or_dataset: Either a pandas DataFrame or a Dataset object
batch_size: Batch size for the DataLoader
Returns:
Configured DataLoader
"""
# Check if it's already a Dataset object
if isinstance(dataframe_or_dataset, Dataset):
dataset = dataframe_or_dataset
print(f"πŸ” Using pre-created Dataset object")
# Otherwise create dataset from dataframe
elif isinstance(dataframe_or_dataset, pd.DataFrame):
# Check if this is Fashion-MNIST data
if 'pixel1' in dataframe_or_dataset.columns:
print(f"πŸ” Detected Fashion-MNIST data, creating FashionMNISTDataset")
dataset = FashionMNISTDataset(dataframe_or_dataset, image_size=224)
else:
dataset = HierarchyDataset(dataframe_or_dataset, image_size=224)
else:
raise ValueError(f"Unsupported type: {type(dataframe_or_dataset)}")
# Create DataLoader
# Note: num_workers=0 to avoid pickling issues on macOS
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=self._collate_fn_wrapper,
num_workers=0,
pin_memory=False
)
return dataloader
def create_clip_dataloader(
self,
dataframe_or_dataset: Union[pd.DataFrame, Dataset],
batch_size: int = 16
) -> DataLoader:
"""
Create a DataLoader for Fashion-CLIP baseline.
Args:
dataframe_or_dataset: Either a pandas DataFrame or a Dataset object
batch_size: Batch size for the DataLoader
Returns:
Configured DataLoader
"""
# Check if it's already a Dataset object
if isinstance(dataframe_or_dataset, Dataset):
dataset = dataframe_or_dataset
print(f"πŸ” Using pre-created Dataset object for CLIP")
# Otherwise create dataset from dataframe
elif isinstance(dataframe_or_dataset, pd.DataFrame):
# Check if this is Fashion-MNIST data
if 'pixel1' in dataframe_or_dataset.columns:
print("πŸ” Detected Fashion-MNIST data for Fashion-CLIP")
dataset = FashionMNISTDataset(dataframe_or_dataset, image_size=224)
else:
dataset = CLIPDataset(dataframe_or_dataset)
else:
raise ValueError(f"Unsupported type: {type(dataframe_or_dataset)}")
# Create DataLoader
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=False
)
return dataloader
def extract_custom_embeddings(
self,
dataloader: DataLoader,
embedding_type: str = 'text',
use_tta: bool = False
) -> Tuple[np.ndarray, List[str], List[str]]:
"""
Extract embeddings from custom model with optional Test-Time Augmentation.
Args:
dataloader: DataLoader for the dataset
embedding_type: Type of embedding to extract ('text', 'image', or 'both')
use_tta: Whether to use Test-Time Augmentation for images
Returns:
Tuple of (embeddings, labels, texts)
"""
all_embeddings = []
all_labels = []
all_texts = []
with torch.no_grad():
for batch in tqdm(dataloader, desc=f"Extracting custom {embedding_type} embeddings{' with TTA' if use_tta else ''}"):
images = batch['image'].to(self.device)
hierarchy_indices = batch['hierarchy_indices'].to(self.device)
hierarchy_labels = batch['hierarchy']
# Handle Test-Time Augmentation
if use_tta and embedding_type == 'image' and images.dim() == 5:
embeddings = self._extract_with_tta(images, hierarchy_indices)
else:
# Standard forward pass
out = self.model(image=images, hierarchy_indices=hierarchy_indices)
embeddings = out['z_txt'] if embedding_type == 'text' else out['z_img']
all_embeddings.append(embeddings.cpu().numpy())
all_labels.extend(hierarchy_labels)
all_texts.extend(hierarchy_labels)
# Clear memory
del images, hierarchy_indices, embeddings, out
if str(self.device) != 'cpu':
if torch.cuda.is_available():
torch.cuda.empty_cache()
return np.vstack(all_embeddings), all_labels, all_texts
def _extract_with_tta(
self,
images: torch.Tensor,
hierarchy_indices: torch.Tensor
) -> torch.Tensor:
"""
Extract embeddings using Test-Time Augmentation.
Args:
images: Images with TTA crops [batch_size, tta_crops, C, H, W]
hierarchy_indices: Hierarchy class indices
Returns:
Averaged embeddings [batch_size, embed_dim]
"""
batch_size, tta_crops, C, H, W = images.shape
# Reshape to [batch_size * tta_crops, C, H, W]
images_flat = images.view(batch_size * tta_crops, C, H, W)
# Repeat hierarchy indices for each TTA crop
hierarchy_indices_repeated = hierarchy_indices.unsqueeze(1).repeat(1, tta_crops).view(-1)
# Forward pass on all TTA crops
out = self.model(image=images_flat, hierarchy_indices=hierarchy_indices_repeated)
embeddings_flat = out['z_img']
# Reshape back to [batch_size, tta_crops, embed_dim]
embeddings = embeddings_flat.view(batch_size, tta_crops, -1)
# Average over TTA crops
embeddings = embeddings.mean(dim=1)
return embeddings
def apply_whitening(
self,
embeddings: np.ndarray,
epsilon: float = 1e-5
) -> np.ndarray:
"""
Apply ZCA whitening to embeddings for better feature decorrelation.
Whitening removes correlations between dimensions and can improve
class separation by normalizing the feature space.
Args:
embeddings: Input embeddings [N, D]
epsilon: Small constant for numerical stability
Returns:
Whitened embeddings [N, D]
"""
# Center the data
mean = np.mean(embeddings, axis=0, keepdims=True)
centered = embeddings - mean
# Compute covariance matrix
cov = np.cov(centered.T)
# Eigenvalue decomposition
eigenvalues, eigenvectors = np.linalg.eigh(cov)
# ZCA whitening transformation
d = np.diag(1.0 / np.sqrt(eigenvalues + epsilon))
whiten_transform = eigenvectors @ d @ eigenvectors.T
# Apply whitening
whitened = centered @ whiten_transform
# L2 normalize after whitening
norms = np.linalg.norm(whitened, axis=1, keepdims=True)
whitened = whitened / (norms + epsilon)
return whitened
def compute_similarity_metrics(
self,
embeddings: np.ndarray,
labels: List[str],
apply_whitening_norm: bool = False
) -> Dict[str, Any]:
"""
Compute intra-class and inter-class similarity metrics.
Args:
embeddings: Embedding vectors
labels: Class labels
apply_whitening_norm: Whether to apply ZCA whitening
Returns:
Dictionary containing similarity metrics and accuracies
"""
# Apply whitening if requested
if apply_whitening_norm:
embeddings = self.apply_whitening(embeddings)
# Compute pairwise cosine similarities
similarities = cosine_similarity(embeddings)
# Group embeddings by hierarchy
hierarchy_groups = defaultdict(list)
for i, hierarchy in enumerate(labels):
hierarchy_groups[hierarchy].append(i)
# Calculate intra-class similarities (same hierarchy)
intra_class_similarities = self._compute_intra_class_similarities(
similarities, hierarchy_groups
)
# Calculate inter-class similarities (different hierarchies)
inter_class_similarities = self._compute_inter_class_similarities(
similarities, hierarchy_groups
)
# Calculate classification accuracies
nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
return {
'intra_class_similarities': intra_class_similarities,
'inter_class_similarities': inter_class_similarities,
'intra_class_mean': np.mean(intra_class_similarities) if intra_class_similarities else 0,
'inter_class_mean': np.mean(inter_class_similarities) if inter_class_similarities else 0,
'separation_score': np.mean(intra_class_similarities) - np.mean(inter_class_similarities) if intra_class_similarities and inter_class_similarities else 0,
'accuracy': nn_accuracy,
'centroid_accuracy': centroid_accuracy
}
def _compute_intra_class_similarities(
self,
similarities: np.ndarray,
hierarchy_groups: Dict[str, List[int]]
) -> List[float]:
"""
Compute within-class similarities.
Args:
similarities: Pairwise similarity matrix
hierarchy_groups: Mapping from hierarchy to sample indices
Returns:
List of intra-class similarity values
"""
intra_class_similarities = []
for hierarchy, indices in hierarchy_groups.items():
if len(indices) > 1:
# Compare all pairs within the same class
for i in range(len(indices)):
for j in range(i + 1, len(indices)):
sim = similarities[indices[i], indices[j]]
intra_class_similarities.append(sim)
return intra_class_similarities
def _compute_inter_class_similarities(
self,
similarities: np.ndarray,
hierarchy_groups: Dict[str, List[int]]
) -> List[float]:
"""
Compute between-class similarities with sampling for efficiency.
To prevent O(nΒ²) complexity on large datasets, we limit the number
of comparisons through sampling.
Args:
similarities: Pairwise similarity matrix
hierarchy_groups: Mapping from hierarchy to sample indices
Returns:
List of inter-class similarity values
"""
inter_class_similarities = []
hierarchies = list(hierarchy_groups.keys())
comparison_count = 0
for i in range(len(hierarchies)):
for j in range(i + 1, len(hierarchies)):
hierarchy1_indices = hierarchy_groups[hierarchies[i]]
hierarchy2_indices = hierarchy_groups[hierarchies[j]]
# Sample if too many comparisons
max_samples_per_pair = min(100, len(hierarchy1_indices), len(hierarchy2_indices))
sampled_idx1 = np.random.choice(
hierarchy1_indices,
size=min(max_samples_per_pair, len(hierarchy1_indices)),
replace=False
)
sampled_idx2 = np.random.choice(
hierarchy2_indices,
size=min(max_samples_per_pair, len(hierarchy2_indices)),
replace=False
)
# Compute similarities between sampled pairs
for idx1 in sampled_idx1:
for idx2 in sampled_idx2:
if comparison_count >= MAX_INTER_CLASS_COMPARISONS:
break
sim = similarities[idx1, idx2]
inter_class_similarities.append(sim)
comparison_count += 1
if comparison_count >= MAX_INTER_CLASS_COMPARISONS:
break
if comparison_count >= MAX_INTER_CLASS_COMPARISONS:
break
if comparison_count >= MAX_INTER_CLASS_COMPARISONS:
break
return inter_class_similarities
def compute_embedding_accuracy(
self,
embeddings: np.ndarray,
labels: List[str],
similarities: np.ndarray
) -> float:
"""
Compute classification accuracy using nearest neighbor in embedding space.
Args:
embeddings: Embedding vectors
labels: True class labels
similarities: Precomputed similarity matrix
Returns:
Classification accuracy
"""
correct_predictions = 0
total_predictions = len(labels)
for i in range(len(embeddings)):
true_label = labels[i]
# Find the most similar embedding (excluding itself)
similarities_row = similarities[i].copy()
similarities_row[i] = -1 # Exclude self-similarity
nearest_neighbor_idx = np.argmax(similarities_row)
predicted_label = labels[nearest_neighbor_idx]
if predicted_label == true_label:
correct_predictions += 1
return correct_predictions / total_predictions if total_predictions > 0 else 0
def compute_centroid_accuracy(
self,
embeddings: np.ndarray,
labels: List[str]
) -> float:
"""
Compute classification accuracy using hierarchy centroids.
Args:
embeddings: Embedding vectors
labels: True class labels
Returns:
Classification accuracy
"""
# Create centroids for each hierarchy
unique_hierarchies = list(set(labels))
centroids = {}
for hierarchy in unique_hierarchies:
hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy]
hierarchy_embeddings = embeddings[hierarchy_indices]
centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0)
# Classify each embedding to nearest centroid
correct_predictions = 0
total_predictions = len(labels)
for i, embedding in enumerate(embeddings):
true_label = labels[i]
# Find closest centroid
best_similarity = -1
predicted_label = None
for hierarchy, centroid in centroids.items():
similarity = cosine_similarity([embedding], [centroid])[0][0]
if similarity > best_similarity:
best_similarity = similarity
predicted_label = hierarchy
if predicted_label == true_label:
correct_predictions += 1
return correct_predictions / total_predictions if total_predictions > 0 else 0
def compute_mahalanobis_distance(
self,
point: np.ndarray,
centroid: np.ndarray,
cov_inv: np.ndarray
) -> float:
"""
Compute Mahalanobis distance between a point and a centroid.
The Mahalanobis distance takes into account the covariance structure
of the data, making it more robust than Euclidean distance for
high-dimensional spaces.
Args:
point: Query point
centroid: Class centroid
cov_inv: Inverse covariance matrix
Returns:
Mahalanobis distance
"""
diff = point - centroid
distance = np.sqrt(np.dot(np.dot(diff, cov_inv), diff.T))
return distance
def predict_hierarchy_from_embeddings(
self,
embeddings: np.ndarray,
labels: List[str],
use_mahalanobis: bool = False
) -> List[str]:
"""
Predict hierarchy from embeddings using centroid-based classification.
Args:
embeddings: Embedding vectors
labels: Training labels for computing centroids
use_mahalanobis: Whether to use Mahalanobis distance
Returns:
List of predicted hierarchy labels
"""
# Create hierarchy centroids from training data
unique_hierarchies = list(set(labels))
centroids = {}
cov_inverses = {}
for hierarchy in unique_hierarchies:
hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy]
hierarchy_embeddings = embeddings[hierarchy_indices]
centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0)
# Compute covariance for Mahalanobis distance
if use_mahalanobis and len(hierarchy_embeddings) > 1:
cov = np.cov(hierarchy_embeddings.T)
# Add regularization for numerical stability
cov += np.eye(cov.shape[0]) * 1e-6
try:
cov_inverses[hierarchy] = np.linalg.inv(cov)
except np.linalg.LinAlgError:
# If inversion fails, fallback to identity (Euclidean)
cov_inverses[hierarchy] = np.eye(cov.shape[0])
# Predict hierarchy for all embeddings
predictions = []
for embedding in embeddings:
if use_mahalanobis:
predicted_hierarchy = self._predict_with_mahalanobis(
embedding, centroids, cov_inverses
)
else:
predicted_hierarchy = self._predict_with_cosine(
embedding, centroids
)
predictions.append(predicted_hierarchy)
return predictions
def _predict_with_mahalanobis(
self,
embedding: np.ndarray,
centroids: Dict[str, np.ndarray],
cov_inverses: Dict[str, np.ndarray]
) -> str:
"""
Predict class using Mahalanobis distance (lower is better).
Args:
embedding: Query embedding
centroids: Class centroids
cov_inverses: Inverse covariance matrices
Returns:
Predicted class label
"""
best_distance = float('inf')
predicted_hierarchy = None
for hierarchy, centroid in centroids.items():
if hierarchy in cov_inverses:
distance = self.compute_mahalanobis_distance(
embedding, centroid, cov_inverses[hierarchy]
)
else:
# Fallback to cosine similarity for classes with insufficient samples
similarity = cosine_similarity([embedding], [centroid])[0][0]
distance = 1 - similarity
if distance < best_distance:
best_distance = distance
predicted_hierarchy = hierarchy
return predicted_hierarchy
def _predict_with_cosine(
self,
embedding: np.ndarray,
centroids: Dict[str, np.ndarray]
) -> str:
"""
Predict class using cosine similarity (higher is better).
Args:
embedding: Query embedding
centroids: Class centroids
Returns:
Predicted class label
"""
best_similarity = -1
predicted_hierarchy = None
for hierarchy, centroid in centroids.items():
similarity = cosine_similarity([embedding], [centroid])[0][0]
if similarity > best_similarity:
best_similarity = similarity
predicted_hierarchy = hierarchy
return predicted_hierarchy
def create_confusion_matrix(
self,
true_labels: List[str],
predicted_labels: List[str],
title: str = "Confusion Matrix"
) -> Tuple[plt.Figure, float, np.ndarray]:
"""
Create and plot confusion matrix.
Args:
true_labels: Ground truth labels
predicted_labels: Predicted labels
title: Plot title
Returns:
Tuple of (figure, accuracy, confusion_matrix)
"""
# Get unique labels
unique_labels = sorted(list(set(true_labels + predicted_labels)))
# Create confusion matrix
cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
# Calculate accuracy
accuracy = accuracy_score(true_labels, predicted_labels)
# Plot confusion matrix
plt.figure(figsize=(12, 10))
sns.heatmap(
cm,
annot=True,
fmt='d',
cmap='Blues',
xticklabels=unique_labels,
yticklabels=unique_labels
)
plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
plt.ylabel('True Hierarchy')
plt.xlabel('Predicted Hierarchy')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
return plt.gcf(), accuracy, cm
def evaluate_classification_performance(
self,
embeddings: np.ndarray,
labels: List[str],
embedding_type: str = "Embeddings",
apply_whitening_norm: bool = False,
use_mahalanobis: bool = False
) -> Dict[str, Any]:
"""
Evaluate classification performance and create confusion matrix.
Args:
embeddings: Embedding vectors
labels: True class labels
embedding_type: Description of embedding type for display
apply_whitening_norm: Whether to apply ZCA whitening
use_mahalanobis: Whether to use Mahalanobis distance
Returns:
Dictionary containing classification metrics and visualizations
"""
# Apply whitening if requested
if apply_whitening_norm:
embeddings = self.apply_whitening(embeddings)
# Predict hierarchy
predictions = self.predict_hierarchy_from_embeddings(
embeddings, labels, use_mahalanobis=use_mahalanobis
)
# Calculate accuracy
accuracy = accuracy_score(labels, predictions)
# Calculate F1 scores
unique_labels = sorted(list(set(labels)))
f1_macro = f1_score(
labels, predictions, labels=unique_labels,
average='macro', zero_division=0
)
f1_weighted = f1_score(
labels, predictions, labels=unique_labels,
average='weighted', zero_division=0
)
f1_per_class = f1_score(
labels, predictions, labels=unique_labels,
average=None, zero_division=0
)
# Create confusion matrix
fig, acc, cm = self.create_confusion_matrix(
labels, predictions,
f"{embedding_type} - Hierarchy Classification"
)
# Generate classification report
report = classification_report(
labels, predictions, labels=unique_labels,
target_names=unique_labels, output_dict=True
)
return {
'accuracy': accuracy,
'f1_macro': f1_macro,
'f1_weighted': f1_weighted,
'f1_per_class': f1_per_class,
'predictions': predictions,
'confusion_matrix': cm,
'classification_report': report,
'figure': fig
}
def evaluate_dataset_with_baselines(
self,
dataframe: Union[pd.DataFrame, Dataset],
dataset_name: str = "Dataset",
use_whitening: bool = False,
use_mahalanobis: bool = False
) -> Dict[str, Dict[str, Any]]:
"""
Evaluate embeddings on a given dataset with both custom model and CLIP baseline.
This is the main evaluation method that compares the custom model against
the Fashion-CLIP baseline across multiple metrics and embedding types.
Aligned with main_model_evaluation.py for consistency (no TTA for fair comparison).
Args:
dataframe: DataFrame or Dataset to evaluate on
dataset_name: Name of the dataset for display
use_whitening: Whether to apply ZCA whitening
use_mahalanobis: Whether to use Mahalanobis distance
Returns:
Dictionary containing results for all models and embedding types
"""
print(f"\n{'='*60}")
print(f"Evaluating {dataset_name}")
if use_whitening:
print(f"🎯 ZCA Whitening ENABLED for better feature decorrelation")
if use_mahalanobis:
print(f"🎯 Mahalanobis Distance ENABLED for classification")
print(f"{'='*60}")
results = {}
# ===== CUSTOM MODEL EVALUATION =====
print(f"\nπŸ”§ Evaluating Custom Model on {dataset_name}")
print("-" * 40)
# Create dataloader
custom_dataloader = self.create_dataloader(dataframe, batch_size=16)
# Evaluate text embeddings
text_embeddings, text_labels, texts = self.extract_custom_embeddings(
custom_dataloader, 'text', use_tta=False
)
text_metrics = self.compute_similarity_metrics(
text_embeddings, text_labels, apply_whitening_norm=use_whitening
)
text_classification = self.evaluate_classification_performance(
text_embeddings, text_labels, "Custom Text Embeddings",
apply_whitening_norm=use_whitening, use_mahalanobis=use_mahalanobis
)
text_metrics.update(text_classification)
results['custom_text'] = text_metrics
# Evaluate image embeddings
# NOTE: TTA disabled for fair comparison
image_embeddings, image_labels, _ = self.extract_custom_embeddings(
custom_dataloader, 'image', use_tta=False
)
image_metrics = self.compute_similarity_metrics(
image_embeddings, image_labels, apply_whitening_norm=use_whitening
)
whitening_suffix = " + Whitening" if use_whitening else ""
mahalanobis_suffix = " + Mahalanobis" if use_mahalanobis else ""
image_classification = self.evaluate_classification_performance(
image_embeddings, image_labels,
f"Custom Image Embeddings{whitening_suffix}{mahalanobis_suffix}",
apply_whitening_norm=use_whitening, use_mahalanobis=use_mahalanobis
)
image_metrics.update(image_classification)
results['custom_image'] = image_metrics
# ===== FASHION-CLIP BASELINE EVALUATION =====
print(f"\nπŸ€— Evaluating Fashion-CLIP Baseline on {dataset_name}")
print("-" * 40)
# Create dataloader for Fashion-CLIP
clip_dataloader = self.create_clip_dataloader(dataframe, batch_size=8)
# Extract data for Fashion-CLIP
all_images = []
all_texts = []
all_labels = []
for batch in tqdm(clip_dataloader, desc="Preparing data for Fashion-CLIP"):
# Handle different batch formats
if len(batch) == 4:
images, descriptions, colors, hierarchies = batch
else:
images, descriptions, hierarchies = batch
all_images.extend(images)
all_texts.extend(descriptions)
all_labels.extend(hierarchies)
# Get Fashion-CLIP embeddings
clip_image_embeddings, clip_text_embeddings = self.clip_evaluator.extract_clip_embeddings(
all_images, all_texts
)
# Evaluate Fashion-CLIP text embeddings
clip_text_metrics = self.compute_similarity_metrics(
clip_text_embeddings, all_labels
)
clip_text_classification = self.evaluate_classification_performance(
clip_text_embeddings, all_labels, "Fashion-CLIP Text Embeddings"
)
clip_text_metrics.update(clip_text_classification)
results['clip_text'] = clip_text_metrics
# Evaluate Fashion-CLIP image embeddings
clip_image_metrics = self.compute_similarity_metrics(
clip_image_embeddings, all_labels
)
clip_image_classification = self.evaluate_classification_performance(
clip_image_embeddings, all_labels, "Fashion-CLIP Image Embeddings"
)
clip_image_metrics.update(clip_image_classification)
results['clip_image'] = clip_image_metrics
# ===== PRINT COMPARISON RESULTS =====
self._print_comparison_results(dataframe, dataset_name, results)
# ===== SAVE VISUALIZATIONS =====
self._save_visualizations(dataset_name, results)
return results
def _print_comparison_results(
self,
dataframe: Union[pd.DataFrame, Dataset],
dataset_name: str,
results: Dict[str, Dict[str, Any]]
):
"""
Print formatted comparison results.
Args:
dataframe: Dataset being evaluated
dataset_name: Name of the dataset
results: Evaluation results dictionary
"""
dataset_size = len(dataframe) if hasattr(dataframe, '__len__') else "N/A"
print(f"\n{dataset_name} Results Comparison:")
print(f"Dataset size: {dataset_size} samples")
print("=" * 80)
print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<10} {'NN Acc':<8} {'Centroid Acc':<12} {'F1 Macro':<10}")
print("-" * 80)
for model_type in ['custom', 'clip']:
for emb_type in ['text', 'image']:
key = f"{model_type}_{emb_type}"
if key in results:
metrics = results[key]
model_name = "Custom Model" if model_type == 'custom' else "Fashion-CLIP Baseline"
print(
f"{model_name:<20} "
f"{emb_type.capitalize():<10} "
f"{metrics['separation_score']:<10.4f} "
f"{metrics['accuracy']*100:<8.1f}% "
f"{metrics['centroid_accuracy']*100:<12.1f}% "
f"{metrics['f1_macro']*100:<10.1f}%"
)
def _save_visualizations(
self,
dataset_name: str,
results: Dict[str, Dict[str, Any]]
):
"""
Save confusion matrices and other visualizations.
Args:
dataset_name: Name of the dataset
results: Evaluation results dictionary
"""
os.makedirs(self.directory, exist_ok=True)
# Save confusion matrices
for key, metrics in results.items():
if 'figure' in metrics:
filename = f'{self.directory}/{dataset_name.lower()}_{key}_confusion_matrix.png'
metrics['figure'].savefig(filename, dpi=300, bbox_inches='tight')
plt.close(metrics['figure'])
# ============================================================================
# DATASET LOADING FUNCTIONS
# ============================================================================
def load_fashion_mnist_dataset(
evaluator: EmbeddingEvaluator,
max_samples: int = 1000
) -> FashionMNISTDataset:
"""
Load and prepare Fashion-MNIST test dataset.
This function loads the Fashion-MNIST test set and creates appropriate
mappings to the custom model's hierarchy classes.
Exactly aligned with main_model_evaluation.py for consistency.
Args:
evaluator: EmbeddingEvaluator instance with loaded model
max_samples: Maximum number of samples to use
Returns:
FashionMNISTDataset object
"""
print("πŸ“Š Loading Fashion-MNIST test dataset...")
df = pd.read_csv(config.fashion_mnist_test_path)
print(f"βœ… Fashion-MNIST dataset loaded: {len(df)} samples")
# Create mapping if hierarchy classes are provided
label_mapping = None
if evaluator.hierarchy_classes is not None:
print("\nπŸ”— Creating mapping from Fashion-MNIST labels to hierarchy classes:")
label_mapping = create_fashion_mnist_to_hierarchy_mapping(
evaluator.hierarchy_classes
)
# Filter dataset to only include samples that can be mapped
valid_label_ids = [
label_id for label_id, hierarchy in label_mapping.items()
if hierarchy is not None
]
df_filtered = df[df['label'].isin(valid_label_ids)]
print(
f"\nπŸ“Š After filtering to mappable labels: "
f"{len(df_filtered)} samples (from {len(df)})"
)
# Apply max_samples limit after filtering
df_sample = df_filtered.head(max_samples)
else:
df_sample = df.head(max_samples)
print(f"πŸ“Š Using {len(df_sample)} samples for evaluation")
return FashionMNISTDataset(df_sample, label_mapping=label_mapping)
def load_kagl_marqo_dataset(evaluator: EmbeddingEvaluator) -> pd.DataFrame:
"""
Load and prepare Kaggle Marqo dataset for evaluation.
This function loads the Marqo fashion dataset from Hugging Face
and preprocesses it for evaluation with the custom model.
Args:
evaluator: EmbeddingEvaluator instance with loaded model
Returns:
Formatted pandas DataFrame ready for evaluation
"""
from datasets import load_dataset
print("πŸ“Š Loading Kaggle Marqo dataset...")
# Load the dataset from Hugging Face
dataset = load_dataset("Marqo/KAGL")
df = dataset["data"].to_pandas()
print(f"βœ… Dataset Kaggle loaded")
print(f"πŸ“Š Before filtering: {len(df)} samples")
print(f"πŸ“‹ Available columns: {list(df.columns)}")
print(f"🎨 Available categories: {sorted(df['category2'].unique())}")
# Map categories to our hierarchy format
df['hierarchy'] = df['category2'].str.lower()
df['hierarchy'] = df['hierarchy'].replace({
'bags': 'bag',
'topwear': 'top',
'flip flops': 'shoes',
'sandal': 'shoes'
})
# Filter to only include valid hierarchies
valid_hierarchies = df['hierarchy'].dropna().unique()
print(f"🎯 Valid hierarchies found: {sorted(valid_hierarchies)}")
print(f"🎯 Model hierarchies: {sorted(evaluator.hierarchy_classes)}")
df = df[df['hierarchy'].isin(evaluator.hierarchy_classes)]
print(f"πŸ“Š After filtering to model hierarchies: {len(df)} samples")
if len(df) == 0:
print("❌ No samples left after hierarchy filtering.")
return pd.DataFrame()
# Ensure we have text and image data
df = df.dropna(subset=['text', 'image'])
print(f"πŸ“Š After removing missing text/image: {len(df)} samples")
# Show sample of text data to verify quality
print(f"πŸ“ Sample texts:")
for i, (text, hierarchy) in enumerate(zip(df['text'].head(3), df['hierarchy'].head(3))):
print(f" {i+1}. [{hierarchy}] {text[:100]}...")
# Limit size to prevent memory overload
max_samples = 1000
if len(df) > max_samples:
print(f"⚠️ Dataset too large ({len(df)} samples), sampling to {max_samples} samples")
df_test = df.sample(n=max_samples, random_state=42).reset_index(drop=True)
else:
df_test = df.copy()
print(f"πŸ“Š After sampling: {len(df_test)} samples")
print(f"πŸ“Š Samples per hierarchy:")
for hierarchy in sorted(df_test['hierarchy'].unique()):
count = len(df_test[df_test['hierarchy'] == hierarchy])
print(f" {hierarchy}: {count} samples")
# Create formatted dataset with proper column names
kagl_formatted = pd.DataFrame({
'image_url': df_test['image'],
'text': df_test['text'],
'hierarchy': df_test['hierarchy']
})
print(f"πŸ“Š Final dataset size: {len(kagl_formatted)} samples")
return kagl_formatted
# ============================================================================
# MAIN EXECUTION
# ============================================================================
def main():
"""
Main evaluation function that runs comprehensive evaluation across multiple datasets.
This function evaluates the custom hierarchy classification model against the
Fashion-CLIP baseline on:
1. Validation dataset (from training data)
2. Fashion-MNIST test dataset
3. Kaggle Marqo dataset
Results include detailed metrics, confusion matrices, and performance comparisons.
"""
# Setup output directory
directory = "hierarchy_model_analysis"
print(f"πŸš€ Starting evaluation with custom model: {hierarchy_model_path}")
print(f"πŸ€— Including Fashion-CLIP baseline comparison")
# Initialize evaluator
evaluator = EmbeddingEvaluator(hierarchy_model_path, directory)
print(
f"πŸ“Š Final hierarchy classes after initialization: "
f"{len(evaluator.vocab.hierarchy_classes)} classes"
)
# ===== EVALUATION 1: VALIDATION DATASET =====
print("\n" + "="*60)
print("EVALUATING VALIDATION DATASET - CUSTOM MODEL vs FASHION-CLIP BASELINE")
print("="*60)
val_results = evaluator.evaluate_dataset_with_baselines(
evaluator.val_df,
"Validation Dataset"
)
# ===== EVALUATION 2: FASHION-MNIST TEST DATASET =====
print("\n" + "="*60)
print("EVALUATING FASHION-MNIST TEST DATASET - CUSTOM MODEL vs FASHION-CLIP BASELINE")
print("="*60)
fashion_mnist_dataset = load_fashion_mnist_dataset(evaluator, max_samples=1000)
if fashion_mnist_dataset is not None:
# Aligned with main_model_evaluation.py: NO TTA for fair baseline comparison
fashion_mnist_results = evaluator.evaluate_dataset_with_baselines(
fashion_mnist_dataset,
"Fashion-MNIST Test Dataset",
use_whitening=False, # Disabled for fair comparison
use_mahalanobis=False # Disabled for fair comparison
)
else:
fashion_mnist_results = {}
# ===== EVALUATION 3: KAGGLE MARQO DATASET =====
print("\n" + "="*60)
print("EVALUATING KAGGLE MARQO DATASET - CUSTOM MODEL vs FASHION-CLIP BASELINE")
print("="*60)
df_kagl_marqo = load_kagl_marqo_dataset(evaluator)
if len(df_kagl_marqo) > 0:
kagl_results = evaluator.evaluate_dataset_with_baselines(
df_kagl_marqo,
"Kaggle Marqo Dataset"
)
else:
kagl_results = {}
# ===== FINAL SUMMARY =====
print(f"\n{'='*80}")
print("FINAL EVALUATION SUMMARY - CUSTOM MODEL vs FASHION-CLIP BASELINE")
print(f"{'='*80}")
# Print validation results
print("\nπŸ” VALIDATION DATASET RESULTS:")
_print_dataset_results(val_results, len(evaluator.val_df))
# Print Fashion-MNIST results
if fashion_mnist_results:
print("\nπŸ‘— FASHION-MNIST TEST DATASET RESULTS:")
_print_dataset_results(fashion_mnist_results, 1000)
# Print Kaggle results
if kagl_results:
print("\n🌐 KAGGLE MARQO DATASET RESULTS:")
_print_dataset_results(
kagl_results,
len(df_kagl_marqo) if df_kagl_marqo is not None else 'N/A'
)
# Final completion message
print(f"\nβœ… Evaluation completed! Check '{directory}/' for visualization files.")
print(f"πŸ“Š Custom model hierarchy classes: {len(evaluator.vocab.hierarchy_classes)} classes")
print(f"πŸ€— Fashion-CLIP baseline comparison included")
def _print_dataset_results(results: Dict[str, Dict[str, Any]], dataset_size: int):
"""
Print formatted results for a single dataset.
Args:
results: Dictionary containing evaluation results
dataset_size: Number of samples in the dataset
"""
print(f"Dataset size: {dataset_size} samples")
print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<12} {'NN Acc':<10} {'Centroid Acc':<12} {'F1 Macro':<10}")
print("-" * 80)
for model_type in ['custom', 'clip']:
for emb_type in ['text', 'image']:
key = f"{model_type}_{emb_type}"
if key in results:
metrics = results[key]
model_name = "Custom Model" if model_type == 'custom' else "Fashion-CLIP Baseline"
print(
f"{model_name:<20} "
f"{emb_type.capitalize():<10} "
f"{metrics['separation_score']:<12.4f} "
f"{metrics['accuracy']*100:<10.1f}% "
f"{metrics['centroid_accuracy']*100:<12.1f}% "
f"{metrics['f1_macro']*100:<10.1f}%"
)
if __name__ == "__main__":
main()