diff --git "a/evaluation/main_model_evaluation.py" "b/evaluation/main_model_evaluation.py" new file mode 100644--- /dev/null +++ "b/evaluation/main_model_evaluation.py" @@ -0,0 +1,2057 @@ +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +import torch +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +import difflib +from sklearn.metrics.pairwise import cosine_similarity +from sklearn.metrics import confusion_matrix, classification_report, accuracy_score +from collections import defaultdict +from tqdm import tqdm +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from PIL import Image +from io import BytesIO +import warnings +warnings.filterwarnings('ignore') +from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers + +from config import main_model_path, hierarchy_model_path, color_emb_dim, hierarchy_emb_dim, local_dataset_path, column_local_image_path + + +def create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes): + """Create mapping from Fashion-MNIST labels to hierarchy classes""" + # Fashion-MNIST labels + fashion_mnist_labels = { + 0: "T-shirt/top", + 1: "Trouser", + 2: "Pullover", + 3: "Dress", + 4: "Coat", + 5: "Sandal", + 6: "Shirt", + 7: "Sneaker", + 8: "Bag", + 9: "Ankle boot", + } + + # Normalize hierarchy classes to lowercase for matching + hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] + + # Create mapping dictionary + mapping = {} + + for fm_label_id, fm_label in fashion_mnist_labels.items(): + fm_label_lower = fm_label.lower() + matched_hierarchy = None + + # Try exact match first + if fm_label_lower in hierarchy_classes_lower: + matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(fm_label_lower)] + # Try partial matches + elif any(h in fm_label_lower or fm_label_lower in h for h in hierarchy_classes_lower): + for h_class in hierarchy_classes: + h_lower = h_class.lower() + if h_lower in fm_label_lower or fm_label_lower in h_lower: + matched_hierarchy = h_class + break + # Try semantic matching + else: + # T-shirt/top -> shirt or top + if fm_label_lower in ['t-shirt/top', 'top']: + if 'top' in hierarchy_classes_lower: + matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('top')] + + # Trouser -> bottom, pants, trousers + elif 'trouser' in fm_label_lower: + for possible in ['bottom', 'pants', 'trousers', 'trouser', 'pant']: + if possible in hierarchy_classes_lower: + matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)] + break + + # Pullover -> sweater + elif 'pullover' in fm_label_lower: + for possible in ['sweater', 'pullover']: + if possible in hierarchy_classes_lower: + matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)] + break + + # Dress -> dress + elif 'dress' in fm_label_lower: + if 'dress' in hierarchy_classes_lower: + matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('dress')] + # Coat -> jacket, outerwear, coat + elif 'coat' in fm_label_lower: + for possible in ['jacket', 'outerwear', 'coat']: + if possible in hierarchy_classes_lower: + matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)] + break + # Sandal, Sneaker, Ankle boot -> shoes, shoe + elif fm_label_lower in ['sandal', 'sneaker', 'ankle boot']: + for possible in ['shoes', 'shoe', 'sandal', 'sneaker', 'boot']: + if possible in hierarchy_classes_lower: + matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)] + break + # Bag -> bag + elif 'bag' in fm_label_lower: + if 'bag' in hierarchy_classes_lower: + matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('bag')] + + if matched_hierarchy is None: + close_matches = difflib.get_close_matches(fm_label_lower, hierarchy_classes_lower, n=1, cutoff=0.6) + if close_matches: + matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(close_matches[0])] + + mapping[fm_label_id] = matched_hierarchy + if matched_hierarchy: + print(f" {fm_label} ({fm_label_id}) -> {matched_hierarchy}") + else: + print(f" ⚠️ {fm_label} ({fm_label_id}) -> NO MATCH (will be filtered out)") + + return mapping + + +def convert_fashion_mnist_to_image(pixel_values): + image_array = np.array(pixel_values).reshape(28, 28).astype(np.uint8) + image_array = np.stack([image_array] * 3, axis=-1) + image = Image.fromarray(image_array) + return image + + +def get_fashion_mnist_labels(): + return { + 0: "T-shirt/top", + 1: "Trouser", + 2: "Pullover", + 3: "Dress", + 4: "Coat", + 5: "Sandal", + 6: "Shirt", + 7: "Sneaker", + 8: "Bag", + 9: "Ankle boot", + } + + +class FashionMNISTDataset(Dataset): + def __init__(self, dataframe, image_size=224, label_mapping=None): + self.dataframe = dataframe + self.image_size = image_size + self.labels_map = get_fashion_mnist_labels() + self.label_mapping = label_mapping # Mapping from Fashion-MNIST label ID to hierarchy class + + self.transform = transforms.Compose([ + transforms.Resize((image_size, image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + def __len__(self): + return len(self.dataframe) + + def __getitem__(self, idx): + row = self.dataframe.iloc[idx] + + pixel_cols = [f"pixel{i}" for i in range(1, 785)] + pixel_values = row[pixel_cols].values + + image = convert_fashion_mnist_to_image(pixel_values) + image = self.transform(image) + + label_id = int(row['label']) + description = self.labels_map[label_id] + + color = "unknown" + # Use mapped hierarchy if available, otherwise use original label + if self.label_mapping and label_id in self.label_mapping: + hierarchy = self.label_mapping[label_id] + else: + hierarchy = self.labels_map[label_id] + + return image, description, color, hierarchy + + +def load_fashion_mnist_dataset(max_samples=1000, hierarchy_classes=None): + print("📊 Loading Fashion-MNIST test dataset...") + df = pd.read_csv("/Users/leaattiasarfati/Desktop/docs/search/old/MainModel/data/fashion-mnist_test.csv") + print(f"✅ Fashion-MNIST dataset loaded: {len(df)} samples") + + # Create mapping if hierarchy classes are provided + label_mapping = None + if hierarchy_classes is not None: + print("\n🔗 Creating mapping from Fashion-MNIST labels to hierarchy classes:") + label_mapping = create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes) + + # Filter dataset to only include samples that can be mapped to hierarchy classes + valid_label_ids = [label_id for label_id, hierarchy in label_mapping.items() if hierarchy is not None] + df_filtered = df[df['label'].isin(valid_label_ids)] + print(f"\n📊 After filtering to mappable labels: {len(df_filtered)} samples (from {len(df)})") + + # Apply max_samples limit after filtering + df_sample = df_filtered.head(max_samples) + else: + df_sample = df.head(max_samples) + + print(f"📊 Using {len(df_sample)} samples for evaluation") + return FashionMNISTDataset(df_sample, label_mapping=label_mapping) + + +def create_kaggle_marqo_to_hierarchy_mapping(kaggle_labels, hierarchy_classes): + """Create mapping from Kaggle Marqo categories to hierarchy classes""" + hierarchy_classes = list(hierarchy_classes) + hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] + + synonyms = { + 'topwear': 'top', + 'tops': 'top', + 'tee': 'top', + 'tees': 'top', + 't-shirt': 'top', + 'tshirt': 'top', + 'tshirts': 'top', + 'shirt': 'shirt', + 'shirts': 'shirt', + 'sweater': 'sweater', + 'sweaters': 'sweater', + 'outerwear': 'coat', + 'outer': 'coat', + 'coat': 'coat', + 'coats': 'coat', + 'jacket': 'coat', + 'jackets': 'coat', + 'blazer': 'coat', + 'blazers': 'coat', + 'hoodie': 'hoodie', + 'hoodies': 'hoodie', + 'bottomwear': 'bottom', + 'bottoms': 'bottom', + 'pants': 'bottom', + 'pant': 'bottom', + 'trouser': 'bottom', + 'trousers': 'bottom', + 'jeans': 'jeans', + 'denim': 'jeans', + 'short': 'shorts', + 'shorts': 'shorts', + 'skirt': 'skirt', + 'skirts': 'skirt', + 'dress': 'dress', + 'dresses': 'dress', + 'saree': 'saree', + 'lehenga': 'lehenga', + 'shoe': 'shoes', + 'shoes': 'shoes', + 'sandal': 'shoes', + 'sandals': 'shoes', + 'sneaker': 'shoes', + 'sneakers': 'shoes', + 'boot': 'shoes', + 'boots': 'shoes', + 'heel': 'shoes', + 'heels': 'shoes', + 'flip flops': 'shoes', + 'flip-flops': 'shoes', + 'loafer': 'shoes', + 'loafers': 'shoes', + 'bag': 'bag', + 'bags': 'bag', + 'backpack': 'bag', + 'backpacks': 'bag', + 'handbag': 'bag', + 'handbags': 'bag', + 'accessory': 'accessories', + 'accessories': 'accessories', + 'belt': 'belt', + 'belts': 'belt', + 'scarf': 'scarf', + 'scarves': 'scarf', + 'cap': 'cap', + 'caps': 'cap', + 'hat': 'cap', + 'hats': 'cap', + 'watch': 'watch', + 'watches': 'watch', + } + + def match_candidate(candidate): + if candidate in hierarchy_classes_lower: + return hierarchy_classes[hierarchy_classes_lower.index(candidate)] + return None + + mapping = {} + + for label in sorted(set(kaggle_labels)): + label_str = str(label) if not pd.isna(label) else '' + label_lower = label_str.strip().lower() + matched_hierarchy = None + + if not label_lower: + mapping[label_lower] = None + continue + + # Direct match or synonym substitution + candidate = synonyms.get(label_lower, label_lower) + matched_hierarchy = match_candidate(candidate) + + # Partial match with hierarchy classes + if matched_hierarchy is None: + for idx, h_lower in enumerate(hierarchy_classes_lower): + if h_lower in candidate or candidate in h_lower: + matched_hierarchy = hierarchy_classes[idx] + break + + # Token-based match (split on spaces, hyphens, slashes) + if matched_hierarchy is None: + tokens = set(candidate.replace('-', ' ').replace('/', ' ').split()) + for token in tokens: + token_candidate = synonyms.get(token, token) + matched_hierarchy = match_candidate(token_candidate) + if matched_hierarchy: + break + + # Synonym containment checks + if matched_hierarchy is None: + for synonym_key, synonym_value in synonyms.items(): + if synonym_key in candidate: + matched_hierarchy = match_candidate(synonym_value) + if matched_hierarchy: + break + + # Fallback to fuzzy matching + if matched_hierarchy is None: + close_matches = difflib.get_close_matches(candidate, hierarchy_classes_lower, n=1, cutoff=0.6) + if close_matches: + matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(close_matches[0])] + + mapping[label_lower] = matched_hierarchy + + if matched_hierarchy: + print(f" {label_str} -> {matched_hierarchy}") + else: + print(f" ⚠️ {label_str} -> NO MATCH (will be filtered out)") + + return mapping + + +class KaggleDataset(Dataset): + """Dataset class for KAGL Marqo dataset""" + def __init__(self, dataframe, image_size=224): + self.dataframe = dataframe + self.image_size = image_size + + # Transforms for validation (no augmentation) + self.val_transform = transforms.Compose([ + transforms.Resize((image_size, image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + def __len__(self): + return len(self.dataframe) + + def __getitem__(self, idx): + row = self.dataframe.iloc[idx] + + # Handle image - it should be in row['image_url'] and contain the image data as bytes + image_data = row['image_url'] + + # Check if image_data has 'bytes' key or is already PIL Image + if isinstance(image_data, dict) and 'bytes' in image_data: + image = Image.open(BytesIO(image_data['bytes'])).convert("RGB") + elif hasattr(image_data, 'convert'): # Already a PIL Image + image = image_data.convert("RGB") + else: + # Assume it's raw bytes + image = Image.open(BytesIO(image_data)).convert("RGB") + + # Apply validation transform + image = self.val_transform(image) + + # Get text and labels + description = row['text'] + color = row.get('color', 'unknown') + hierarchy = row['hierarchy'] + + return image, description, color, hierarchy + + +def load_kaggle_marqo_dataset(evaluator, max_samples=5000): + """Load and prepare Kaggle KAGL dataset with memory optimization""" + from datasets import load_dataset + print("📊 Loading Kaggle KAGL dataset...") + + # Load the dataset + dataset = load_dataset("Marqo/KAGL") + df = dataset["data"].to_pandas() + print(f"✅ Dataset Kaggle loaded") + print(f" Before filtering: {len(df)} samples") + print(f" Available columns: {list(df.columns)}") + + # Check available categories and create mapping to validation hierarchies + available_categories = sorted(df['category2'].dropna().unique()) + print(f"🎨 Available categories: {available_categories}") + + validation_hierarchies = evaluator.validation_hierarchy_classes or evaluator.hierarchy_classes + print(f"📚 Validation hierarchies: {sorted(validation_hierarchies)}") + + print("\n🔗 Creating mapping from Kaggle categories to validation hierarchies:") + category_mapping = create_kaggle_marqo_to_hierarchy_mapping(available_categories, validation_hierarchies) + + total_categories = {str(cat).strip().lower() for cat in df['category2'].dropna()} + unmapped_categories = sorted(cat for cat in total_categories if category_mapping.get(cat) is None) + if unmapped_categories: + print(f"⚠️ Categories without mapping (will be dropped): {unmapped_categories}") + + df['hierarchy'] = df['category2'].apply( + lambda cat: category_mapping.get(str(cat).strip().lower()) if pd.notna(cat) else None + ) + + before_mapping_len = len(df) + df = df[df['hierarchy'].notna()] + print(f" After mapping to validation hierarchies: {len(df)} samples (from {before_mapping_len})") + + if len(df) == 0: + print("❌ No samples left after hierarchy mapping.") + return None + + # Ensure we have text and image data + df = df.dropna(subset=['text', 'image']) + print(f" After removing missing text/image: {len(df)} samples") + + # Show sample of text data to verify quality + print(f"📝 Sample texts:") + for i, (text, hierarchy) in enumerate(zip(df['text'].head(3), df['hierarchy'].head(3))): + print(f" {i+1}. [{hierarchy}] {text[:100]}...") + + df_test = df.copy() + + # Limit to max_samples + if len(df_test) > max_samples: + df_test = df_test.head(max_samples) + + print(f"📊 After sampling: {len(df_test)} samples") + print(f" Samples per hierarchy:") + for hierarchy in sorted(df_test['hierarchy'].unique()): + count = len(df_test[df_test['hierarchy'] == hierarchy]) + print(f" {hierarchy}: {count} samples") + + # Create formatted dataset with proper column names + kaggle_formatted = pd.DataFrame({ + 'image_url': df_test['image'], # This contains image data as bytes + 'text': df_test['text'], + 'hierarchy': df_test['hierarchy'], + 'color': df_test['baseColour'].str.lower().str.replace("grey", "gray") # Use actual colors + }) + + print(f" Final dataset size: {len(kaggle_formatted)} samples") + return KaggleDataset(kaggle_formatted) + + +class LocalDataset(Dataset): + """Dataset class for local validation dataset""" + def __init__(self, dataframe, image_size=224): + self.dataframe = dataframe + self.image_size = image_size + + # Transforms for validation (no augmentation) + self.val_transform = transforms.Compose([ + transforms.Resize((image_size, image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + def __len__(self): + return len(self.dataframe) + + def __getitem__(self, idx): + row = self.dataframe.iloc[idx] + + # Load image from local path + image_path = row[column_local_image_path] + try: + image = Image.open(image_path).convert("RGB") + except Exception as e: + print(f"Error loading image at index {idx} from {image_path}: {e}") + # Create a dummy image if loading fails + image = Image.new('RGB', (224, 224), color='gray') + + # Apply validation transform + image = self.val_transform(image) + + # Get text and labels + description = row['text'] + color = row.get('color', 'unknown') + hierarchy = row['hierarchy'] + + return image, description, color, hierarchy + + +def load_local_validation_dataset(max_samples=5000): + """Load and prepare local validation dataset""" + print("📊 Loading local validation dataset...") + + if not os.path.exists(local_dataset_path): + print(f"❌ Local dataset file not found: {local_dataset_path}") + return None + + df = pd.read_csv(local_dataset_path) + print(f"✅ Dataset loaded: {len(df)} samples") + + # Filter out rows with NaN values in image path + df_clean = df.dropna(subset=[column_local_image_path]) + print(f"📊 After filtering NaN image paths: {len(df_clean)} samples") + + if len(df_clean) == 0: + print("❌ No valid samples after filtering.") + return None + + # Ensure we have required columns + required_cols = ['text', 'hierarchy'] + missing_cols = [col for col in required_cols if col not in df_clean.columns] + if missing_cols: + print(f"❌ Missing required columns: {missing_cols}") + return None + + # Limit to max_samples + if len(df_clean) > max_samples: + df_clean = df_clean.head(max_samples) + + print(f"📊 Using {len(df_clean)} samples for evaluation") + print(f" Samples per hierarchy:") + for hierarchy in sorted(df_clean['hierarchy'].unique()): + count = len(df_clean[df_clean['hierarchy'] == hierarchy]) + print(f" {hierarchy}: {count} samples") + + return LocalDataset(df_clean) + + +class ColorHierarchyEvaluator: + """Evaluate color (dims 0-15) and hierarchy (dims 16-79) embeddings on Fashion-MNIST""" + + def __init__(self, device='mps', directory='fashion_mnist_color_hierarchy_analysis'): + self.device = torch.device(device) + self.directory = directory + self.color_emb_dim = color_emb_dim + self.hierarchy_emb_dim = hierarchy_emb_dim + os.makedirs(self.directory, exist_ok=True) + + print(f"🚀 Loading main model from {main_model_path}") + if not os.path.exists(main_model_path): + raise FileNotFoundError(f"Main model file {main_model_path} not found") + + # Load hierarchy classes from hierarchy model checkpoint + print("📋 Loading hierarchy classes from hierarchy model...") + if not os.path.exists(hierarchy_model_path): + raise FileNotFoundError(f"Hierarchy model file {hierarchy_model_path} not found") + + hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=self.device) + self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', []) + print(f"✅ Found {len(self.hierarchy_classes)} hierarchy classes: {sorted(self.hierarchy_classes)}") + + self.validation_hierarchy_classes = self._load_validation_hierarchy_classes() + if self.validation_hierarchy_classes: + print(f"✅ Validation dataset hierarchies ({len(self.validation_hierarchy_classes)} classes): {sorted(self.validation_hierarchy_classes)}") + else: + print("⚠️ Unable to load validation hierarchy classes, falling back to hierarchy model classes.") + self.validation_hierarchy_classes = self.hierarchy_classes + + checkpoint = torch.load(main_model_path, map_location=self.device) + self.processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') + self.model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') + self.model.load_state_dict(checkpoint['model_state_dict']) + self.model.to(self.device) + self.model.eval() + print("✅ Main model loaded successfully") + + # Load baseline Fashion CLIP model + print("📦 Loading baseline Fashion CLIP model...") + patrick_model_name = "patrickjohncyh/fashion-clip" + self.baseline_processor = CLIPProcessor.from_pretrained(patrick_model_name) + self.baseline_model = CLIPModel_transformers.from_pretrained(patrick_model_name).to(self.device) + self.baseline_model.eval() + print("✅ Baseline Fashion CLIP model loaded successfully") + + def _load_validation_hierarchy_classes(self): + """Load hierarchy classes present in the validation dataset""" + if not os.path.exists(local_dataset_path): + print(f"⚠️ Validation dataset not found at {local_dataset_path}") + return [] + + try: + df = pd.read_csv(local_dataset_path) + except Exception as exc: + print(f"⚠️ Failed to read validation dataset: {exc}") + return [] + + if 'hierarchy' not in df.columns: + print("⚠️ Validation dataset does not contain 'hierarchy' column.") + return [] + + hierarchies = ( + df['hierarchy'] + .dropna() + .astype(str) + .str.strip() + ) + hierarchies = [h for h in hierarchies if h] + return sorted(set(hierarchies)) + + def extract_color_embeddings(self, dataloader, embedding_type='text', max_samples=10000): + """Extract color embeddings from dims 0-16""" + all_embeddings = [] + all_colors = [] + all_hierarchies = [] + + sample_count = 0 + with torch.no_grad(): + for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} color embeddings (dims 0-16)"): + if sample_count >= max_samples: + break + + images, texts, colors, hierarchies = batch + images = images.to(self.device) + images = images.expand(-1, 3, -1, -1) + + text_inputs = self.processor(text=texts, padding=True, return_tensors="pt") + text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} + + outputs = self.model(**text_inputs, pixel_values=images) + + if embedding_type == 'text': + embeddings = outputs.text_embeds + elif embedding_type == 'image': + embeddings = outputs.image_embeds + else: + embeddings = outputs.text_embeds + + # Extract only color embeddings (dims 0-16) + color_embeddings = embeddings[:, :self.color_emb_dim] + + all_embeddings.append(color_embeddings.cpu().numpy()) + all_colors.extend(colors) + all_hierarchies.extend(hierarchies) + + sample_count += len(images) + + del images, text_inputs, outputs, embeddings, color_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + return np.vstack(all_embeddings), all_colors, all_hierarchies + + def extract_hierarchy_embeddings(self, dataloader, embedding_type='text', max_samples=10000): + """Extract hierarchy embeddings from dims 16-79 (indices 16:79)""" + all_embeddings = [] + all_colors = [] + all_hierarchies = [] + + sample_count = 0 + with torch.no_grad(): + for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} hierarchy embeddings (dims 16-79)"): + if sample_count >= max_samples: + break + + images, texts, colors, hierarchies = batch + images = images.to(self.device) + images = images.expand(-1, 3, -1, -1) + + text_inputs = self.processor(text=texts, padding=True, return_tensors="pt") + text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} + + outputs = self.model(**text_inputs, pixel_values=images) + + if embedding_type == 'text': + embeddings = outputs.text_embeds + elif embedding_type == 'image': + embeddings = outputs.image_embeds + else: + embeddings = outputs.text_embeds + + # Extract hierarchy embeddings (dims 17-79 -> indices 16:79) + hierarchy_embeddings = embeddings[:, 16:79] + + all_embeddings.append(hierarchy_embeddings.cpu().numpy()) + all_colors.extend(colors) + all_hierarchies.extend(hierarchies) + + sample_count += len(images) + + del images, text_inputs, outputs, embeddings, hierarchy_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + return np.vstack(all_embeddings), all_colors, all_hierarchies + + def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000): + """Extract embeddings from baseline Fashion CLIP model""" + all_embeddings = [] + all_colors = [] + all_hierarchies = [] + + sample_count = 0 + + with torch.no_grad(): + for batch in tqdm(dataloader, desc=f"Extracting baseline {embedding_type} embeddings"): + if sample_count >= max_samples: + break + + images, texts, colors, hierarchies = batch + images = images.to(self.device) + images = images.expand(-1, 3, -1, -1) # Ensure 3 channels + + # Process text inputs with baseline processor + text_inputs = self.baseline_processor(text=texts, padding=True, return_tensors="pt") + text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} + + # Forward pass through baseline model + outputs = self.baseline_model(**text_inputs, pixel_values=images) + + # Extract embeddings based on type + if embedding_type == 'text': + embeddings = outputs.text_embeds + elif embedding_type == 'image': + embeddings = outputs.image_embeds + else: + embeddings = outputs.text_embeds + + all_embeddings.append(embeddings.cpu().numpy()) + all_colors.extend(colors) + all_hierarchies.extend(hierarchies) + + sample_count += len(images) + + # Clear GPU memory + del images, text_inputs, outputs, embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + return np.vstack(all_embeddings), all_colors, all_hierarchies + + def extract_full_embeddings(self, dataloader, embedding_type='text', max_samples=10000): + """ + Extrait TOUTES les dimensions des embeddings du modèle entraîné (pas seulement les sous-espaces spécialisés) + + Cette méthode permet de comparer les performances en utilisant toutes les dimensions disponibles, + similaire à la baseline qui utilise toutes ses dimensions. + + Différence avec extract_color_embeddings et extract_hierarchy_embeddings: + - extract_color_embeddings: utilise seulement dims 0-15 (16 dimensions) + - extract_hierarchy_embeddings: utilise seulement dims 16-79 (64 dimensions) + - extract_full_embeddings: utilise toutes les dimensions (ex: 512 dimensions) + + Cela peut améliorer les performances car toutes les informations sont disponibles. + """ + all_embeddings = [] + all_colors = [] + all_hierarchies = [] + + sample_count = 0 + with torch.no_grad(): + for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} FULL embeddings (all dims)"): + if sample_count >= max_samples: + break + + images, texts, colors, hierarchies = batch + images = images.to(self.device) + images = images.expand(-1, 3, -1, -1) + + text_inputs = self.processor(text=texts, padding=True, return_tensors="pt") + text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} + + outputs = self.model(**text_inputs, pixel_values=images) + + if embedding_type == 'text': + embeddings = outputs.text_embeds + elif embedding_type == 'image': + embeddings = outputs.image_embeds + else: + embeddings = outputs.text_embeds + + # Utiliser TOUTES les dimensions (pas seulement un sous-espace) + # Cela permet d'avoir accès à toute l'information disponible dans l'embedding + all_embeddings.append(embeddings.cpu().numpy()) + all_colors.extend(colors) + all_hierarchies.extend(hierarchies) + + sample_count += len(images) + + del images, text_inputs, outputs, embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + return np.vstack(all_embeddings), all_colors, all_hierarchies + + def compute_similarity_metrics(self, embeddings, labels): + """Compute intra-class and inter-class similarities""" + max_samples = min(5000, len(embeddings)) + if len(embeddings) > max_samples: + indices = np.random.choice(len(embeddings), max_samples, replace=False) + embeddings = embeddings[indices] + labels = [labels[i] for i in indices] + + similarities = cosine_similarity(embeddings) + + label_groups = defaultdict(list) + for i, label in enumerate(labels): + label_groups[label].append(i) + + intra_class_similarities = [] + for label, indices in label_groups.items(): + if len(indices) > 1: + for i in range(len(indices)): + for j in range(i + 1, len(indices)): + sim = similarities[indices[i], indices[j]] + intra_class_similarities.append(sim) + + inter_class_similarities = [] + labels_list = list(label_groups.keys()) + for i in range(len(labels_list)): + for j in range(i + 1, len(labels_list)): + label1_indices = label_groups[labels_list[i]] + label2_indices = label_groups[labels_list[j]] + for idx1 in label1_indices: + for idx2 in label2_indices: + sim = similarities[idx1, idx2] + inter_class_similarities.append(sim) + + nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities) + centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels) + + return { + 'intra_class_similarities': intra_class_similarities, + 'inter_class_similarities': inter_class_similarities, + 'intra_class_mean': float(np.mean(intra_class_similarities)) if intra_class_similarities else 0.0, + 'inter_class_mean': float(np.mean(inter_class_similarities)) if inter_class_similarities else 0.0, + 'separation_score': float(np.mean(intra_class_similarities) - np.mean(inter_class_similarities)) if intra_class_similarities and inter_class_similarities else 0.0, + 'accuracy': nn_accuracy, + 'centroid_accuracy': centroid_accuracy, + } + + def compute_embedding_accuracy(self, embeddings, labels, similarities): + """Compute classification accuracy using nearest neighbor""" + correct_predictions = 0 + total_predictions = len(labels) + for i in range(len(embeddings)): + true_label = labels[i] + similarities_row = similarities[i].copy() + similarities_row[i] = -1 + nearest_neighbor_idx = int(np.argmax(similarities_row)) + predicted_label = labels[nearest_neighbor_idx] + if predicted_label == true_label: + correct_predictions += 1 + return correct_predictions / total_predictions if total_predictions > 0 else 0.0 + + def compute_centroid_accuracy(self, embeddings, labels): + """Compute classification accuracy using centroids""" + unique_labels = list(set(labels)) + centroids = {} + for label in unique_labels: + label_indices = [i for i, l in enumerate(labels) if l == label] + centroids[label] = np.mean(embeddings[label_indices], axis=0) + + correct_predictions = 0 + total_predictions = len(labels) + for i, embedding in enumerate(embeddings): + true_label = labels[i] + best_similarity = -1 + predicted_label = None + for label, centroid in centroids.items(): + similarity = cosine_similarity([embedding], [centroid])[0][0] + if similarity > best_similarity: + best_similarity = similarity + predicted_label = label + if predicted_label == true_label: + correct_predictions += 1 + return correct_predictions / total_predictions if total_predictions > 0 else 0.0 + + def predict_labels_from_embeddings(self, embeddings, labels): + """Predict labels from embeddings using centroid-based classification""" + unique_labels = list(set(labels)) + centroids = {} + for label in unique_labels: + label_indices = [i for i, l in enumerate(labels) if l == label] + centroids[label] = np.mean(embeddings[label_indices], axis=0) + + predictions = [] + for i, embedding in enumerate(embeddings): + best_similarity = -1 + predicted_label = None + for label, centroid in centroids.items(): + similarity = cosine_similarity([embedding], [centroid])[0][0] + if similarity > best_similarity: + best_similarity = similarity + predicted_label = label + predictions.append(predicted_label) + return predictions + + def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix", label_type="Label"): + """Create and plot confusion matrix""" + unique_labels = sorted(list(set(true_labels + predicted_labels))) + cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels) + accuracy = accuracy_score(true_labels, predicted_labels) + plt.figure(figsize=(12, 10)) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=unique_labels, yticklabels=unique_labels) + plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)') + plt.ylabel(f'True {label_type}') + plt.xlabel(f'Predicted {label_type}') + plt.xticks(rotation=45) + plt.yticks(rotation=0) + plt.tight_layout() + return plt.gcf(), accuracy, cm + + def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label"): + """Evaluate classification performance and create confusion matrix""" + predictions = self.predict_labels_from_embeddings(embeddings, labels) + accuracy = accuracy_score(labels, predictions) + fig, acc, cm = self.create_confusion_matrix(labels, predictions, f"{embedding_type} - {label_type} Classification", label_type) + unique_labels = sorted(list(set(labels))) + report = classification_report(labels, predictions, labels=unique_labels, target_names=unique_labels, output_dict=True) + return { + 'accuracy': accuracy, + 'predictions': predictions, + 'confusion_matrix': cm, + 'classification_report': report, + 'figure': fig, + } + + def evaluate_fashion_mnist(self, max_samples): + """Evaluate both color and hierarchy embeddings on Fashion-MNIST""" + print(f"\n{'='*60}") + print("Evaluating Fashion-MNIST") + print(" Color embeddings: dims 0-15") + print(" Hierarchy embeddings: dims 16-79") + print(f"Max samples: {max_samples}") + print(f"{'='*60}") + + target_hierarchy_classes = self.validation_hierarchy_classes or self.hierarchy_classes + fashion_dataset = load_fashion_mnist_dataset(max_samples, hierarchy_classes=target_hierarchy_classes) + dataloader = DataLoader(fashion_dataset, batch_size=8, shuffle=False, num_workers=0) + + # Check hierarchy distribution after mapping + if len(fashion_dataset.dataframe) > 0: + print(f"\n📊 Hierarchy distribution in dataset:") + if fashion_dataset.label_mapping: + hierarchy_counts = {} + for _, row in fashion_dataset.dataframe.iterrows(): + label_id = int(row['label']) + hierarchy = fashion_dataset.label_mapping.get(label_id, 'unknown') + hierarchy_counts[hierarchy] = hierarchy_counts.get(hierarchy, 0) + 1 + + for hierarchy, count in sorted(hierarchy_counts.items()): + print(f" {hierarchy}: {count} samples") + + results = {} + + # ========== COLOR EVALUATION (DIMS 0-15) ========== + print("\n🎨 COLOR EVALUATION (dims 0-15)") + print("=" * 50) + + # Text color embeddings + print("\n📝 Extracting text color embeddings...") + text_color_embeddings, text_colors, _ = self.extract_color_embeddings(dataloader, 'text', max_samples) + print(f" Text color embeddings shape: {text_color_embeddings.shape}") + text_color_metrics = self.compute_similarity_metrics(text_color_embeddings, text_colors) + text_color_class = self.evaluate_classification_performance( + text_color_embeddings, text_colors, "Text Color Embeddings (16D)", "Color" + ) + text_color_metrics.update(text_color_class) + results['text_color'] = text_color_metrics + + del text_color_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Image color embeddings + print("\n🖼️ Extracting image color embeddings...") + image_color_embeddings, image_colors, _ = self.extract_color_embeddings(dataloader, 'image', max_samples) + print(f" Image color embeddings shape: {image_color_embeddings.shape}") + image_color_metrics = self.compute_similarity_metrics(image_color_embeddings, image_colors) + image_color_class = self.evaluate_classification_performance( + image_color_embeddings, image_colors, "Image Color Embeddings (16D)", "Color" + ) + image_color_metrics.update(image_color_class) + results['image_color'] = image_color_metrics + + del image_color_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # ========== HIERARCHY EVALUATION (DIMS 16-79) ========== + print("\n📋 HIERARCHY EVALUATION (dims 16-79)") + print("=" * 50) + + # Text hierarchy embeddings + print("\n📝 Extracting text hierarchy embeddings...") + text_hierarchy_embeddings, _, text_hierarchies = self.extract_hierarchy_embeddings(dataloader, 'text', max_samples) + print(f" Text hierarchy embeddings shape: {text_hierarchy_embeddings.shape}") + text_hierarchy_metrics = self.compute_similarity_metrics(text_hierarchy_embeddings, text_hierarchies) + text_hierarchy_class = self.evaluate_classification_performance( + text_hierarchy_embeddings, text_hierarchies, "Text Hierarchy Embeddings (64D)", "Hierarchy" + ) + text_hierarchy_metrics.update(text_hierarchy_class) + results['text_hierarchy'] = text_hierarchy_metrics + + del text_hierarchy_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Image hierarchy embeddings + print("\n🖼️ Extracting image hierarchy embeddings...") + image_hierarchy_embeddings, _, image_hierarchies = self.extract_hierarchy_embeddings(dataloader, 'image', max_samples) + print(f" Image hierarchy embeddings shape: {image_hierarchy_embeddings.shape}") + image_hierarchy_metrics = self.compute_similarity_metrics(image_hierarchy_embeddings, image_hierarchies) + image_hierarchy_class = self.evaluate_classification_performance( + image_hierarchy_embeddings, image_hierarchies, "Image Hierarchy Embeddings (64D)", "Hierarchy" + ) + image_hierarchy_metrics.update(image_hierarchy_class) + results['image_hierarchy'] = image_hierarchy_metrics + + del image_hierarchy_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # ========== SAVE VISUALIZATIONS ========== + os.makedirs(self.directory, exist_ok=True) + for key in ['text_color', 'image_color', 'text_hierarchy', 'image_hierarchy']: + results[key]['figure'].savefig( + f"{self.directory}/{key.replace('_', '_')}_confusion_matrix.png", + dpi=300, + bbox_inches='tight', + ) + plt.close(results[key]['figure']) + + return results + + def evaluate_kaggle_marqo(self, max_samples): + """Evaluate both color and hierarchy embeddings on KAGL Marqo dataset""" + print(f"\n{'='*60}") + print("Evaluating KAGL Marqo Dataset") + print(" Color embeddings: dims 0-15") + print(" Hierarchy embeddings: dims 16-79") + print(f"Max samples: {max_samples}") + print(f"{'='*60}") + + kaggle_dataset = load_kaggle_marqo_dataset(self, max_samples) + if kaggle_dataset is None: + print("❌ Failed to load KAGL dataset") + return None + + dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0) + + # Check hierarchy distribution + if len(kaggle_dataset.dataframe) > 0: + print(f"\n📊 Hierarchy distribution in dataset:") + hierarchy_counts = {} + for _, row in kaggle_dataset.dataframe.iterrows(): + hierarchy = row['hierarchy'] + hierarchy_counts[hierarchy] = hierarchy_counts.get(hierarchy, 0) + 1 + + for hierarchy, count in sorted(hierarchy_counts.items()): + print(f" {hierarchy}: {count} samples") + + results = {} + + # ========== COLOR EVALUATION (DIMS 0-16) ========== + print("\n🎨 COLOR EVALUATION (dims 0-16)") + print("=" * 50) + + # Text color embeddings + print("\n📝 Extracting text color embeddings...") + text_color_embeddings, text_colors, _ = self.extract_color_embeddings(dataloader, 'text', max_samples) + print(f" Text color embeddings shape: {text_color_embeddings.shape}") + text_color_metrics = self.compute_similarity_metrics(text_color_embeddings, text_colors) + text_color_class = self.evaluate_classification_performance( + text_color_embeddings, text_colors, "Text Color Embeddings (16D)", "Color" + ) + text_color_metrics.update(text_color_class) + results['text_color'] = text_color_metrics + + del text_color_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Image color embeddings + print("\n🖼️ Extracting image color embeddings...") + image_color_embeddings, image_colors, _ = self.extract_color_embeddings(dataloader, 'image', max_samples) + print(f" Image color embeddings shape: {image_color_embeddings.shape}") + image_color_metrics = self.compute_similarity_metrics(image_color_embeddings, image_colors) + image_color_class = self.evaluate_classification_performance( + image_color_embeddings, image_colors, "Image Color Embeddings (16D)", "Color" + ) + image_color_metrics.update(image_color_class) + results['image_color'] = image_color_metrics + + del image_color_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # ========== HIERARCHY EVALUATION (DIMS 16-79) ========== + print("\n📋 HIERARCHY EVALUATION (dims 16-79)") + print("=" * 50) + + # Text hierarchy embeddings + print("\n📝 Extracting text hierarchy embeddings...") + text_hierarchy_embeddings, _, text_hierarchies = self.extract_hierarchy_embeddings(dataloader, 'text', max_samples) + print(f" Text hierarchy embeddings shape: {text_hierarchy_embeddings.shape}") + text_hierarchy_metrics = self.compute_similarity_metrics(text_hierarchy_embeddings, text_hierarchies) + text_hierarchy_class = self.evaluate_classification_performance( + text_hierarchy_embeddings, text_hierarchies, "Text Hierarchy Embeddings (64D)", "Hierarchy" + ) + text_hierarchy_metrics.update(text_hierarchy_class) + results['text_hierarchy'] = text_hierarchy_metrics + + del text_hierarchy_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Image hierarchy embeddings + print("\n🖼️ Extracting image hierarchy embeddings...") + image_hierarchy_embeddings, _, image_hierarchies = self.extract_hierarchy_embeddings(dataloader, 'image', max_samples) + print(f" Image hierarchy embeddings shape: {image_hierarchy_embeddings.shape}") + image_hierarchy_metrics = self.compute_similarity_metrics(image_hierarchy_embeddings, image_hierarchies) + image_hierarchy_class = self.evaluate_classification_performance( + image_hierarchy_embeddings, image_hierarchies, "Image Hierarchy Embeddings (64D)", "Hierarchy" + ) + image_hierarchy_metrics.update(image_hierarchy_class) + results['image_hierarchy'] = image_hierarchy_metrics + + del image_hierarchy_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # ========== SAVE VISUALIZATIONS ========== + os.makedirs(self.directory, exist_ok=True) + for key in ['text_color', 'image_color', 'text_hierarchy', 'image_hierarchy']: + results[key]['figure'].savefig( + f"{self.directory}/kaggle_{key.replace('_', '_')}_confusion_matrix.png", + dpi=300, + bbox_inches='tight', + ) + plt.close(results[key]['figure']) + + return results + + def evaluate_local_validation(self, max_samples): + """Evaluate both color and hierarchy embeddings on local validation dataset""" + print(f"\n{'='*60}") + print("Evaluating Local Validation Dataset") + print(" Color embeddings: dims 0-15") + print(" Hierarchy embeddings: dims 16-79") + print(f"Max samples: {max_samples}") + print(f"{'='*60}") + + local_dataset = load_local_validation_dataset(max_samples) + if local_dataset is None: + print("❌ Failed to load local validation dataset") + return None + + # Filter to only include hierarchies that exist in our model + if len(local_dataset.dataframe) > 0: + valid_df = local_dataset.dataframe[local_dataset.dataframe['hierarchy'].isin(self.hierarchy_classes)] + if len(valid_df) == 0: + print("❌ No samples left after hierarchy filtering.") + return None + if len(valid_df) < len(local_dataset.dataframe): + print(f"📊 Filtered to model hierarchies: {len(valid_df)} samples (from {len(local_dataset.dataframe)})") + local_dataset = LocalDataset(valid_df) + + dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0) + + # Check hierarchy distribution + if len(local_dataset.dataframe) > 0: + print(f"\n📊 Hierarchy distribution in dataset:") + hierarchy_counts = {} + for _, row in local_dataset.dataframe.iterrows(): + hierarchy = row['hierarchy'] + hierarchy_counts[hierarchy] = hierarchy_counts.get(hierarchy, 0) + 1 + + for hierarchy, count in sorted(hierarchy_counts.items()): + print(f" {hierarchy}: {count} samples") + + results = {} + + # ========== COLOR EVALUATION (DIMS 0-15) ========== + print("\n🎨 COLOR EVALUATION (dims 0-15)") + print("=" * 50) + + # Text color embeddings + print("\n📝 Extracting text color embeddings...") + text_color_embeddings, text_colors, _ = self.extract_color_embeddings(dataloader, 'text', max_samples) + print(f" Text color embeddings shape: {text_color_embeddings.shape}") + text_color_metrics = self.compute_similarity_metrics(text_color_embeddings, text_colors) + text_color_class = self.evaluate_classification_performance( + text_color_embeddings, text_colors, "Text Color Embeddings (16D)", "Color" + ) + text_color_metrics.update(text_color_class) + results['text_color'] = text_color_metrics + + del text_color_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Image color embeddings + print("\n🖼️ Extracting image color embeddings...") + image_color_embeddings, image_colors, _ = self.extract_color_embeddings(dataloader, 'image', max_samples) + print(f" Image color embeddings shape: {image_color_embeddings.shape}") + image_color_metrics = self.compute_similarity_metrics(image_color_embeddings, image_colors) + image_color_class = self.evaluate_classification_performance( + image_color_embeddings, image_colors, "Image Color Embeddings (16D)", "Color" + ) + image_color_metrics.update(image_color_class) + results['image_color'] = image_color_metrics + + del image_color_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # ========== HIERARCHY EVALUATION (DIMS 16-79) ========== + print("\n📋 HIERARCHY EVALUATION (dims 16-79)") + print("=" * 50) + + # Text hierarchy embeddings + print("\n📝 Extracting text hierarchy embeddings...") + text_hierarchy_embeddings, _, text_hierarchies = self.extract_hierarchy_embeddings(dataloader, 'text', max_samples) + print(f" Text hierarchy embeddings shape: {text_hierarchy_embeddings.shape}") + text_hierarchy_metrics = self.compute_similarity_metrics(text_hierarchy_embeddings, text_hierarchies) + text_hierarchy_class = self.evaluate_classification_performance( + text_hierarchy_embeddings, text_hierarchies, "Text Hierarchy Embeddings (64D)", "Hierarchy" + ) + text_hierarchy_metrics.update(text_hierarchy_class) + results['text_hierarchy'] = text_hierarchy_metrics + + del text_hierarchy_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Image hierarchy embeddings + print("\n🖼️ Extracting image hierarchy embeddings...") + image_hierarchy_embeddings, _, image_hierarchies = self.extract_hierarchy_embeddings(dataloader, 'image', max_samples) + print(f" Image hierarchy embeddings shape: {image_hierarchy_embeddings.shape}") + image_hierarchy_metrics = self.compute_similarity_metrics(image_hierarchy_embeddings, image_hierarchies) + image_hierarchy_class = self.evaluate_classification_performance( + image_hierarchy_embeddings, image_hierarchies, "Image Hierarchy Embeddings (64D)", "Hierarchy" + ) + image_hierarchy_metrics.update(image_hierarchy_class) + results['image_hierarchy'] = image_hierarchy_metrics + + del image_hierarchy_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # ========== SAVE VISUALIZATIONS ========== + os.makedirs(self.directory, exist_ok=True) + for key in ['text_color', 'image_color', 'text_hierarchy', 'image_hierarchy']: + results[key]['figure'].savefig( + f"{self.directory}/local_{key.replace('_', '_')}_confusion_matrix.png", + dpi=300, + bbox_inches='tight', + ) + plt.close(results[key]['figure']) + + return results + + def evaluate_full_embeddings(self, dataloader, dataset_name, max_samples=10000): + """ + Evaluate using ALL 512 dimensions from our trained model (not just specialized subspaces) + This allows fair comparison with baseline which uses all 512 dimensions. + """ + print(f"\n{'='*60}") + print(f"Evaluating {dataset_name} with FULL 512-dimensional embeddings (Our Model)") + print(f"Max samples: {max_samples}") + print(f"{'='*60}") + + results = {} + + # ========== COLOR EVALUATION WITH FULL EMBEDDINGS ========== + print("\n🎨 COLOR EVALUATION (512 dims - Full Embeddings)") + print("=" * 50) + + # Text color embeddings + print("\n📝 Extracting text FULL embeddings for color classification...") + text_full_embeddings, text_colors, _ = self.extract_full_embeddings(dataloader, 'text', max_samples) + print(f" Text full embeddings shape: {text_full_embeddings.shape} (using all {text_full_embeddings.shape[1]} dimensions)") + text_color_metrics = self.compute_similarity_metrics(text_full_embeddings, text_colors) + text_color_class = self.evaluate_classification_performance( + text_full_embeddings, text_colors, "Text Full Embeddings (512D) - Color", "Color" + ) + text_color_metrics.update(text_color_class) + results['text_color'] = text_color_metrics + + del text_full_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Image color embeddings + print("\n🖼️ Extracting image FULL embeddings for color classification...") + image_full_embeddings, image_colors, _ = self.extract_full_embeddings(dataloader, 'image', max_samples) + print(f" Image full embeddings shape: {image_full_embeddings.shape} (using all {image_full_embeddings.shape[1]} dimensions)") + image_color_metrics = self.compute_similarity_metrics(image_full_embeddings, image_colors) + image_color_class = self.evaluate_classification_performance( + image_full_embeddings, image_colors, "Image Full Embeddings (512D) - Color", "Color" + ) + image_color_metrics.update(image_color_class) + results['image_color'] = image_color_metrics + + del image_full_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # ========== HIERARCHY EVALUATION WITH FULL EMBEDDINGS ========== + print("\n📋 HIERARCHY EVALUATION (512 dims - Full Embeddings)") + print("=" * 50) + + # Text hierarchy embeddings + print("\n📝 Extracting text FULL embeddings for hierarchy classification...") + text_full_embeddings, _, text_hierarchies = self.extract_full_embeddings(dataloader, 'text', max_samples) + print(f" Text full embeddings shape: {text_full_embeddings.shape} (using all {text_full_embeddings.shape[1]} dimensions)") + text_hierarchy_metrics = self.compute_similarity_metrics(text_full_embeddings, text_hierarchies) + text_hierarchy_class = self.evaluate_classification_performance( + text_full_embeddings, text_hierarchies, "Text Full Embeddings (512D) - Hierarchy", "Hierarchy" + ) + text_hierarchy_metrics.update(text_hierarchy_class) + results['text_hierarchy'] = text_hierarchy_metrics + + del text_full_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Image hierarchy embeddings + print("\n🖼️ Extracting image FULL embeddings for hierarchy classification...") + image_full_embeddings, _, image_hierarchies = self.extract_full_embeddings(dataloader, 'image', max_samples) + print(f" Image full embeddings shape: {image_full_embeddings.shape} (using all {image_full_embeddings.shape[1]} dimensions)") + image_hierarchy_metrics = self.compute_similarity_metrics(image_full_embeddings, image_hierarchies) + image_hierarchy_class = self.evaluate_classification_performance( + image_full_embeddings, image_hierarchies, "Image Full Embeddings (512D) - Hierarchy", "Hierarchy" + ) + image_hierarchy_metrics.update(image_hierarchy_class) + results['image_hierarchy'] = image_hierarchy_metrics + + del image_full_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # ========== SAVE VISUALIZATIONS ========== + os.makedirs(self.directory, exist_ok=True) + dataset_prefix = dataset_name.lower().replace(' ', '_').replace('-', '_') + for key in ['text_color', 'image_color', 'text_hierarchy', 'image_hierarchy']: + results[key]['figure'].savefig( + f"{self.directory}/{dataset_prefix}_full_{key.replace('_', '_')}_confusion_matrix.png", + dpi=300, + bbox_inches='tight', + ) + plt.close(results[key]['figure']) + + return results + + def compare_subspace_vs_full_embeddings(self, results_subspace, results_full, dataset_name): + """ + Compare performance between specialized subspaces (16/64 dims) vs full embeddings (512 dims) + """ + print(f"\n{'='*60}") + print(f"📊 COMPARISON: Subspace vs Full Embeddings - {dataset_name}") + print(f"{'='*60}") + + comparisons = [] + + # Text Color + subspace_color_text_acc = results_subspace.get('text_color', {}).get('accuracy', 0) + full_color_text_acc = results_full.get('text_color', {}).get('accuracy', 0) + if subspace_color_text_acc > 0 and full_color_text_acc > 0: + diff = full_color_text_acc - subspace_color_text_acc + comparisons.append({ + 'type': 'Text Color', + 'subspace': subspace_color_text_acc, + 'full': full_color_text_acc, + 'diff': diff, + 'subspace_dims': '0-15 (16 dims)', + 'full_dims': 'All 512 dims' + }) + + # Image Color + subspace_color_img_acc = results_subspace.get('image_color', {}).get('accuracy', 0) + full_color_img_acc = results_full.get('image_color', {}).get('accuracy', 0) + if subspace_color_img_acc > 0 and full_color_img_acc > 0: + diff = full_color_img_acc - subspace_color_img_acc + comparisons.append({ + 'type': 'Image Color', + 'subspace': subspace_color_img_acc, + 'full': full_color_img_acc, + 'diff': diff, + 'subspace_dims': '0-15 (16 dims)', + 'full_dims': 'All 512 dims' + }) + + # Text Hierarchy + subspace_hier_text_acc = results_subspace.get('text_hierarchy', {}).get('accuracy', 0) + full_hier_text_acc = results_full.get('text_hierarchy', {}).get('accuracy', 0) + if subspace_hier_text_acc > 0 and full_hier_text_acc > 0: + diff = full_hier_text_acc - subspace_hier_text_acc + comparisons.append({ + 'type': 'Text Hierarchy', + 'subspace': subspace_hier_text_acc, + 'full': full_hier_text_acc, + 'diff': diff, + 'subspace_dims': '16-79 (64 dims)', + 'full_dims': 'All 512 dims' + }) + + # Image Hierarchy + subspace_hier_img_acc = results_subspace.get('image_hierarchy', {}).get('accuracy', 0) + full_hier_img_acc = results_full.get('image_hierarchy', {}).get('accuracy', 0) + if subspace_hier_img_acc > 0 and full_hier_img_acc > 0: + diff = full_hier_img_acc - subspace_hier_img_acc + comparisons.append({ + 'type': 'Image Hierarchy', + 'subspace': subspace_hier_img_acc, + 'full': full_hier_img_acc, + 'diff': diff, + 'subspace_dims': '16-79 (64 dims)', + 'full_dims': 'All 512 dims' + }) + + # Display comparisons + print("\n📈 PERFORMANCE COMPARISON:") + print("-" * 60) + for comp in comparisons: + better = "✅ Full (512D)" if comp['diff'] > 0 else "✅ Subspace" + print(f"\n{comp['type']}:") + print(f" Subspace ({comp['subspace_dims']}): {comp['subspace']*100:.2f}%") + print(f" Full ({comp['full_dims']}): {comp['full']*100:.2f}%") + print(f" Difference: {comp['diff']*100:+.2f}% → {better}") + + print(f"\n{'='*60}") + print("💡 INTERPRETATION:") + print(f"{'='*60}") + full_better_count = sum(1 for c in comparisons if c['diff'] > 0) + + if full_better_count > len(comparisons) / 2: + print("\n✅ Full embeddings (512D) perform better on most metrics.") + print(" This suggests that using all dimensions provides more information") + print(" for classification, even though specialized subspaces offer interpretability.") + else: + print("\n✅ Specialized subspaces perform competitively or better.") + print(" This validates the effectiveness of dimensional specialization") + print(" while maintaining interpretability advantages.") + + print("\n📊 Trade-off summary:") + print(" • Subspace (16/64 dims): Better interpretability, task-specific") + print(" • Full (512 dims): More information, potentially better accuracy") + print(" • Use case: Subspace for explainability, Full for maximum performance") + + return comparisons + + def evaluate_baseline_fashion_mnist(self, max_samples=1000): + """Evaluate baseline Fashion CLIP model on Fashion-MNIST""" + print(f"\n{'='*60}") + print("Evaluating Baseline Fashion CLIP on Fashion-MNIST") + print(f"Max samples: {max_samples}") + print(f"{'='*60}") + + # Load Fashion-MNIST dataset + target_hierarchy_classes = self.validation_hierarchy_classes or self.hierarchy_classes + fashion_dataset = load_fashion_mnist_dataset(max_samples, hierarchy_classes=target_hierarchy_classes) + + # Create dataloader for Fashion-MNIST + dataloader = DataLoader( + fashion_dataset, + batch_size=8, + shuffle=False, + num_workers=0 + ) + + results = {} + + # Evaluate text embeddings + print("\n📝 Extracting baseline text embeddings from Fashion-MNIST...") + text_embeddings, text_colors, text_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) + print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)") + text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors) + text_hierarchy_metrics = self.compute_similarity_metrics(text_embeddings, text_hierarchies) + + text_color_classification = self.evaluate_classification_performance( + text_embeddings, text_colors, "Baseline Fashion-MNIST Text Embeddings - Color", "Color" + ) + text_hierarchy_classification = self.evaluate_classification_performance( + text_embeddings, text_hierarchies, "Baseline Fashion-MNIST Text Embeddings - Hierarchy", "Hierarchy" + ) + + text_color_metrics.update(text_color_classification) + text_hierarchy_metrics.update(text_hierarchy_classification) + results['text'] = { + 'color': text_color_metrics, + 'hierarchy': text_hierarchy_metrics + } + + # Clear memory + del text_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Evaluate image embeddings + print("\n🖼️ Extracting baseline image embeddings from Fashion-MNIST...") + image_embeddings, image_colors, image_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) + print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)") + image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors) + image_hierarchy_metrics = self.compute_similarity_metrics(image_embeddings, image_hierarchies) + + image_color_classification = self.evaluate_classification_performance( + image_embeddings, image_colors, "Baseline Fashion-MNIST Image Embeddings - Color", "Color" + ) + image_hierarchy_classification = self.evaluate_classification_performance( + image_embeddings, image_hierarchies, "Baseline Fashion-MNIST Image Embeddings - Hierarchy", "Hierarchy" + ) + + image_color_metrics.update(image_color_classification) + image_hierarchy_metrics.update(image_hierarchy_classification) + results['image'] = { + 'color': image_color_metrics, + 'hierarchy': image_hierarchy_metrics + } + + # Clear memory + del image_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # ========== SAVE VISUALIZATIONS ========== + os.makedirs(self.directory, exist_ok=True) + for key in ['text', 'image']: + for subkey in ['color', 'hierarchy']: + figure = results[key][subkey]['figure'] + figure.savefig( + f"{self.directory}/fashion_baseline_{key}_{subkey}_confusion_matrix.png", + dpi=300, + bbox_inches='tight', + ) + plt.close(figure) + + return results + + def evaluate_baseline_kaggle_marqo(self, max_samples=5000): + """Evaluate baseline Fashion CLIP model on KAGL Marqo dataset""" + print(f"\n{'='*60}") + print("Evaluating Baseline Fashion CLIP on KAGL Marqo Dataset") + print(f"Max samples: {max_samples}") + print(f"{'='*60}") + + # Load KAGL Marqo dataset + kaggle_dataset = load_kaggle_marqo_dataset(self, max_samples) + if kaggle_dataset is None: + print("❌ Failed to load KAGL dataset") + return None + + # Create dataloader + dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0) + + results = {} + + # Evaluate text embeddings + print("\n📝 Extracting baseline text embeddings from KAGL Marqo...") + text_embeddings, text_colors, text_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) + print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)") + text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors) + text_hierarchy_metrics = self.compute_similarity_metrics(text_embeddings, text_hierarchies) + + text_color_classification = self.evaluate_classification_performance( + text_embeddings, text_colors, "Baseline KAGL Marqo Text Embeddings - Color", "Color" + ) + text_hierarchy_classification = self.evaluate_classification_performance( + text_embeddings, text_hierarchies, "Baseline KAGL Marqo Text Embeddings - Hierarchy", "Hierarchy" + ) + + text_color_metrics.update(text_color_classification) + text_hierarchy_metrics.update(text_hierarchy_classification) + results['text'] = { + 'color': text_color_metrics, + 'hierarchy': text_hierarchy_metrics + } + + # Clear memory + del text_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Evaluate image embeddings + print("\n🖼️ Extracting baseline image embeddings from KAGL Marqo...") + image_embeddings, image_colors, image_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) + print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)") + image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors) + image_hierarchy_metrics = self.compute_similarity_metrics(image_embeddings, image_hierarchies) + + image_color_classification = self.evaluate_classification_performance( + image_embeddings, image_colors, "Baseline KAGL Marqo Image Embeddings - Color", "Color" + ) + image_hierarchy_classification = self.evaluate_classification_performance( + image_embeddings, image_hierarchies, "Baseline KAGL Marqo Image Embeddings - Hierarchy", "Hierarchy" + ) + + image_color_metrics.update(image_color_classification) + image_hierarchy_metrics.update(image_hierarchy_classification) + results['image'] = { + 'color': image_color_metrics, + 'hierarchy': image_hierarchy_metrics + } + + # Clear memory + del image_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # ========== SAVE VISUALIZATIONS ========== + os.makedirs(self.directory, exist_ok=True) + for key in ['text', 'image']: + for subkey in ['color', 'hierarchy']: + figure = results[key][subkey]['figure'] + figure.savefig( + f"{self.directory}/kaggle_baseline_{key}_{subkey}_confusion_matrix.png", + dpi=300, + bbox_inches='tight', + ) + plt.close(figure) + + return results + + def evaluate_baseline_local_validation(self, max_samples=5000): + """Evaluate baseline Fashion CLIP model on local validation dataset""" + print(f"\n{'='*60}") + print("Evaluating Baseline Fashion CLIP on Local Validation Dataset") + print(f"Max samples: {max_samples}") + print(f"{'='*60}") + + # Load local validation dataset + local_dataset = load_local_validation_dataset(max_samples) + if local_dataset is None: + print("❌ Failed to load local validation dataset") + return None + + # Filter to only include hierarchies that exist in our model + if len(local_dataset.dataframe) > 0: + valid_df = local_dataset.dataframe[local_dataset.dataframe['hierarchy'].isin(self.hierarchy_classes)] + if len(valid_df) == 0: + print("❌ No samples left after hierarchy filtering.") + return None + if len(valid_df) < len(local_dataset.dataframe): + print(f"📊 Filtered to model hierarchies: {len(valid_df)} samples (from {len(local_dataset.dataframe)})") + local_dataset = LocalDataset(valid_df) + + # Create dataloader + dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0) + + results = {} + + # Evaluate text embeddings + print("\n📝 Extracting baseline text embeddings from Local Validation...") + text_embeddings, text_colors, text_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) + print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)") + text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors) + text_hierarchy_metrics = self.compute_similarity_metrics(text_embeddings, text_hierarchies) + + text_color_classification = self.evaluate_classification_performance( + text_embeddings, text_colors, "Baseline Local Validation Text Embeddings - Color", "Color" + ) + text_hierarchy_classification = self.evaluate_classification_performance( + text_embeddings, text_hierarchies, "Baseline Local Validation Text Embeddings - Hierarchy", "Hierarchy" + ) + + text_color_metrics.update(text_color_classification) + text_hierarchy_metrics.update(text_hierarchy_classification) + results['text'] = { + 'color': text_color_metrics, + 'hierarchy': text_hierarchy_metrics + } + + # Clear memory + del text_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Evaluate image embeddings + print("\n🖼️ Extracting baseline image embeddings from Local Validation...") + image_embeddings, image_colors, image_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) + print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)") + image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors) + image_hierarchy_metrics = self.compute_similarity_metrics(image_embeddings, image_hierarchies) + + image_color_classification = self.evaluate_classification_performance( + image_embeddings, image_colors, "Baseline Local Validation Image Embeddings - Color", "Color" + ) + image_hierarchy_classification = self.evaluate_classification_performance( + image_embeddings, image_hierarchies, "Baseline Local Validation Image Embeddings - Hierarchy", "Hierarchy" + ) + + image_color_metrics.update(image_color_classification) + image_hierarchy_metrics.update(image_hierarchy_classification) + results['image'] = { + 'color': image_color_metrics, + 'hierarchy': image_hierarchy_metrics + } + + # Clear memory + del image_embeddings + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # ========== SAVE VISUALIZATIONS ========== + os.makedirs(self.directory, exist_ok=True) + for key in ['text', 'image']: + for subkey in ['color', 'hierarchy']: + figure = results[key][subkey]['figure'] + figure.savefig( + f"{self.directory}/local_baseline_{key}_{subkey}_confusion_matrix.png", + dpi=300, + bbox_inches='tight', + ) + plt.close(figure) + + return results + + def analyze_baseline_vs_trained_performance(self, results_trained, results_baseline, dataset_name): + """ + Analyse et explique pourquoi la baseline peut performer mieux que le modèle entraîné + + Raisons possibles: + 1. Capacité dimensionnelle: Baseline utilise toutes les dimensions (512), modèle entraîné utilise seulement des sous-espaces (17 ou 64 dims) + 2. Distribution shift: Dataset de validation différent de celui d'entraînement + 3. Overfitting: Modèle trop spécialisé sur le dataset d'entraînement + 4. Généralisation: Baseline pré-entraînée sur un dataset plus large et diversifié + 5. Perte d'information: Spécialisation excessive peut causer perte d'information générale + """ + print(f"\n{'='*60}") + print(f"📊 ANALYSE: Baseline vs Modèle Entraîné - {dataset_name}") + print(f"{'='*60}") + + # Comparer les métriques pour chaque type d'embedding + comparisons = [] + + # Text Color + trained_color_text_acc = results_trained.get('text_color', {}).get('accuracy', 0) + baseline_color_text_acc = results_baseline.get('text', {}).get('color', {}).get('accuracy', 0) + if trained_color_text_acc > 0 and baseline_color_text_acc > 0: + diff = baseline_color_text_acc - trained_color_text_acc + comparisons.append({ + 'type': 'Text Color', + 'trained': trained_color_text_acc, + 'baseline': baseline_color_text_acc, + 'diff': diff, + 'trained_dims': '0-15 (16 dims)', + 'baseline_dims': 'All dimensions (512 dims)' + }) + + # Image Color + trained_color_img_acc = results_trained.get('image_color', {}).get('accuracy', 0) + baseline_color_img_acc = results_baseline.get('image', {}).get('color', {}).get('accuracy', 0) + if trained_color_img_acc > 0 and baseline_color_img_acc > 0: + diff = baseline_color_img_acc - trained_color_img_acc + comparisons.append({ + 'type': 'Image Color', + 'trained': trained_color_img_acc, + 'baseline': baseline_color_img_acc, + 'diff': diff, + 'trained_dims': '0-16 (17 dims)', + 'baseline_dims': 'All dimensions (512 dims)' + }) + + # Text Hierarchy + trained_hier_text_acc = results_trained.get('text_hierarchy', {}).get('accuracy', 0) + baseline_hier_text_acc = results_baseline.get('text', {}).get('hierarchy', {}).get('accuracy', 0) + if trained_hier_text_acc > 0 and baseline_hier_text_acc > 0: + diff = baseline_hier_text_acc - trained_hier_text_acc + comparisons.append({ + 'type': 'Text Hierarchy', + 'trained': trained_hier_text_acc, + 'baseline': baseline_hier_text_acc, + 'diff': diff, + 'trained_dims': '16-79 (64 dims)', + 'baseline_dims': 'All dimensions (512 dims)' + }) + + # Image Hierarchy + trained_hier_img_acc = results_trained.get('image_hierarchy', {}).get('accuracy', 0) + baseline_hier_img_acc = results_baseline.get('image', {}).get('hierarchy', {}).get('accuracy', 0) + if trained_hier_img_acc > 0 and baseline_hier_img_acc > 0: + diff = baseline_hier_img_acc - trained_hier_img_acc + comparisons.append({ + 'type': 'Image Hierarchy', + 'trained': trained_hier_img_acc, + 'baseline': baseline_hier_img_acc, + 'diff': diff, + 'trained_dims': '16-79 (64 dims)', + 'baseline_dims': 'All dimensions (512 dims)' + }) + + # Afficher les comparaisons + print("\n📈 COMPARAISON DES PERFORMANCES:") + print("-" * 60) + for comp in comparisons: + better = "✅ Baseline" if comp['diff'] > 0 else "✅ Modèle Entraîné" + print(f"\n{comp['type']}:") + print(f" Modèle Entraîné ({comp['trained_dims']}): {comp['trained']*100:.2f}%") + print(f" Baseline ({comp['baseline_dims']}): {comp['baseline']*100:.2f}%") + print(f" Différence: {comp['diff']*100:+.2f}% → {better}") + + # Analyse des raisons + print(f"\n{'='*60}") + print("🔍 EXPLICATIONS POSSIBLES:") + print(f"{'='*60}") + + avg_diff = np.mean([abs(c['diff']) for c in comparisons]) if comparisons else 0 + baseline_better_count = sum(1 for c in comparisons if c['diff'] > 0) + + if baseline_better_count > len(comparisons) / 2: + print("\n⚠️ La baseline performe mieux sur la majorité des métriques.") + print("\nRaisons probables:") + print("\n1. 📐 CAPACITÉ DIMENSIONNELLE:") + print(" • Baseline: Utilise TOUTES les 512 dimensions des embeddings") + print(" • Modèle entraîné: Utilise seulement 17 dims (couleur) ou 64 dims (hiérarchie)") + print(" • Impact: La baseline a accès à plus d'information pour la classification") + + print("\n2. 🎯 SUR-SPÉCIALISATION:") + print(" • Le modèle entraîné a été spécialisé pour séparer couleur et hiérarchie") + print(" • Cette spécialisation peut causer une perte d'information générale") + print(" • Les dimensions non utilisées peuvent contenir de l'information utile") + + print("\n3. 📊 DISTRIBUTION SHIFT:") + print(" • Le dataset de validation peut avoir une distribution différente") + print(" • Le modèle entraîné peut avoir overfitté sur le dataset d'entraînement") + print(" • La baseline pré-entraînée est plus robuste car entraînée sur plus de données") + + print("\n4. 🌐 GÉNÉRALISATION:") + print(" • Baseline Fashion CLIP: Entraînée sur un large dataset diversifié") + print(" • Modèle entraîné: Entraîné sur un dataset plus spécifique") + print(" • La baseline peut mieux généraliser à des distributions nouvelles") + + print("\n5. 🔄 TRADE-OFF SPÉCIALISATION vs CAPACITÉ:") + print(" • Spécialisation (modèle entraîné): Meilleure séparation explicable") + print(" • Capacité (baseline): Plus d'information pour meilleure performance brute") + print(" • C'est un compromis entre interprétabilité et performance") + + print(f"\n{'='*60}") + print("💡 RECOMMANDATIONS:") + print(f"{'='*60}") + print("\n1. Analyser les matrices de confusion pour voir les types d'erreurs") + print("2. Vérifier si le modèle entraîné performe mieux sur le dataset d'entraînement") + print("\n3. 🔧 CONSIDÉRER UTILISER TOUTES LES DIMENSIONS POUR LA CLASSIFICATION FINALE:") + print(" Actuellement:") + print(" • Modèle entraîné: utilise seulement dims 0-15 (couleur) ou dims 16-79 (hiérarchie)") + print(" • Baseline: utilise toutes les 512 dimensions") + print(" ") + print(" Solution proposée:") + print(" • Utiliser TOUTES les dimensions du modèle entraîné (ex: 512 dims) pour la classification") + print(" • Cela permet d'avoir accès à toute l'information disponible") + print(" • Méthode disponible: extract_full_embeddings() pour extraire toutes les dimensions") + print(" • Vous pouvez alors comparer:") + print(" - Spécialisé (16 ou 64 dims) → meilleur pour interprétabilité") + print(" - Complet (512 dims) → meilleur pour performance brute") + print("\n4. Utiliser les embeddings spécialisés pour l'interprétabilité, pas pour la classification brute") + print("5. Si la performance est critique, combiner spécialisé + général (ensemble)") + + return comparisons + + +if __name__ == "__main__": + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + print(f"Using device: {device}") + + directory = 'main_model_analysi' + max_samples = 10000 + + evaluator = ColorHierarchyEvaluator(device=device, directory=directory) + + # Evaluate Fashion-MNIST + print("\n" + "="*60) + print("🚀 Starting evaluation of Fashion-MNIST with Color & Hierarchy embeddings") + print("="*60) + results_fashion = evaluator.evaluate_fashion_mnist(max_samples=max_samples) + + print(f"\n{'='*60}") + print("FASHION-MNIST EVALUATION SUMMARY") + print(f"{'='*60}") + + print("\n🎨 COLOR CLASSIFICATION RESULTS (dims 0-15):") + print(f" Text - NN Acc: {results_fashion['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion['text_color']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_fashion['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion['image_color']['separation_score']:.4f}") + + print("\n📋 HIERARCHY CLASSIFICATION RESULTS (dims 16-79):") + print(f" Text - NN Acc: {results_fashion['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion['text_hierarchy']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_fashion['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion['image_hierarchy']['separation_score']:.4f}") + + # Evaluate Baseline Fashion CLIP on Fashion-MNIST + print("\n" + "="*60) + print("🚀 Starting evaluation of Baseline Fashion CLIP on Fashion-MNIST") + print("="*60) + results_baseline = evaluator.evaluate_baseline_fashion_mnist(max_samples=max_samples) + + print(f"\n{'='*60}") + print("BASELINE FASHION-MNIST EVALUATION SUMMARY") + print(f"{'='*60}") + + print("\n🎨 COLOR CLASSIFICATION RESULTS (Baseline):") + print(f" Text - NN Acc: {results_baseline['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline['text']['color']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_baseline['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline['image']['color']['separation_score']:.4f}") + + print("\n📋 HIERARCHY CLASSIFICATION RESULTS (Baseline):") + print(f" Text - NN Acc: {results_baseline['text']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline['text']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline['text']['hierarchy']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_baseline['image']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline['image']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline['image']['hierarchy']['separation_score']:.4f}") + + # Analyse comparative pour Fashion-MNIST + evaluator.analyze_baseline_vs_trained_performance( + results_fashion, + results_baseline, + "Fashion-MNIST" + ) + + # Evaluate Fashion-MNIST with FULL 512-dimensional embeddings + print("\n" + "="*60) + print("🚀 Starting evaluation of Fashion-MNIST with FULL 512-dimensional embeddings") + print("="*60) + target_hierarchy_classes = evaluator.validation_hierarchy_classes or evaluator.hierarchy_classes + fashion_dataset = load_fashion_mnist_dataset(max_samples, hierarchy_classes=target_hierarchy_classes) + fashion_dataloader = DataLoader(fashion_dataset, batch_size=8, shuffle=False, num_workers=0) + results_fashion_full = evaluator.evaluate_full_embeddings(fashion_dataloader, "Fashion-MNIST", max_samples=max_samples) + + print(f"\n{'='*60}") + print("FASHION-MNIST FULL EMBEDDINGS (512D) EVALUATION SUMMARY") + print(f"{'='*60}") + print("\n🎨 COLOR CLASSIFICATION RESULTS (512 dims):") + print(f" Text - NN Acc: {results_fashion_full['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion_full['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion_full['text_color']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_fashion_full['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion_full['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion_full['image_color']['separation_score']:.4f}") + print("\n📋 HIERARCHY CLASSIFICATION RESULTS (512 dims):") + print(f" Text - NN Acc: {results_fashion_full['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion_full['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion_full['text_hierarchy']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_fashion_full['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion_full['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion_full['image_hierarchy']['separation_score']:.4f}") + + # Compare subspace vs full embeddings for Fashion-MNIST + evaluator.compare_subspace_vs_full_embeddings( + results_fashion, + results_fashion_full, + "Fashion-MNIST" + ) + + # Evaluate KAGL Marqo + print("\n" + "="*60) + print("🚀 Starting evaluation of KAGL Marqo with Color & Hierarchy embeddings") + print("="*60) + results_kaggle = evaluator.evaluate_kaggle_marqo(max_samples=max_samples) + + if results_kaggle is not None: + print(f"\n{'='*60}") + print("KAGL MARQO EVALUATION SUMMARY") + print(f"{'='*60}") + + print("\n🎨 COLOR CLASSIFICATION RESULTS (dims 0-15):") + print(f" Text - NN Acc: {results_kaggle['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['text_color']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_kaggle['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['image_color']['separation_score']:.4f}") + + print("\n📋 HIERARCHY CLASSIFICATION RESULTS (dims 16-79):") + print(f" Text - NN Acc: {results_kaggle['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['text_hierarchy']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_kaggle['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['image_hierarchy']['separation_score']:.4f}") + + # Evaluate Baseline Fashion CLIP on KAGL Marqo + print("\n" + "="*60) + print("🚀 Starting evaluation of Baseline Fashion CLIP on KAGL Marqo") + print("="*60) + results_baseline_kaggle = evaluator.evaluate_baseline_kaggle_marqo(max_samples=max_samples) + + if results_baseline_kaggle is not None: + print(f"\n{'='*60}") + print("BASELINE KAGL MARQO EVALUATION SUMMARY") + print(f"{'='*60}") + + print("\n🎨 COLOR CLASSIFICATION RESULTS (Baseline):") + print(f" Text - NN Acc: {results_baseline_kaggle['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['text']['color']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_baseline_kaggle['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['image']['color']['separation_score']:.4f}") + + print("\n📋 HIERARCHY CLASSIFICATION RESULTS (Baseline):") + print(f" Text - NN Acc: {results_baseline_kaggle['text']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['text']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['text']['hierarchy']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_baseline_kaggle['image']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['image']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['image']['hierarchy']['separation_score']:.4f}") + + # Analyse comparative pour KAGL Marqo + if results_kaggle is not None: + evaluator.analyze_baseline_vs_trained_performance( + results_kaggle, + results_baseline_kaggle, + "KAGL Marqo Dataset" + ) + + # Evaluate KAGL Marqo with FULL 512-dimensional embeddings + print("\n" + "="*60) + print("🚀 Starting evaluation of KAGL Marqo with FULL 512-dimensional embeddings") + print("="*60) + kaggle_dataset = load_kaggle_marqo_dataset(evaluator, max_samples) + if kaggle_dataset is not None: + kaggle_dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0) + results_kaggle_full = evaluator.evaluate_full_embeddings(kaggle_dataloader, "KAGL Marqo", max_samples=max_samples) + + print(f"\n{'='*60}") + print("KAGL MARQO FULL EMBEDDINGS (512D) EVALUATION SUMMARY") + print(f"{'='*60}") + print("\n🎨 COLOR CLASSIFICATION RESULTS (512 dims):") + print(f" Text - NN Acc: {results_kaggle_full['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle_full['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle_full['text_color']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_kaggle_full['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle_full['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle_full['image_color']['separation_score']:.4f}") + print("\n📋 HIERARCHY CLASSIFICATION RESULTS (512 dims):") + print(f" Text - NN Acc: {results_kaggle_full['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle_full['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle_full['text_hierarchy']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_kaggle_full['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle_full['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle_full['image_hierarchy']['separation_score']:.4f}") + + # Compare subspace vs full embeddings for KAGL Marqo + evaluator.compare_subspace_vs_full_embeddings( + results_kaggle, + results_kaggle_full, + "KAGL Marqo" + ) + + # Evaluate Local Validation Dataset + print("\n" + "="*60) + print("🚀 Starting evaluation of Local Validation Dataset with Color & Hierarchy embeddings") + print("="*60) + results_local = evaluator.evaluate_local_validation(max_samples=max_samples) + + if results_local is not None: + print(f"\n{'='*60}") + print("LOCAL VALIDATION DATASET EVALUATION SUMMARY") + print(f"{'='*60}") + + print("\n🎨 COLOR CLASSIFICATION RESULTS (dims 0-15):") + print(f" Text - NN Acc: {results_local['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['text_color']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_local['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['image_color']['separation_score']:.4f}") + + print("\n📋 HIERARCHY CLASSIFICATION RESULTS (dims 16-79):") + print(f" Text - NN Acc: {results_local['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_local['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_local['text_hierarchy']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_local['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_local['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_local['image_hierarchy']['separation_score']:.4f}") + + # Evaluate Baseline Fashion CLIP on Local Validation + print("\n" + "="*60) + print("🚀 Starting evaluation of Baseline Fashion CLIP on Local Validation") + print("="*60) + results_baseline_local = evaluator.evaluate_baseline_local_validation(max_samples=max_samples) + + if results_baseline_local is not None: + print(f"\n{'='*60}") + print("BASELINE LOCAL VALIDATION EVALUATION SUMMARY") + print(f"{'='*60}") + + print("\n🎨 COLOR CLASSIFICATION RESULTS (Baseline):") + print(f" Text - NN Acc: {results_baseline_local['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['text']['color']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_baseline_local['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['color']['separation_score']:.4f}") + + print("\n📋 HIERARCHY CLASSIFICATION RESULTS (Baseline):") + print(f" Text - NN Acc: {results_baseline_local['text']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['text']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['text']['hierarchy']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_baseline_local['image']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['hierarchy']['separation_score']:.4f}") + + # Analyse comparative pour le dataset de validation local + if results_local is not None: + evaluator.analyze_baseline_vs_trained_performance( + results_local, + results_baseline_local, + "Local Validation Dataset" + ) + + # Evaluate Local Validation with FULL 512-dimensional embeddings + print("\n" + "="*60) + print("🚀 Starting evaluation of Local Validation with FULL 512-dimensional embeddings") + print("="*60) + local_dataset = load_local_validation_dataset(max_samples) + if local_dataset is not None: + # Filter to only include hierarchies that exist in our model + if len(local_dataset.dataframe) > 0: + valid_df = local_dataset.dataframe[local_dataset.dataframe['hierarchy'].isin(evaluator.hierarchy_classes)] + if len(valid_df) > 0: + if len(valid_df) < len(local_dataset.dataframe): + local_dataset = LocalDataset(valid_df) + + local_dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0) + results_local_full = evaluator.evaluate_full_embeddings(local_dataloader, "Local Validation", max_samples=max_samples) + + print(f"\n{'='*60}") + print("LOCAL VALIDATION FULL EMBEDDINGS (512D) EVALUATION SUMMARY") + print(f"{'='*60}") + print("\n🎨 COLOR CLASSIFICATION RESULTS (512 dims):") + print(f" Text - NN Acc: {results_local_full['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local_full['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local_full['text_color']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_local_full['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local_full['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local_full['image_color']['separation_score']:.4f}") + print("\n📋 HIERARCHY CLASSIFICATION RESULTS (512 dims):") + print(f" Text - NN Acc: {results_local_full['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_local_full['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_local_full['text_hierarchy']['separation_score']:.4f}") + print(f" Image - NN Acc: {results_local_full['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_local_full['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_local_full['image_hierarchy']['separation_score']:.4f}") + + # Compare subspace vs full embeddings for Local Validation + evaluator.compare_subspace_vs_full_embeddings( + results_local, + results_local_full, + "Local Validation" + ) + + print(f"\n✅ Evaluation completed! Check '{directory}/' for visualization files.")