|
|
""" |
|
|
Hierarchy Embedding Evaluation with Fashion-CLIP Baseline Comparison |
|
|
|
|
|
This module provides comprehensive evaluation tools for hierarchy classification models, |
|
|
comparing custom model performance against the Fashion-CLIP baseline. It includes: |
|
|
|
|
|
- Embedding quality metrics (intra-class/inter-class similarity) |
|
|
- Classification accuracy with multiple methods (nearest neighbor, centroid-based) |
|
|
- Confusion matrix generation and visualization |
|
|
- Support for multiple datasets (validation set, Fashion-MNIST, Kaggle Marqo) |
|
|
- Advanced techniques: ZCA whitening, Mahalanobis distance, Test-Time Augmentation |
|
|
|
|
|
Key Features: |
|
|
- Custom model evaluation with full hierarchy classification pipeline |
|
|
- Fashion-CLIP baseline comparison for performance benchmarking |
|
|
- Multi-dataset evaluation (validation, Fashion-MNIST, Kaggle Marqo) |
|
|
- Flexible evaluation options (whitening, Mahalanobis distance) |
|
|
- Detailed metrics: accuracy, F1 scores, confusion matrices |
|
|
|
|
|
Author: Fashion Search Team |
|
|
License: Apache 2.0 |
|
|
""" |
|
|
|
|
|
|
|
|
import os |
|
|
import warnings |
|
|
from collections import defaultdict |
|
|
from io import BytesIO |
|
|
from typing import Dict, List, Tuple, Optional, Union, Any |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import requests |
|
|
import torch |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from PIL import Image |
|
|
from sklearn.metrics import ( |
|
|
accuracy_score, |
|
|
classification_report, |
|
|
confusion_matrix, |
|
|
f1_score, |
|
|
) |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
from sklearn.model_selection import train_test_split |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torchvision import transforms |
|
|
from tqdm import tqdm |
|
|
from transformers import CLIPProcessor, CLIPModel as TransformersCLIPModel |
|
|
|
|
|
|
|
|
import config |
|
|
from config import device, hierarchy_model_path, hierarchy_column, local_dataset_path |
|
|
from hierarchy_model import Model, HierarchyExtractor, HierarchyDataset, collate_fn |
|
|
|
|
|
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_SAMPLES_EVALUATION = 10000 |
|
|
|
|
|
|
|
|
MAX_INTER_CLASS_COMPARISONS = 10000 |
|
|
|
|
|
|
|
|
FASHION_MNIST_LABELS = { |
|
|
0: "T-shirt/top", |
|
|
1: "Trouser", |
|
|
2: "Pullover", |
|
|
3: "Dress", |
|
|
4: "Coat", |
|
|
5: "Sandal", |
|
|
6: "Shirt", |
|
|
7: "Sneaker", |
|
|
8: "Bag", |
|
|
9: "Ankle boot" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_fashion_mnist_to_image(pixel_values: np.ndarray) -> Image.Image: |
|
|
""" |
|
|
Convert Fashion-MNIST pixel values to RGB PIL Image. |
|
|
|
|
|
Args: |
|
|
pixel_values: Flat array of 784 pixel values (28x28) |
|
|
|
|
|
Returns: |
|
|
PIL Image in RGB format |
|
|
""" |
|
|
|
|
|
image_array = np.array(pixel_values).reshape(28, 28).astype(np.uint8) |
|
|
|
|
|
|
|
|
image_array = np.stack([image_array] * 3, axis=-1) |
|
|
|
|
|
return Image.fromarray(image_array) |
|
|
|
|
|
|
|
|
def get_fashion_mnist_labels() -> Dict[int, str]: |
|
|
""" |
|
|
Get Fashion-MNIST class labels mapping. |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping label IDs to class names |
|
|
""" |
|
|
return FASHION_MNIST_LABELS.copy() |
|
|
|
|
|
|
|
|
def create_fashion_mnist_to_hierarchy_mapping( |
|
|
hierarchy_classes: List[str] |
|
|
) -> Dict[int, Optional[str]]: |
|
|
""" |
|
|
Create mapping from Fashion-MNIST labels to custom hierarchy classes. |
|
|
|
|
|
This function performs intelligent matching between Fashion-MNIST categories |
|
|
and the custom model's hierarchy classes using exact, partial, and semantic matching. |
|
|
|
|
|
Args: |
|
|
hierarchy_classes: List of hierarchy class names from the custom model |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping Fashion-MNIST label IDs to hierarchy class names |
|
|
(None if no match found) |
|
|
""" |
|
|
|
|
|
hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] |
|
|
|
|
|
|
|
|
mapping = {} |
|
|
|
|
|
for fm_label_id, fm_label in FASHION_MNIST_LABELS.items(): |
|
|
fm_label_lower = fm_label.lower() |
|
|
matched_hierarchy = None |
|
|
|
|
|
|
|
|
if fm_label_lower in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(fm_label_lower)] |
|
|
|
|
|
|
|
|
elif any(h in fm_label_lower or fm_label_lower in h for h in hierarchy_classes_lower): |
|
|
for h_class in hierarchy_classes: |
|
|
h_lower = h_class.lower() |
|
|
if h_lower in fm_label_lower or fm_label_lower in h_lower: |
|
|
matched_hierarchy = h_class |
|
|
break |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
if fm_label_lower in ['t-shirt/top', 'top']: |
|
|
if 'top' in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('top')] |
|
|
elif 'shirt' in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('shirt')] |
|
|
|
|
|
|
|
|
elif 'trouser' in fm_label_lower: |
|
|
for possible in ['pant', 'pants', 'trousers', 'trouser', 'bottom']: |
|
|
if possible in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)] |
|
|
break |
|
|
|
|
|
|
|
|
elif 'pullover' in fm_label_lower: |
|
|
for possible in ['sweater', 'pullover', 'top']: |
|
|
if possible in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)] |
|
|
break |
|
|
|
|
|
|
|
|
elif 'dress' in fm_label_lower: |
|
|
if 'dress' in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('dress')] |
|
|
|
|
|
|
|
|
elif 'coat' in fm_label_lower: |
|
|
for possible in ['coat', 'jacket', 'outerwear']: |
|
|
if possible in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)] |
|
|
break |
|
|
|
|
|
|
|
|
elif fm_label_lower in ['sandal', 'sneaker', 'ankle boot']: |
|
|
for possible in ['shoes', 'shoe', 'footwear', 'sandal', 'sneaker', 'boot']: |
|
|
if possible in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)] |
|
|
break |
|
|
|
|
|
|
|
|
elif 'bag' in fm_label_lower: |
|
|
if 'bag' in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('bag')] |
|
|
|
|
|
mapping[fm_label_id] = matched_hierarchy |
|
|
|
|
|
|
|
|
if matched_hierarchy: |
|
|
print(f" {fm_label} ({fm_label_id}) -> {matched_hierarchy}") |
|
|
else: |
|
|
print(f" β οΈ {fm_label} ({fm_label_id}) -> NO MATCH (will be filtered out)") |
|
|
|
|
|
return mapping |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FashionMNISTDataset(Dataset): |
|
|
""" |
|
|
Fashion-MNIST Dataset class for evaluation. |
|
|
|
|
|
This dataset handles Fashion-MNIST images with proper preprocessing and |
|
|
label mapping to custom hierarchy classes. Aligned with main_model_evaluation.py |
|
|
for consistent evaluation across different scripts. |
|
|
|
|
|
Args: |
|
|
dataframe: Pandas DataFrame containing Fashion-MNIST data with pixel columns |
|
|
image_size: Target size for image resizing (default: 224) |
|
|
label_mapping: Optional mapping from Fashion-MNIST label IDs to hierarchy classes |
|
|
|
|
|
Returns: |
|
|
Tuple of (image_tensor, description, color, hierarchy) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataframe: pd.DataFrame, |
|
|
image_size: int = 224, |
|
|
label_mapping: Optional[Dict[int, str]] = None |
|
|
): |
|
|
self.dataframe = dataframe |
|
|
self.image_size = image_size |
|
|
self.labels_map = get_fashion_mnist_labels() |
|
|
self.label_mapping = label_mapping |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((image_size, image_size)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
), |
|
|
]) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.dataframe) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, str, str]: |
|
|
""" |
|
|
Get a single item from the dataset. |
|
|
|
|
|
Args: |
|
|
idx: Index of the item to retrieve |
|
|
|
|
|
Returns: |
|
|
Tuple of (image_tensor, description, color, hierarchy) |
|
|
""" |
|
|
row = self.dataframe.iloc[idx] |
|
|
|
|
|
|
|
|
pixel_cols = [f"pixel{i}" for i in range(1, 785)] |
|
|
pixel_values = row[pixel_cols].values |
|
|
|
|
|
|
|
|
image = convert_fashion_mnist_to_image(pixel_values) |
|
|
image = self.transform(image) |
|
|
|
|
|
|
|
|
label_id = int(row['label']) |
|
|
description = self.labels_map[label_id] |
|
|
color = "unknown" |
|
|
|
|
|
|
|
|
if self.label_mapping and label_id in self.label_mapping: |
|
|
hierarchy = self.label_mapping[label_id] |
|
|
else: |
|
|
hierarchy = self.labels_map[label_id] |
|
|
|
|
|
return image, description, color, hierarchy |
|
|
|
|
|
|
|
|
class CLIPDataset(Dataset): |
|
|
""" |
|
|
Dataset class for Fashion-CLIP baseline evaluation. |
|
|
|
|
|
This dataset handles image loading from various sources (local paths, URLs, PIL Images) |
|
|
and applies standard validation transforms without augmentation. |
|
|
|
|
|
Args: |
|
|
dataframe: Pandas DataFrame containing image and text data |
|
|
|
|
|
Returns: |
|
|
Tuple of (image_tensor, description, hierarchy) |
|
|
""" |
|
|
|
|
|
def __init__(self, dataframe: pd.DataFrame): |
|
|
self.dataframe = dataframe |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
) |
|
|
]) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.dataframe) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str, str]: |
|
|
""" |
|
|
Get a single item from the dataset. |
|
|
|
|
|
Args: |
|
|
idx: Index of the item to retrieve |
|
|
|
|
|
Returns: |
|
|
Tuple of (image_tensor, description, hierarchy) |
|
|
""" |
|
|
row = self.dataframe.iloc[idx] |
|
|
|
|
|
|
|
|
image = self._load_image(row, idx) |
|
|
|
|
|
|
|
|
image_tensor = self.transform(image) |
|
|
|
|
|
description = row[config.text_column] |
|
|
hierarchy = row[config.hierarchy_column] |
|
|
|
|
|
return image_tensor, description, hierarchy |
|
|
|
|
|
def _load_image(self, row: pd.Series, idx: int) -> Image.Image: |
|
|
""" |
|
|
Load image from various sources with fallback handling. |
|
|
|
|
|
Args: |
|
|
row: DataFrame row containing image information |
|
|
idx: Index for error reporting |
|
|
|
|
|
Returns: |
|
|
PIL Image in RGB format |
|
|
""" |
|
|
|
|
|
if config.column_local_image_path in row.index and pd.notna(row[config.column_local_image_path]): |
|
|
local_path = row[config.column_local_image_path] |
|
|
try: |
|
|
if os.path.exists(local_path): |
|
|
return Image.open(local_path).convert("RGB") |
|
|
else: |
|
|
print(f"β οΈ Local image not found: {local_path}") |
|
|
except Exception as e: |
|
|
print(f"β οΈ Failed to load local image {idx}: {e}") |
|
|
|
|
|
|
|
|
image_data = row.get(config.column_url_image) |
|
|
|
|
|
|
|
|
if isinstance(image_data, dict) and 'bytes' in image_data: |
|
|
return Image.open(BytesIO(image_data['bytes'])).convert('RGB') |
|
|
|
|
|
|
|
|
if isinstance(image_data, (list, np.ndarray)): |
|
|
pixels = np.array(image_data).reshape(28, 28) |
|
|
return Image.fromarray(pixels.astype(np.uint8)).convert("RGB") |
|
|
|
|
|
|
|
|
if isinstance(image_data, Image.Image): |
|
|
return image_data.convert("RGB") |
|
|
|
|
|
|
|
|
try: |
|
|
response = requests.get(image_data, timeout=10) |
|
|
response.raise_for_status() |
|
|
return Image.open(BytesIO(response.content)).convert("RGB") |
|
|
except Exception as e: |
|
|
print(f"β οΈ Failed to load image {idx}: {e}") |
|
|
|
|
|
return Image.new('RGB', (224, 224), color='gray') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPBaselineEvaluator: |
|
|
""" |
|
|
Fashion-CLIP Baseline Evaluator. |
|
|
|
|
|
This class handles the loading and evaluation of the Fashion-CLIP baseline model |
|
|
(patrickjohncyh/fashion-clip) for comparison with custom models. |
|
|
|
|
|
Args: |
|
|
device: Device to run the model on ('cuda', 'mps', or 'cpu') |
|
|
""" |
|
|
|
|
|
def __init__(self, device: str = 'mps'): |
|
|
self.device = torch.device(device) |
|
|
|
|
|
|
|
|
print("π€ Loading Fashion-CLIP baseline model from transformers...") |
|
|
model_name = "patrickjohncyh/fashion-clip" |
|
|
self.clip_model = TransformersCLIPModel.from_pretrained(model_name).to(self.device) |
|
|
self.clip_processor = CLIPProcessor.from_pretrained(model_name) |
|
|
|
|
|
self.clip_model.eval() |
|
|
print("β
Fashion-CLIP model loaded successfully") |
|
|
|
|
|
def extract_clip_embeddings( |
|
|
self, |
|
|
images: List[Union[torch.Tensor, Image.Image]], |
|
|
texts: List[str] |
|
|
) -> Tuple[np.ndarray, np.ndarray]: |
|
|
""" |
|
|
Extract Fashion-CLIP embeddings for images and texts. |
|
|
|
|
|
This method processes images and texts through the Fashion-CLIP model |
|
|
to generate normalized embeddings. Aligned with main_model_evaluation.py |
|
|
for consistency. |
|
|
|
|
|
Args: |
|
|
images: List of images (tensors or PIL Images) |
|
|
texts: List of text descriptions |
|
|
|
|
|
Returns: |
|
|
Tuple of (image_embeddings, text_embeddings) as numpy arrays |
|
|
""" |
|
|
all_image_embeddings = [] |
|
|
all_text_embeddings = [] |
|
|
|
|
|
|
|
|
batch_size = 32 |
|
|
num_batches = (len(images) + batch_size - 1) // batch_size |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch_idx in tqdm(range(num_batches), desc="Extracting CLIP embeddings"): |
|
|
start_idx = batch_idx * batch_size |
|
|
end_idx = min(start_idx + batch_size, len(images)) |
|
|
|
|
|
batch_images = images[start_idx:end_idx] |
|
|
batch_texts = texts[start_idx:end_idx] |
|
|
|
|
|
|
|
|
text_features = self._extract_text_features(batch_texts) |
|
|
|
|
|
|
|
|
image_features = self._extract_image_features(batch_images) |
|
|
|
|
|
|
|
|
all_image_embeddings.append(image_features.cpu().numpy()) |
|
|
all_text_embeddings.append(text_features.cpu().numpy()) |
|
|
|
|
|
|
|
|
del text_features, image_features |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return np.vstack(all_image_embeddings), np.vstack(all_text_embeddings) |
|
|
|
|
|
def _extract_text_features(self, texts: List[str]) -> torch.Tensor: |
|
|
""" |
|
|
Extract text features using Fashion-CLIP. |
|
|
|
|
|
Args: |
|
|
texts: List of text descriptions |
|
|
|
|
|
Returns: |
|
|
Normalized text feature embeddings |
|
|
""" |
|
|
|
|
|
text_inputs = self.clip_processor( |
|
|
text=texts, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=77 |
|
|
) |
|
|
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} |
|
|
|
|
|
|
|
|
text_features = self.clip_model.get_text_features(**text_inputs) |
|
|
|
|
|
|
|
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
return text_features |
|
|
|
|
|
def _extract_image_features( |
|
|
self, |
|
|
images: List[Union[torch.Tensor, Image.Image]] |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Extract image features using Fashion-CLIP. |
|
|
|
|
|
Args: |
|
|
images: List of images (tensors or PIL Images) |
|
|
|
|
|
Returns: |
|
|
Normalized image feature embeddings |
|
|
""" |
|
|
|
|
|
pil_images = [] |
|
|
for img in images: |
|
|
if isinstance(img, torch.Tensor): |
|
|
pil_images.append(self._tensor_to_pil(img)) |
|
|
elif isinstance(img, Image.Image): |
|
|
pil_images.append(img) |
|
|
else: |
|
|
raise ValueError(f"Unsupported image type: {type(img)}") |
|
|
|
|
|
|
|
|
image_inputs = self.clip_processor( |
|
|
images=pil_images, |
|
|
return_tensors="pt" |
|
|
) |
|
|
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()} |
|
|
|
|
|
|
|
|
image_features = self.clip_model.get_image_features(**image_inputs) |
|
|
|
|
|
|
|
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
return image_features |
|
|
|
|
|
def _tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image: |
|
|
""" |
|
|
Convert a normalized tensor to PIL Image. |
|
|
|
|
|
Args: |
|
|
tensor: Image tensor (C, H, W) |
|
|
|
|
|
Returns: |
|
|
PIL Image |
|
|
""" |
|
|
if tensor.dim() != 3: |
|
|
raise ValueError(f"Expected 3D tensor, got {tensor.dim()}D") |
|
|
|
|
|
|
|
|
if tensor.min() < 0 or tensor.max() > 1: |
|
|
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) |
|
|
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) |
|
|
tensor = tensor * std + mean |
|
|
tensor = torch.clamp(tensor, 0, 1) |
|
|
|
|
|
|
|
|
return transforms.ToPILImage()(tensor) |
|
|
|
|
|
|
|
|
class EmbeddingEvaluator: |
|
|
""" |
|
|
Comprehensive Embedding Evaluator for Hierarchy Classification. |
|
|
|
|
|
This class provides a complete evaluation pipeline for hierarchy classification models, |
|
|
including custom model evaluation and Fashion-CLIP baseline comparison. It supports |
|
|
multiple evaluation metrics, datasets, and advanced techniques. |
|
|
|
|
|
Key Features: |
|
|
- Custom model loading and evaluation |
|
|
- Fashion-CLIP baseline comparison |
|
|
- Multiple classification methods (nearest neighbor, centroid, Mahalanobis) |
|
|
- Advanced techniques (ZCA whitening, Test-Time Augmentation) |
|
|
- Comprehensive metrics (accuracy, F1, confusion matrices) |
|
|
|
|
|
Args: |
|
|
model_path: Path to the trained custom model checkpoint |
|
|
directory: Output directory for saving evaluation results |
|
|
""" |
|
|
|
|
|
def __init__(self, model_path: str, directory: str): |
|
|
self.directory = directory |
|
|
self.device = device |
|
|
|
|
|
|
|
|
print(f"π Using dataset with local images: {local_dataset_path}") |
|
|
df = pd.read_csv(local_dataset_path) |
|
|
print(f"π Loaded {len(df)} samples") |
|
|
|
|
|
|
|
|
hierarchy_classes = sorted(df[hierarchy_column].unique().tolist()) |
|
|
print(f"π Found {len(hierarchy_classes)} hierarchy classes") |
|
|
|
|
|
|
|
|
if len(df) > MAX_SAMPLES_EVALUATION: |
|
|
print(f"β οΈ Dataset too large ({len(df)} samples), sampling to {MAX_SAMPLES_EVALUATION} samples") |
|
|
df = self._stratified_sample(df, MAX_SAMPLES_EVALUATION) |
|
|
|
|
|
|
|
|
_, self.val_df = train_test_split( |
|
|
df, |
|
|
test_size=0.2, |
|
|
random_state=42, |
|
|
stratify=df['hierarchy'] |
|
|
) |
|
|
|
|
|
|
|
|
self._load_model(model_path) |
|
|
|
|
|
|
|
|
self.clip_evaluator = CLIPBaselineEvaluator(device) |
|
|
|
|
|
def _stratified_sample(self, df: pd.DataFrame, max_samples: int) -> pd.DataFrame: |
|
|
""" |
|
|
Perform stratified sampling to maintain class distribution. |
|
|
|
|
|
Args: |
|
|
df: Original DataFrame |
|
|
max_samples: Maximum number of samples to keep |
|
|
|
|
|
Returns: |
|
|
Sampled DataFrame |
|
|
""" |
|
|
|
|
|
df_sampled = df.groupby('hierarchy', group_keys=False).apply( |
|
|
lambda x: x.sample( |
|
|
n=min(len(x), int(max_samples * len(x) / len(df))), |
|
|
random_state=42 |
|
|
) |
|
|
).reset_index(drop=True) |
|
|
|
|
|
|
|
|
if len(df_sampled) < max_samples: |
|
|
remaining = max_samples - len(df_sampled) |
|
|
extra = df.sample(n=remaining, random_state=42) |
|
|
df_sampled = pd.concat([df_sampled, extra]).reset_index(drop=True) |
|
|
|
|
|
return df_sampled |
|
|
|
|
|
def _load_model(self, model_path: str): |
|
|
""" |
|
|
Load the custom hierarchy classification model. |
|
|
|
|
|
Args: |
|
|
model_path: Path to the model checkpoint |
|
|
|
|
|
Raises: |
|
|
FileNotFoundError: If model file doesn't exist |
|
|
""" |
|
|
if not os.path.exists(model_path): |
|
|
raise FileNotFoundError(f"Model file {model_path} not found") |
|
|
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=self.device) |
|
|
|
|
|
|
|
|
config_dict = checkpoint.get('config', {}) |
|
|
saved_hierarchy_classes = checkpoint['hierarchy_classes'] |
|
|
|
|
|
|
|
|
self.hierarchy_classes = saved_hierarchy_classes |
|
|
|
|
|
|
|
|
self.vocab = HierarchyExtractor(saved_hierarchy_classes) |
|
|
|
|
|
|
|
|
self.model = Model( |
|
|
num_hierarchy_classes=len(saved_hierarchy_classes), |
|
|
embed_dim=config_dict['embed_dim'], |
|
|
dropout=config_dict['dropout'] |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
self.model.load_state_dict(checkpoint['model_state']) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
print(f"β
Custom model loaded with:") |
|
|
print(f"π Hierarchy classes: {len(saved_hierarchy_classes)}") |
|
|
print(f"π― Embed dim: {config_dict['embed_dim']}") |
|
|
print(f"π§ Dropout: {config_dict['dropout']}") |
|
|
print(f"π
Epoch: {checkpoint.get('epoch', 'unknown')}") |
|
|
|
|
|
def _collate_fn_wrapper(self, batch: List[Tuple]) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Wrapper for collate_fn that can be pickled (required for DataLoader). |
|
|
|
|
|
Handles both formats: |
|
|
- (image, description, hierarchy) for HierarchyDataset |
|
|
- (image, description, color, hierarchy) for FashionMNISTDataset |
|
|
|
|
|
Args: |
|
|
batch: List of samples from dataset |
|
|
|
|
|
Returns: |
|
|
Collated batch dictionary |
|
|
""" |
|
|
|
|
|
if len(batch[0]) == 4: |
|
|
|
|
|
batch_converted = [(b[0], b[1], b[3]) for b in batch] |
|
|
return collate_fn(batch_converted, self.vocab) |
|
|
else: |
|
|
|
|
|
return collate_fn(batch, self.vocab) |
|
|
|
|
|
def create_dataloader( |
|
|
self, |
|
|
dataframe_or_dataset: Union[pd.DataFrame, Dataset], |
|
|
batch_size: int = 16 |
|
|
) -> DataLoader: |
|
|
""" |
|
|
Create a DataLoader for the custom model. |
|
|
|
|
|
Aligned with main_model_evaluation.py for consistency. |
|
|
|
|
|
Args: |
|
|
dataframe_or_dataset: Either a pandas DataFrame or a Dataset object |
|
|
batch_size: Batch size for the DataLoader |
|
|
|
|
|
Returns: |
|
|
Configured DataLoader |
|
|
""" |
|
|
|
|
|
if isinstance(dataframe_or_dataset, Dataset): |
|
|
dataset = dataframe_or_dataset |
|
|
print(f"π Using pre-created Dataset object") |
|
|
|
|
|
|
|
|
elif isinstance(dataframe_or_dataset, pd.DataFrame): |
|
|
|
|
|
if 'pixel1' in dataframe_or_dataset.columns: |
|
|
print(f"π Detected Fashion-MNIST data, creating FashionMNISTDataset") |
|
|
dataset = FashionMNISTDataset(dataframe_or_dataset, image_size=224) |
|
|
else: |
|
|
dataset = HierarchyDataset(dataframe_or_dataset, image_size=224) |
|
|
else: |
|
|
raise ValueError(f"Unsupported type: {type(dataframe_or_dataset)}") |
|
|
|
|
|
|
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
collate_fn=self._collate_fn_wrapper, |
|
|
num_workers=0, |
|
|
pin_memory=False |
|
|
) |
|
|
|
|
|
return dataloader |
|
|
|
|
|
def create_clip_dataloader( |
|
|
self, |
|
|
dataframe_or_dataset: Union[pd.DataFrame, Dataset], |
|
|
batch_size: int = 16 |
|
|
) -> DataLoader: |
|
|
""" |
|
|
Create a DataLoader for Fashion-CLIP baseline. |
|
|
|
|
|
Args: |
|
|
dataframe_or_dataset: Either a pandas DataFrame or a Dataset object |
|
|
batch_size: Batch size for the DataLoader |
|
|
|
|
|
Returns: |
|
|
Configured DataLoader |
|
|
""" |
|
|
|
|
|
if isinstance(dataframe_or_dataset, Dataset): |
|
|
dataset = dataframe_or_dataset |
|
|
print(f"π Using pre-created Dataset object for CLIP") |
|
|
|
|
|
|
|
|
elif isinstance(dataframe_or_dataset, pd.DataFrame): |
|
|
|
|
|
if 'pixel1' in dataframe_or_dataset.columns: |
|
|
print("π Detected Fashion-MNIST data for Fashion-CLIP") |
|
|
dataset = FashionMNISTDataset(dataframe_or_dataset, image_size=224) |
|
|
else: |
|
|
dataset = CLIPDataset(dataframe_or_dataset) |
|
|
else: |
|
|
raise ValueError(f"Unsupported type: {type(dataframe_or_dataset)}") |
|
|
|
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=0, |
|
|
pin_memory=False |
|
|
) |
|
|
|
|
|
return dataloader |
|
|
|
|
|
def extract_custom_embeddings( |
|
|
self, |
|
|
dataloader: DataLoader, |
|
|
embedding_type: str = 'text', |
|
|
use_tta: bool = False |
|
|
) -> Tuple[np.ndarray, List[str], List[str]]: |
|
|
""" |
|
|
Extract embeddings from custom model with optional Test-Time Augmentation. |
|
|
|
|
|
Args: |
|
|
dataloader: DataLoader for the dataset |
|
|
embedding_type: Type of embedding to extract ('text', 'image', or 'both') |
|
|
use_tta: Whether to use Test-Time Augmentation for images |
|
|
|
|
|
Returns: |
|
|
Tuple of (embeddings, labels, texts) |
|
|
""" |
|
|
all_embeddings = [] |
|
|
all_labels = [] |
|
|
all_texts = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in tqdm(dataloader, desc=f"Extracting custom {embedding_type} embeddings{' with TTA' if use_tta else ''}"): |
|
|
images = batch['image'].to(self.device) |
|
|
hierarchy_indices = batch['hierarchy_indices'].to(self.device) |
|
|
hierarchy_labels = batch['hierarchy'] |
|
|
|
|
|
|
|
|
if use_tta and embedding_type == 'image' and images.dim() == 5: |
|
|
embeddings = self._extract_with_tta(images, hierarchy_indices) |
|
|
else: |
|
|
|
|
|
out = self.model(image=images, hierarchy_indices=hierarchy_indices) |
|
|
embeddings = out['z_txt'] if embedding_type == 'text' else out['z_img'] |
|
|
|
|
|
all_embeddings.append(embeddings.cpu().numpy()) |
|
|
all_labels.extend(hierarchy_labels) |
|
|
all_texts.extend(hierarchy_labels) |
|
|
|
|
|
|
|
|
del images, hierarchy_indices, embeddings, out |
|
|
if str(self.device) != 'cpu': |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return np.vstack(all_embeddings), all_labels, all_texts |
|
|
|
|
|
def _extract_with_tta( |
|
|
self, |
|
|
images: torch.Tensor, |
|
|
hierarchy_indices: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Extract embeddings using Test-Time Augmentation. |
|
|
|
|
|
Args: |
|
|
images: Images with TTA crops [batch_size, tta_crops, C, H, W] |
|
|
hierarchy_indices: Hierarchy class indices |
|
|
|
|
|
Returns: |
|
|
Averaged embeddings [batch_size, embed_dim] |
|
|
""" |
|
|
batch_size, tta_crops, C, H, W = images.shape |
|
|
|
|
|
|
|
|
images_flat = images.view(batch_size * tta_crops, C, H, W) |
|
|
|
|
|
|
|
|
hierarchy_indices_repeated = hierarchy_indices.unsqueeze(1).repeat(1, tta_crops).view(-1) |
|
|
|
|
|
|
|
|
out = self.model(image=images_flat, hierarchy_indices=hierarchy_indices_repeated) |
|
|
embeddings_flat = out['z_img'] |
|
|
|
|
|
|
|
|
embeddings = embeddings_flat.view(batch_size, tta_crops, -1) |
|
|
|
|
|
|
|
|
embeddings = embeddings.mean(dim=1) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def apply_whitening( |
|
|
self, |
|
|
embeddings: np.ndarray, |
|
|
epsilon: float = 1e-5 |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Apply ZCA whitening to embeddings for better feature decorrelation. |
|
|
|
|
|
Whitening removes correlations between dimensions and can improve |
|
|
class separation by normalizing the feature space. |
|
|
|
|
|
Args: |
|
|
embeddings: Input embeddings [N, D] |
|
|
epsilon: Small constant for numerical stability |
|
|
|
|
|
Returns: |
|
|
Whitened embeddings [N, D] |
|
|
""" |
|
|
|
|
|
mean = np.mean(embeddings, axis=0, keepdims=True) |
|
|
centered = embeddings - mean |
|
|
|
|
|
|
|
|
cov = np.cov(centered.T) |
|
|
|
|
|
|
|
|
eigenvalues, eigenvectors = np.linalg.eigh(cov) |
|
|
|
|
|
|
|
|
d = np.diag(1.0 / np.sqrt(eigenvalues + epsilon)) |
|
|
whiten_transform = eigenvectors @ d @ eigenvectors.T |
|
|
|
|
|
|
|
|
whitened = centered @ whiten_transform |
|
|
|
|
|
|
|
|
norms = np.linalg.norm(whitened, axis=1, keepdims=True) |
|
|
whitened = whitened / (norms + epsilon) |
|
|
|
|
|
return whitened |
|
|
|
|
|
def compute_similarity_metrics( |
|
|
self, |
|
|
embeddings: np.ndarray, |
|
|
labels: List[str], |
|
|
apply_whitening_norm: bool = False |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Compute intra-class and inter-class similarity metrics. |
|
|
|
|
|
Args: |
|
|
embeddings: Embedding vectors |
|
|
labels: Class labels |
|
|
apply_whitening_norm: Whether to apply ZCA whitening |
|
|
|
|
|
Returns: |
|
|
Dictionary containing similarity metrics and accuracies |
|
|
""" |
|
|
|
|
|
if apply_whitening_norm: |
|
|
embeddings = self.apply_whitening(embeddings) |
|
|
|
|
|
|
|
|
similarities = cosine_similarity(embeddings) |
|
|
|
|
|
|
|
|
hierarchy_groups = defaultdict(list) |
|
|
for i, hierarchy in enumerate(labels): |
|
|
hierarchy_groups[hierarchy].append(i) |
|
|
|
|
|
|
|
|
intra_class_similarities = self._compute_intra_class_similarities( |
|
|
similarities, hierarchy_groups |
|
|
) |
|
|
|
|
|
|
|
|
inter_class_similarities = self._compute_inter_class_similarities( |
|
|
similarities, hierarchy_groups |
|
|
) |
|
|
|
|
|
|
|
|
nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities) |
|
|
centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels) |
|
|
|
|
|
return { |
|
|
'intra_class_similarities': intra_class_similarities, |
|
|
'inter_class_similarities': inter_class_similarities, |
|
|
'intra_class_mean': np.mean(intra_class_similarities) if intra_class_similarities else 0, |
|
|
'inter_class_mean': np.mean(inter_class_similarities) if inter_class_similarities else 0, |
|
|
'separation_score': np.mean(intra_class_similarities) - np.mean(inter_class_similarities) if intra_class_similarities and inter_class_similarities else 0, |
|
|
'accuracy': nn_accuracy, |
|
|
'centroid_accuracy': centroid_accuracy |
|
|
} |
|
|
|
|
|
def _compute_intra_class_similarities( |
|
|
self, |
|
|
similarities: np.ndarray, |
|
|
hierarchy_groups: Dict[str, List[int]] |
|
|
) -> List[float]: |
|
|
""" |
|
|
Compute within-class similarities. |
|
|
|
|
|
Args: |
|
|
similarities: Pairwise similarity matrix |
|
|
hierarchy_groups: Mapping from hierarchy to sample indices |
|
|
|
|
|
Returns: |
|
|
List of intra-class similarity values |
|
|
""" |
|
|
intra_class_similarities = [] |
|
|
|
|
|
for hierarchy, indices in hierarchy_groups.items(): |
|
|
if len(indices) > 1: |
|
|
|
|
|
for i in range(len(indices)): |
|
|
for j in range(i + 1, len(indices)): |
|
|
sim = similarities[indices[i], indices[j]] |
|
|
intra_class_similarities.append(sim) |
|
|
|
|
|
return intra_class_similarities |
|
|
|
|
|
def _compute_inter_class_similarities( |
|
|
self, |
|
|
similarities: np.ndarray, |
|
|
hierarchy_groups: Dict[str, List[int]] |
|
|
) -> List[float]: |
|
|
""" |
|
|
Compute between-class similarities with sampling for efficiency. |
|
|
|
|
|
To prevent O(nΒ²) complexity on large datasets, we limit the number |
|
|
of comparisons through sampling. |
|
|
|
|
|
Args: |
|
|
similarities: Pairwise similarity matrix |
|
|
hierarchy_groups: Mapping from hierarchy to sample indices |
|
|
|
|
|
Returns: |
|
|
List of inter-class similarity values |
|
|
""" |
|
|
inter_class_similarities = [] |
|
|
hierarchies = list(hierarchy_groups.keys()) |
|
|
comparison_count = 0 |
|
|
|
|
|
for i in range(len(hierarchies)): |
|
|
for j in range(i + 1, len(hierarchies)): |
|
|
hierarchy1_indices = hierarchy_groups[hierarchies[i]] |
|
|
hierarchy2_indices = hierarchy_groups[hierarchies[j]] |
|
|
|
|
|
|
|
|
max_samples_per_pair = min(100, len(hierarchy1_indices), len(hierarchy2_indices)) |
|
|
sampled_idx1 = np.random.choice( |
|
|
hierarchy1_indices, |
|
|
size=min(max_samples_per_pair, len(hierarchy1_indices)), |
|
|
replace=False |
|
|
) |
|
|
sampled_idx2 = np.random.choice( |
|
|
hierarchy2_indices, |
|
|
size=min(max_samples_per_pair, len(hierarchy2_indices)), |
|
|
replace=False |
|
|
) |
|
|
|
|
|
|
|
|
for idx1 in sampled_idx1: |
|
|
for idx2 in sampled_idx2: |
|
|
if comparison_count >= MAX_INTER_CLASS_COMPARISONS: |
|
|
break |
|
|
sim = similarities[idx1, idx2] |
|
|
inter_class_similarities.append(sim) |
|
|
comparison_count += 1 |
|
|
if comparison_count >= MAX_INTER_CLASS_COMPARISONS: |
|
|
break |
|
|
if comparison_count >= MAX_INTER_CLASS_COMPARISONS: |
|
|
break |
|
|
if comparison_count >= MAX_INTER_CLASS_COMPARISONS: |
|
|
break |
|
|
|
|
|
return inter_class_similarities |
|
|
|
|
|
def compute_embedding_accuracy( |
|
|
self, |
|
|
embeddings: np.ndarray, |
|
|
labels: List[str], |
|
|
similarities: np.ndarray |
|
|
) -> float: |
|
|
""" |
|
|
Compute classification accuracy using nearest neighbor in embedding space. |
|
|
|
|
|
Args: |
|
|
embeddings: Embedding vectors |
|
|
labels: True class labels |
|
|
similarities: Precomputed similarity matrix |
|
|
|
|
|
Returns: |
|
|
Classification accuracy |
|
|
""" |
|
|
correct_predictions = 0 |
|
|
total_predictions = len(labels) |
|
|
|
|
|
for i in range(len(embeddings)): |
|
|
true_label = labels[i] |
|
|
|
|
|
|
|
|
similarities_row = similarities[i].copy() |
|
|
similarities_row[i] = -1 |
|
|
nearest_neighbor_idx = np.argmax(similarities_row) |
|
|
predicted_label = labels[nearest_neighbor_idx] |
|
|
|
|
|
if predicted_label == true_label: |
|
|
correct_predictions += 1 |
|
|
|
|
|
return correct_predictions / total_predictions if total_predictions > 0 else 0 |
|
|
|
|
|
def compute_centroid_accuracy( |
|
|
self, |
|
|
embeddings: np.ndarray, |
|
|
labels: List[str] |
|
|
) -> float: |
|
|
""" |
|
|
Compute classification accuracy using hierarchy centroids. |
|
|
|
|
|
Args: |
|
|
embeddings: Embedding vectors |
|
|
labels: True class labels |
|
|
|
|
|
Returns: |
|
|
Classification accuracy |
|
|
""" |
|
|
|
|
|
unique_hierarchies = list(set(labels)) |
|
|
centroids = {} |
|
|
|
|
|
for hierarchy in unique_hierarchies: |
|
|
hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy] |
|
|
hierarchy_embeddings = embeddings[hierarchy_indices] |
|
|
centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0) |
|
|
|
|
|
|
|
|
correct_predictions = 0 |
|
|
total_predictions = len(labels) |
|
|
|
|
|
for i, embedding in enumerate(embeddings): |
|
|
true_label = labels[i] |
|
|
|
|
|
|
|
|
best_similarity = -1 |
|
|
predicted_label = None |
|
|
|
|
|
for hierarchy, centroid in centroids.items(): |
|
|
similarity = cosine_similarity([embedding], [centroid])[0][0] |
|
|
if similarity > best_similarity: |
|
|
best_similarity = similarity |
|
|
predicted_label = hierarchy |
|
|
|
|
|
if predicted_label == true_label: |
|
|
correct_predictions += 1 |
|
|
|
|
|
return correct_predictions / total_predictions if total_predictions > 0 else 0 |
|
|
|
|
|
def compute_mahalanobis_distance( |
|
|
self, |
|
|
point: np.ndarray, |
|
|
centroid: np.ndarray, |
|
|
cov_inv: np.ndarray |
|
|
) -> float: |
|
|
""" |
|
|
Compute Mahalanobis distance between a point and a centroid. |
|
|
|
|
|
The Mahalanobis distance takes into account the covariance structure |
|
|
of the data, making it more robust than Euclidean distance for |
|
|
high-dimensional spaces. |
|
|
|
|
|
Args: |
|
|
point: Query point |
|
|
centroid: Class centroid |
|
|
cov_inv: Inverse covariance matrix |
|
|
|
|
|
Returns: |
|
|
Mahalanobis distance |
|
|
""" |
|
|
diff = point - centroid |
|
|
distance = np.sqrt(np.dot(np.dot(diff, cov_inv), diff.T)) |
|
|
return distance |
|
|
|
|
|
def predict_hierarchy_from_embeddings( |
|
|
self, |
|
|
embeddings: np.ndarray, |
|
|
labels: List[str], |
|
|
use_mahalanobis: bool = False |
|
|
) -> List[str]: |
|
|
""" |
|
|
Predict hierarchy from embeddings using centroid-based classification. |
|
|
|
|
|
Args: |
|
|
embeddings: Embedding vectors |
|
|
labels: Training labels for computing centroids |
|
|
use_mahalanobis: Whether to use Mahalanobis distance |
|
|
|
|
|
Returns: |
|
|
List of predicted hierarchy labels |
|
|
""" |
|
|
|
|
|
unique_hierarchies = list(set(labels)) |
|
|
centroids = {} |
|
|
cov_inverses = {} |
|
|
|
|
|
for hierarchy in unique_hierarchies: |
|
|
hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy] |
|
|
hierarchy_embeddings = embeddings[hierarchy_indices] |
|
|
centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0) |
|
|
|
|
|
|
|
|
if use_mahalanobis and len(hierarchy_embeddings) > 1: |
|
|
cov = np.cov(hierarchy_embeddings.T) |
|
|
|
|
|
cov += np.eye(cov.shape[0]) * 1e-6 |
|
|
try: |
|
|
cov_inverses[hierarchy] = np.linalg.inv(cov) |
|
|
except np.linalg.LinAlgError: |
|
|
|
|
|
cov_inverses[hierarchy] = np.eye(cov.shape[0]) |
|
|
|
|
|
|
|
|
predictions = [] |
|
|
|
|
|
for embedding in embeddings: |
|
|
if use_mahalanobis: |
|
|
predicted_hierarchy = self._predict_with_mahalanobis( |
|
|
embedding, centroids, cov_inverses |
|
|
) |
|
|
else: |
|
|
predicted_hierarchy = self._predict_with_cosine( |
|
|
embedding, centroids |
|
|
) |
|
|
predictions.append(predicted_hierarchy) |
|
|
|
|
|
return predictions |
|
|
|
|
|
def _predict_with_mahalanobis( |
|
|
self, |
|
|
embedding: np.ndarray, |
|
|
centroids: Dict[str, np.ndarray], |
|
|
cov_inverses: Dict[str, np.ndarray] |
|
|
) -> str: |
|
|
""" |
|
|
Predict class using Mahalanobis distance (lower is better). |
|
|
|
|
|
Args: |
|
|
embedding: Query embedding |
|
|
centroids: Class centroids |
|
|
cov_inverses: Inverse covariance matrices |
|
|
|
|
|
Returns: |
|
|
Predicted class label |
|
|
""" |
|
|
best_distance = float('inf') |
|
|
predicted_hierarchy = None |
|
|
|
|
|
for hierarchy, centroid in centroids.items(): |
|
|
if hierarchy in cov_inverses: |
|
|
distance = self.compute_mahalanobis_distance( |
|
|
embedding, centroid, cov_inverses[hierarchy] |
|
|
) |
|
|
else: |
|
|
|
|
|
similarity = cosine_similarity([embedding], [centroid])[0][0] |
|
|
distance = 1 - similarity |
|
|
|
|
|
if distance < best_distance: |
|
|
best_distance = distance |
|
|
predicted_hierarchy = hierarchy |
|
|
|
|
|
return predicted_hierarchy |
|
|
|
|
|
def _predict_with_cosine( |
|
|
self, |
|
|
embedding: np.ndarray, |
|
|
centroids: Dict[str, np.ndarray] |
|
|
) -> str: |
|
|
""" |
|
|
Predict class using cosine similarity (higher is better). |
|
|
|
|
|
Args: |
|
|
embedding: Query embedding |
|
|
centroids: Class centroids |
|
|
|
|
|
Returns: |
|
|
Predicted class label |
|
|
""" |
|
|
best_similarity = -1 |
|
|
predicted_hierarchy = None |
|
|
|
|
|
for hierarchy, centroid in centroids.items(): |
|
|
similarity = cosine_similarity([embedding], [centroid])[0][0] |
|
|
if similarity > best_similarity: |
|
|
best_similarity = similarity |
|
|
predicted_hierarchy = hierarchy |
|
|
|
|
|
return predicted_hierarchy |
|
|
|
|
|
def create_confusion_matrix( |
|
|
self, |
|
|
true_labels: List[str], |
|
|
predicted_labels: List[str], |
|
|
title: str = "Confusion Matrix" |
|
|
) -> Tuple[plt.Figure, float, np.ndarray]: |
|
|
""" |
|
|
Create and plot confusion matrix. |
|
|
|
|
|
Args: |
|
|
true_labels: Ground truth labels |
|
|
predicted_labels: Predicted labels |
|
|
title: Plot title |
|
|
|
|
|
Returns: |
|
|
Tuple of (figure, accuracy, confusion_matrix) |
|
|
""" |
|
|
|
|
|
unique_labels = sorted(list(set(true_labels + predicted_labels))) |
|
|
|
|
|
|
|
|
cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels) |
|
|
|
|
|
|
|
|
accuracy = accuracy_score(true_labels, predicted_labels) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(12, 10)) |
|
|
sns.heatmap( |
|
|
cm, |
|
|
annot=True, |
|
|
fmt='d', |
|
|
cmap='Blues', |
|
|
xticklabels=unique_labels, |
|
|
yticklabels=unique_labels |
|
|
) |
|
|
plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)') |
|
|
plt.ylabel('True Hierarchy') |
|
|
plt.xlabel('Predicted Hierarchy') |
|
|
plt.xticks(rotation=45) |
|
|
plt.yticks(rotation=0) |
|
|
plt.tight_layout() |
|
|
|
|
|
return plt.gcf(), accuracy, cm |
|
|
|
|
|
def evaluate_classification_performance( |
|
|
self, |
|
|
embeddings: np.ndarray, |
|
|
labels: List[str], |
|
|
embedding_type: str = "Embeddings", |
|
|
apply_whitening_norm: bool = False, |
|
|
use_mahalanobis: bool = False |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Evaluate classification performance and create confusion matrix. |
|
|
|
|
|
Args: |
|
|
embeddings: Embedding vectors |
|
|
labels: True class labels |
|
|
embedding_type: Description of embedding type for display |
|
|
apply_whitening_norm: Whether to apply ZCA whitening |
|
|
use_mahalanobis: Whether to use Mahalanobis distance |
|
|
|
|
|
Returns: |
|
|
Dictionary containing classification metrics and visualizations |
|
|
""" |
|
|
|
|
|
if apply_whitening_norm: |
|
|
embeddings = self.apply_whitening(embeddings) |
|
|
|
|
|
|
|
|
predictions = self.predict_hierarchy_from_embeddings( |
|
|
embeddings, labels, use_mahalanobis=use_mahalanobis |
|
|
) |
|
|
|
|
|
|
|
|
accuracy = accuracy_score(labels, predictions) |
|
|
|
|
|
|
|
|
unique_labels = sorted(list(set(labels))) |
|
|
f1_macro = f1_score( |
|
|
labels, predictions, labels=unique_labels, |
|
|
average='macro', zero_division=0 |
|
|
) |
|
|
f1_weighted = f1_score( |
|
|
labels, predictions, labels=unique_labels, |
|
|
average='weighted', zero_division=0 |
|
|
) |
|
|
f1_per_class = f1_score( |
|
|
labels, predictions, labels=unique_labels, |
|
|
average=None, zero_division=0 |
|
|
) |
|
|
|
|
|
|
|
|
fig, acc, cm = self.create_confusion_matrix( |
|
|
labels, predictions, |
|
|
f"{embedding_type} - Hierarchy Classification" |
|
|
) |
|
|
|
|
|
|
|
|
report = classification_report( |
|
|
labels, predictions, labels=unique_labels, |
|
|
target_names=unique_labels, output_dict=True |
|
|
) |
|
|
|
|
|
return { |
|
|
'accuracy': accuracy, |
|
|
'f1_macro': f1_macro, |
|
|
'f1_weighted': f1_weighted, |
|
|
'f1_per_class': f1_per_class, |
|
|
'predictions': predictions, |
|
|
'confusion_matrix': cm, |
|
|
'classification_report': report, |
|
|
'figure': fig |
|
|
} |
|
|
|
|
|
def evaluate_dataset_with_baselines( |
|
|
self, |
|
|
dataframe: Union[pd.DataFrame, Dataset], |
|
|
dataset_name: str = "Dataset", |
|
|
use_whitening: bool = False, |
|
|
use_mahalanobis: bool = False |
|
|
) -> Dict[str, Dict[str, Any]]: |
|
|
""" |
|
|
Evaluate embeddings on a given dataset with both custom model and CLIP baseline. |
|
|
|
|
|
This is the main evaluation method that compares the custom model against |
|
|
the Fashion-CLIP baseline across multiple metrics and embedding types. |
|
|
Aligned with main_model_evaluation.py for consistency (no TTA for fair comparison). |
|
|
|
|
|
Args: |
|
|
dataframe: DataFrame or Dataset to evaluate on |
|
|
dataset_name: Name of the dataset for display |
|
|
use_whitening: Whether to apply ZCA whitening |
|
|
use_mahalanobis: Whether to use Mahalanobis distance |
|
|
|
|
|
Returns: |
|
|
Dictionary containing results for all models and embedding types |
|
|
""" |
|
|
print(f"\n{'='*60}") |
|
|
print(f"Evaluating {dataset_name}") |
|
|
if use_whitening: |
|
|
print(f"π― ZCA Whitening ENABLED for better feature decorrelation") |
|
|
if use_mahalanobis: |
|
|
print(f"π― Mahalanobis Distance ENABLED for classification") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
print(f"\nπ§ Evaluating Custom Model on {dataset_name}") |
|
|
print("-" * 40) |
|
|
|
|
|
|
|
|
custom_dataloader = self.create_dataloader(dataframe, batch_size=16) |
|
|
|
|
|
|
|
|
text_embeddings, text_labels, texts = self.extract_custom_embeddings( |
|
|
custom_dataloader, 'text', use_tta=False |
|
|
) |
|
|
text_metrics = self.compute_similarity_metrics( |
|
|
text_embeddings, text_labels, apply_whitening_norm=use_whitening |
|
|
) |
|
|
text_classification = self.evaluate_classification_performance( |
|
|
text_embeddings, text_labels, "Custom Text Embeddings", |
|
|
apply_whitening_norm=use_whitening, use_mahalanobis=use_mahalanobis |
|
|
) |
|
|
text_metrics.update(text_classification) |
|
|
results['custom_text'] = text_metrics |
|
|
|
|
|
|
|
|
|
|
|
image_embeddings, image_labels, _ = self.extract_custom_embeddings( |
|
|
custom_dataloader, 'image', use_tta=False |
|
|
) |
|
|
image_metrics = self.compute_similarity_metrics( |
|
|
image_embeddings, image_labels, apply_whitening_norm=use_whitening |
|
|
) |
|
|
whitening_suffix = " + Whitening" if use_whitening else "" |
|
|
mahalanobis_suffix = " + Mahalanobis" if use_mahalanobis else "" |
|
|
image_classification = self.evaluate_classification_performance( |
|
|
image_embeddings, image_labels, |
|
|
f"Custom Image Embeddings{whitening_suffix}{mahalanobis_suffix}", |
|
|
apply_whitening_norm=use_whitening, use_mahalanobis=use_mahalanobis |
|
|
) |
|
|
image_metrics.update(image_classification) |
|
|
results['custom_image'] = image_metrics |
|
|
|
|
|
|
|
|
print(f"\nπ€ Evaluating Fashion-CLIP Baseline on {dataset_name}") |
|
|
print("-" * 40) |
|
|
|
|
|
|
|
|
clip_dataloader = self.create_clip_dataloader(dataframe, batch_size=8) |
|
|
|
|
|
|
|
|
all_images = [] |
|
|
all_texts = [] |
|
|
all_labels = [] |
|
|
|
|
|
for batch in tqdm(clip_dataloader, desc="Preparing data for Fashion-CLIP"): |
|
|
|
|
|
if len(batch) == 4: |
|
|
images, descriptions, colors, hierarchies = batch |
|
|
else: |
|
|
images, descriptions, hierarchies = batch |
|
|
|
|
|
all_images.extend(images) |
|
|
all_texts.extend(descriptions) |
|
|
all_labels.extend(hierarchies) |
|
|
|
|
|
|
|
|
clip_image_embeddings, clip_text_embeddings = self.clip_evaluator.extract_clip_embeddings( |
|
|
all_images, all_texts |
|
|
) |
|
|
|
|
|
|
|
|
clip_text_metrics = self.compute_similarity_metrics( |
|
|
clip_text_embeddings, all_labels |
|
|
) |
|
|
clip_text_classification = self.evaluate_classification_performance( |
|
|
clip_text_embeddings, all_labels, "Fashion-CLIP Text Embeddings" |
|
|
) |
|
|
clip_text_metrics.update(clip_text_classification) |
|
|
results['clip_text'] = clip_text_metrics |
|
|
|
|
|
|
|
|
clip_image_metrics = self.compute_similarity_metrics( |
|
|
clip_image_embeddings, all_labels |
|
|
) |
|
|
clip_image_classification = self.evaluate_classification_performance( |
|
|
clip_image_embeddings, all_labels, "Fashion-CLIP Image Embeddings" |
|
|
) |
|
|
clip_image_metrics.update(clip_image_classification) |
|
|
results['clip_image'] = clip_image_metrics |
|
|
|
|
|
|
|
|
self._print_comparison_results(dataframe, dataset_name, results) |
|
|
|
|
|
|
|
|
self._save_visualizations(dataset_name, results) |
|
|
|
|
|
return results |
|
|
|
|
|
def _print_comparison_results( |
|
|
self, |
|
|
dataframe: Union[pd.DataFrame, Dataset], |
|
|
dataset_name: str, |
|
|
results: Dict[str, Dict[str, Any]] |
|
|
): |
|
|
""" |
|
|
Print formatted comparison results. |
|
|
|
|
|
Args: |
|
|
dataframe: Dataset being evaluated |
|
|
dataset_name: Name of the dataset |
|
|
results: Evaluation results dictionary |
|
|
""" |
|
|
dataset_size = len(dataframe) if hasattr(dataframe, '__len__') else "N/A" |
|
|
|
|
|
print(f"\n{dataset_name} Results Comparison:") |
|
|
print(f"Dataset size: {dataset_size} samples") |
|
|
print("=" * 80) |
|
|
print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<10} {'NN Acc':<8} {'Centroid Acc':<12} {'F1 Macro':<10}") |
|
|
print("-" * 80) |
|
|
|
|
|
for model_type in ['custom', 'clip']: |
|
|
for emb_type in ['text', 'image']: |
|
|
key = f"{model_type}_{emb_type}" |
|
|
if key in results: |
|
|
metrics = results[key] |
|
|
model_name = "Custom Model" if model_type == 'custom' else "Fashion-CLIP Baseline" |
|
|
print( |
|
|
f"{model_name:<20} " |
|
|
f"{emb_type.capitalize():<10} " |
|
|
f"{metrics['separation_score']:<10.4f} " |
|
|
f"{metrics['accuracy']*100:<8.1f}% " |
|
|
f"{metrics['centroid_accuracy']*100:<12.1f}% " |
|
|
f"{metrics['f1_macro']*100:<10.1f}%" |
|
|
) |
|
|
|
|
|
def _save_visualizations( |
|
|
self, |
|
|
dataset_name: str, |
|
|
results: Dict[str, Dict[str, Any]] |
|
|
): |
|
|
""" |
|
|
Save confusion matrices and other visualizations. |
|
|
|
|
|
Args: |
|
|
dataset_name: Name of the dataset |
|
|
results: Evaluation results dictionary |
|
|
""" |
|
|
os.makedirs(self.directory, exist_ok=True) |
|
|
|
|
|
|
|
|
for key, metrics in results.items(): |
|
|
if 'figure' in metrics: |
|
|
filename = f'{self.directory}/{dataset_name.lower()}_{key}_confusion_matrix.png' |
|
|
metrics['figure'].savefig(filename, dpi=300, bbox_inches='tight') |
|
|
plt.close(metrics['figure']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_fashion_mnist_dataset( |
|
|
evaluator: EmbeddingEvaluator, |
|
|
max_samples: int = 1000 |
|
|
) -> FashionMNISTDataset: |
|
|
""" |
|
|
Load and prepare Fashion-MNIST test dataset. |
|
|
|
|
|
This function loads the Fashion-MNIST test set and creates appropriate |
|
|
mappings to the custom model's hierarchy classes. |
|
|
Exactly aligned with main_model_evaluation.py for consistency. |
|
|
|
|
|
Args: |
|
|
evaluator: EmbeddingEvaluator instance with loaded model |
|
|
max_samples: Maximum number of samples to use |
|
|
|
|
|
Returns: |
|
|
FashionMNISTDataset object |
|
|
""" |
|
|
print("π Loading Fashion-MNIST test dataset...") |
|
|
df = pd.read_csv(config.fashion_mnist_test_path) |
|
|
print(f"β
Fashion-MNIST dataset loaded: {len(df)} samples") |
|
|
|
|
|
|
|
|
label_mapping = None |
|
|
if evaluator.hierarchy_classes is not None: |
|
|
print("\nπ Creating mapping from Fashion-MNIST labels to hierarchy classes:") |
|
|
label_mapping = create_fashion_mnist_to_hierarchy_mapping( |
|
|
evaluator.hierarchy_classes |
|
|
) |
|
|
|
|
|
|
|
|
valid_label_ids = [ |
|
|
label_id for label_id, hierarchy in label_mapping.items() |
|
|
if hierarchy is not None |
|
|
] |
|
|
df_filtered = df[df['label'].isin(valid_label_ids)] |
|
|
print( |
|
|
f"\nπ After filtering to mappable labels: " |
|
|
f"{len(df_filtered)} samples (from {len(df)})" |
|
|
) |
|
|
|
|
|
|
|
|
df_sample = df_filtered.head(max_samples) |
|
|
else: |
|
|
df_sample = df.head(max_samples) |
|
|
|
|
|
print(f"π Using {len(df_sample)} samples for evaluation") |
|
|
return FashionMNISTDataset(df_sample, label_mapping=label_mapping) |
|
|
|
|
|
|
|
|
def load_kagl_marqo_dataset(evaluator: EmbeddingEvaluator) -> pd.DataFrame: |
|
|
""" |
|
|
Load and prepare Kaggle Marqo dataset for evaluation. |
|
|
|
|
|
This function loads the Marqo fashion dataset from Hugging Face |
|
|
and preprocesses it for evaluation with the custom model. |
|
|
|
|
|
Args: |
|
|
evaluator: EmbeddingEvaluator instance with loaded model |
|
|
|
|
|
Returns: |
|
|
Formatted pandas DataFrame ready for evaluation |
|
|
""" |
|
|
from datasets import load_dataset |
|
|
|
|
|
print("π Loading Kaggle Marqo dataset...") |
|
|
|
|
|
|
|
|
dataset = load_dataset("Marqo/KAGL") |
|
|
df = dataset["data"].to_pandas() |
|
|
|
|
|
print(f"β
Dataset Kaggle loaded") |
|
|
print(f"π Before filtering: {len(df)} samples") |
|
|
print(f"π Available columns: {list(df.columns)}") |
|
|
print(f"π¨ Available categories: {sorted(df['category2'].unique())}") |
|
|
|
|
|
|
|
|
df['hierarchy'] = df['category2'].str.lower() |
|
|
df['hierarchy'] = df['hierarchy'].replace({ |
|
|
'bags': 'bag', |
|
|
'topwear': 'top', |
|
|
'flip flops': 'shoes', |
|
|
'sandal': 'shoes' |
|
|
}) |
|
|
|
|
|
|
|
|
valid_hierarchies = df['hierarchy'].dropna().unique() |
|
|
print(f"π― Valid hierarchies found: {sorted(valid_hierarchies)}") |
|
|
print(f"π― Model hierarchies: {sorted(evaluator.hierarchy_classes)}") |
|
|
|
|
|
df = df[df['hierarchy'].isin(evaluator.hierarchy_classes)] |
|
|
print(f"π After filtering to model hierarchies: {len(df)} samples") |
|
|
|
|
|
if len(df) == 0: |
|
|
print("β No samples left after hierarchy filtering.") |
|
|
return pd.DataFrame() |
|
|
|
|
|
|
|
|
df = df.dropna(subset=['text', 'image']) |
|
|
print(f"π After removing missing text/image: {len(df)} samples") |
|
|
|
|
|
|
|
|
print(f"π Sample texts:") |
|
|
for i, (text, hierarchy) in enumerate(zip(df['text'].head(3), df['hierarchy'].head(3))): |
|
|
print(f" {i+1}. [{hierarchy}] {text[:100]}...") |
|
|
|
|
|
|
|
|
max_samples = 1000 |
|
|
if len(df) > max_samples: |
|
|
print(f"β οΈ Dataset too large ({len(df)} samples), sampling to {max_samples} samples") |
|
|
df_test = df.sample(n=max_samples, random_state=42).reset_index(drop=True) |
|
|
else: |
|
|
df_test = df.copy() |
|
|
|
|
|
print(f"π After sampling: {len(df_test)} samples") |
|
|
print(f"π Samples per hierarchy:") |
|
|
for hierarchy in sorted(df_test['hierarchy'].unique()): |
|
|
count = len(df_test[df_test['hierarchy'] == hierarchy]) |
|
|
print(f" {hierarchy}: {count} samples") |
|
|
|
|
|
|
|
|
kagl_formatted = pd.DataFrame({ |
|
|
'image_url': df_test['image'], |
|
|
'text': df_test['text'], |
|
|
'hierarchy': df_test['hierarchy'] |
|
|
}) |
|
|
|
|
|
print(f"π Final dataset size: {len(kagl_formatted)} samples") |
|
|
return kagl_formatted |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
""" |
|
|
Main evaluation function that runs comprehensive evaluation across multiple datasets. |
|
|
|
|
|
This function evaluates the custom hierarchy classification model against the |
|
|
Fashion-CLIP baseline on: |
|
|
1. Validation dataset (from training data) |
|
|
2. Fashion-MNIST test dataset |
|
|
3. Kaggle Marqo dataset |
|
|
|
|
|
Results include detailed metrics, confusion matrices, and performance comparisons. |
|
|
""" |
|
|
|
|
|
directory = "hierarchy_model_analysis" |
|
|
|
|
|
print(f"π Starting evaluation with custom model: {hierarchy_model_path}") |
|
|
print(f"π€ Including Fashion-CLIP baseline comparison") |
|
|
|
|
|
|
|
|
evaluator = EmbeddingEvaluator(hierarchy_model_path, directory) |
|
|
|
|
|
print( |
|
|
f"π Final hierarchy classes after initialization: " |
|
|
f"{len(evaluator.vocab.hierarchy_classes)} classes" |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("EVALUATING VALIDATION DATASET - CUSTOM MODEL vs FASHION-CLIP BASELINE") |
|
|
print("="*60) |
|
|
val_results = evaluator.evaluate_dataset_with_baselines( |
|
|
evaluator.val_df, |
|
|
"Validation Dataset" |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("EVALUATING FASHION-MNIST TEST DATASET - CUSTOM MODEL vs FASHION-CLIP BASELINE") |
|
|
print("="*60) |
|
|
fashion_mnist_dataset = load_fashion_mnist_dataset(evaluator, max_samples=1000) |
|
|
if fashion_mnist_dataset is not None: |
|
|
|
|
|
fashion_mnist_results = evaluator.evaluate_dataset_with_baselines( |
|
|
fashion_mnist_dataset, |
|
|
"Fashion-MNIST Test Dataset", |
|
|
use_whitening=False, |
|
|
use_mahalanobis=False |
|
|
) |
|
|
else: |
|
|
fashion_mnist_results = {} |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("EVALUATING KAGGLE MARQO DATASET - CUSTOM MODEL vs FASHION-CLIP BASELINE") |
|
|
print("="*60) |
|
|
df_kagl_marqo = load_kagl_marqo_dataset(evaluator) |
|
|
if len(df_kagl_marqo) > 0: |
|
|
kagl_results = evaluator.evaluate_dataset_with_baselines( |
|
|
df_kagl_marqo, |
|
|
"Kaggle Marqo Dataset" |
|
|
) |
|
|
else: |
|
|
kagl_results = {} |
|
|
|
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print("FINAL EVALUATION SUMMARY - CUSTOM MODEL vs FASHION-CLIP BASELINE") |
|
|
print(f"{'='*80}") |
|
|
|
|
|
|
|
|
print("\nπ VALIDATION DATASET RESULTS:") |
|
|
_print_dataset_results(val_results, len(evaluator.val_df)) |
|
|
|
|
|
|
|
|
if fashion_mnist_results: |
|
|
print("\nπ FASHION-MNIST TEST DATASET RESULTS:") |
|
|
_print_dataset_results(fashion_mnist_results, 1000) |
|
|
|
|
|
|
|
|
if kagl_results: |
|
|
print("\nπ KAGGLE MARQO DATASET RESULTS:") |
|
|
_print_dataset_results( |
|
|
kagl_results, |
|
|
len(df_kagl_marqo) if df_kagl_marqo is not None else 'N/A' |
|
|
) |
|
|
|
|
|
|
|
|
print(f"\nβ
Evaluation completed! Check '{directory}/' for visualization files.") |
|
|
print(f"π Custom model hierarchy classes: {len(evaluator.vocab.hierarchy_classes)} classes") |
|
|
print(f"π€ Fashion-CLIP baseline comparison included") |
|
|
|
|
|
|
|
|
def _print_dataset_results(results: Dict[str, Dict[str, Any]], dataset_size: int): |
|
|
""" |
|
|
Print formatted results for a single dataset. |
|
|
|
|
|
Args: |
|
|
results: Dictionary containing evaluation results |
|
|
dataset_size: Number of samples in the dataset |
|
|
""" |
|
|
print(f"Dataset size: {dataset_size} samples") |
|
|
print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<12} {'NN Acc':<10} {'Centroid Acc':<12} {'F1 Macro':<10}") |
|
|
print("-" * 80) |
|
|
|
|
|
for model_type in ['custom', 'clip']: |
|
|
for emb_type in ['text', 'image']: |
|
|
key = f"{model_type}_{emb_type}" |
|
|
if key in results: |
|
|
metrics = results[key] |
|
|
model_name = "Custom Model" if model_type == 'custom' else "Fashion-CLIP Baseline" |
|
|
print( |
|
|
f"{model_name:<20} " |
|
|
f"{emb_type.capitalize():<10} " |
|
|
f"{metrics['separation_score']:<12.4f} " |
|
|
f"{metrics['accuracy']*100:<10.1f}% " |
|
|
f"{metrics['centroid_accuracy']*100:<12.1f}% " |
|
|
f"{metrics['f1_macro']*100:<10.1f}%" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|