import os import json os.environ["TOKENIZERS_PARALLELISM"] = "false" import torch import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns import difflib from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics import confusion_matrix, classification_report, accuracy_score from collections import defaultdict from tqdm import tqdm from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image from io import BytesIO import warnings warnings.filterwarnings('ignore') from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers from config import ( color_model_path, color_emb_dim, local_dataset_path, column_local_image_path, tokeniser_path, ) from color_model import ColorCLIP, Tokenizer class KaggleDataset(Dataset): """Dataset class for KAGL Marqo dataset""" def __init__(self, dataframe, image_size=224): self.dataframe = dataframe self.image_size = image_size # Transforms for validation (no augmentation) self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # AUGMENTATION transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __len__(self): return len(self.dataframe) def __getitem__(self, idx): row = self.dataframe.iloc[idx] # Handle image - it should be in row['image_url'] and contain the image data as bytes image_data = row['image_url'] # Check if image_data has 'bytes' key or is already PIL Image if isinstance(image_data, dict) and 'bytes' in image_data: image = Image.open(BytesIO(image_data['bytes'])).convert("RGB") elif hasattr(image_data, 'convert'): # Already a PIL Image image = image_data.convert("RGB") else: # Assume it's raw bytes image = Image.open(BytesIO(image_data)).convert("RGB") # Apply validation transform image = self.transform(image) # Get text and labels description = row['text'] color = row['color'] return image, description, color def load_kaggle_marqo_dataset(max_samples=5000): """Load and prepare Kaggle KAGL dataset with memory optimization""" from datasets import load_dataset print("šŸ“Š Loading Kaggle KAGL dataset...") # Load the dataset dataset = load_dataset("Marqo/KAGL") df = dataset["data"].to_pandas() print(f"āœ… Dataset Kaggle loaded") print(f" Before filtering: {len(df)} samples") print(f" Available columns: {list(df.columns)}") # Ensure we have text and image data df = df.dropna(subset=['text', 'image']) print(f" After removing missing text/image: {len(df)} samples") df_test = df.copy() # Limit to max_samples with RANDOM SAMPLING to get diverse colors if len(df_test) > max_samples: df_test = df_test.sample(n=max_samples, random_state=42) print(f"šŸ“Š Randomly sampled {max_samples} samples from Kaggle dataset") # Create formatted dataset with proper column names kaggle_formatted = pd.DataFrame({ 'image_url': df_test['image'], # This contains image data as bytes 'text': df_test['text'], 'color': df_test['baseColour'].str.lower().str.replace("grey", "gray") # Use actual colors }) # Filter out rows with None/NaN colors before_color_filter = len(kaggle_formatted) kaggle_formatted = kaggle_formatted.dropna(subset=['color']) if len(kaggle_formatted) < before_color_filter: print(f" After removing missing colors: {len(kaggle_formatted)} samples (removed {before_color_filter - len(kaggle_formatted)} samples)") # Filter for colors that were used during training (11 colors) valid_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow'] before_valid_filter = len(kaggle_formatted) kaggle_formatted = kaggle_formatted[kaggle_formatted['color'].isin(valid_colors)] print(f" After filtering for valid colors: {len(kaggle_formatted)} samples (removed {before_valid_filter - len(kaggle_formatted)} samples)") print(f" Valid colors found: {sorted(kaggle_formatted['color'].unique())}") print(f" Final dataset size: {len(kaggle_formatted)} samples") # Show color distribution in final dataset print(f"šŸŽØ Color distribution in Kaggle dataset:") color_counts = kaggle_formatted['color'].value_counts() for color in color_counts.index: print(f" {color}: {color_counts[color]} samples") return KaggleDataset(kaggle_formatted) class LocalDataset(Dataset): """Dataset class for local validation dataset""" def __init__(self, dataframe, image_size=224): self.dataframe = dataframe self.image_size = image_size # Transforms for validation (no augmentation) self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # AUGMENTATION transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __len__(self): return len(self.dataframe) def __getitem__(self, idx): row = self.dataframe.iloc[idx] # Load image from local path image_path = row[column_local_image_path] try: image = Image.open(image_path).convert("RGB") except Exception as e: print(f"Error loading image at index {idx} from {image_path}: {e}") # Create a dummy image if loading fails image = Image.new('RGB', (224, 224), color='gray') # Apply validation transform image = self.transform(image) # Get text and labels description = row['text'] color = row['color'] return image, description, color def load_local_validation_dataset(max_samples=5000): """Load and prepare local validation dataset""" print("šŸ“Š Loading local validation dataset...") df = pd.read_csv(local_dataset_path) print(f"āœ… Dataset loaded: {len(df)} samples") # Filter out rows with NaN values in image path df_clean = df.dropna(subset=[column_local_image_path]) print(f"šŸ“Š After filtering NaN image paths: {len(df_clean)} samples") # Filter for colors that were used during training (11 colors) valid_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow'] if 'color' in df_clean.columns: before_valid_filter = len(df_clean) df_clean = df_clean[df_clean['color'].isin(valid_colors)] print(f"šŸ“Š After filtering for valid colors: {len(df_clean)} samples (removed {before_valid_filter - len(df_clean)} samples)") print(f"šŸŽØ Valid colors found: {sorted(df_clean['color'].unique())}") # Limit to max_samples with RANDOM SAMPLING to get diverse colors if len(df_clean) > max_samples: df_clean = df_clean.sample(n=max_samples, random_state=42) print(f"šŸ“Š Randomly sampled {max_samples} samples") print(f"šŸ“Š Using {len(df_clean)} samples for evaluation") # Show color distribution after sampling if 'color' in df_clean.columns: print(f"šŸŽØ Color distribution in sampled data:") color_counts = df_clean['color'].value_counts() print(f" Total unique colors: {len(color_counts)}") for color in color_counts.index[:15]: # Show top 15 print(f" {color}: {color_counts[color]} samples") return LocalDataset(df_clean) def collate_fn_filter_none(batch): """Collate function that filters out None values from batch with debug print""" # Filter out None values original_len = len(batch) batch = [item for item in batch if item is not None] if original_len > len(batch): print(f"āš ļø Filtered out {original_len - len(batch)} None values from batch (original: {original_len}, filtered: {len(batch)})") if len(batch) == 0: # Return empty batch with correct structure print("āš ļø Empty batch after filtering None values") return torch.tensor([]), [], [] images, texts, colors = zip(*batch) images = torch.stack(images, dim=0) return images, list(texts), list(colors) class ColorEvaluator: """Evaluate color 16 embeddings""" def __init__(self, device='mps', directory="color_model_analysis"): self.device = torch.device(device) self.directory = directory self.color_emb_dim = color_emb_dim os.makedirs(self.directory, exist_ok=True) # Load baseline Fashion CLIP model print("šŸ“¦ Loading baseline Fashion CLIP model...") patrick_model_name = "patrickjohncyh/fashion-clip" self.baseline_processor = CLIPProcessor.from_pretrained(patrick_model_name) self.baseline_model = CLIPModel_transformers.from_pretrained(patrick_model_name).to(self.device) self.baseline_model.eval() print("āœ… Baseline Fashion CLIP model loaded successfully") # Load specialized color model (16D) self.color_model = None self.color_tokenizer = None self._load_color_model() def _load_color_model(self): """Load the specialized 16D color model and tokenizer.""" if self.color_model is not None and self.color_tokenizer is not None: return if not os.path.exists(color_model_path): raise FileNotFoundError(f"Color model file {color_model_path} not found") if not os.path.exists(tokeniser_path): raise FileNotFoundError(f"Tokenizer vocab file {tokeniser_path} not found") print("šŸŽØ Loading specialized color model (16D)...") # Load checkpoint first to get the actual vocab size state_dict = torch.load(color_model_path, map_location=self.device) # Get vocab size from the embedding weight shape in checkpoint vocab_size = state_dict['text_encoder.embedding.weight'].shape[0] print(f" Detected vocab size from checkpoint: {vocab_size}") # Load tokenizer vocab with open(tokeniser_path, "r") as f: vocab = json.load(f) self.color_tokenizer = Tokenizer() self.color_tokenizer.load_vocab(vocab) # Create model with the vocab size from checkpoint (not from tokenizer) self.color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=self.color_emb_dim) # Load state dict self.color_model.load_state_dict(state_dict) self.color_model.to(self.device) self.color_model.eval() print("āœ… Color model loaded successfully") def _tokenize_color_texts(self, texts): """Tokenize texts with the color tokenizer and return padded tensors.""" token_lists = [self.color_tokenizer(t) for t in texts] max_len = max((len(toks) for toks in token_lists), default=0) max_len = max_len if max_len > 0 else 1 input_ids = torch.zeros(len(texts), max_len, dtype=torch.long, device=self.device) lengths = torch.zeros(len(texts), dtype=torch.long, device=self.device) for i, toks in enumerate(token_lists): if len(toks) > 0: input_ids[i, :len(toks)] = torch.tensor(toks, dtype=torch.long, device=self.device) lengths[i] = len(toks) else: lengths[i] = 1 # avoid zero-length return input_ids, lengths def extract_color_embeddings(self, dataloader, embedding_type='text', max_samples=10000): """Extract 16D color embeddings from specialized color model.""" self._load_color_model() all_embeddings = [] all_colors = [] sample_count = 0 with torch.no_grad(): for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} color embeddings"): if sample_count >= max_samples: break images, texts, colors = batch images = images.to(self.device) images = images.expand(-1, 3, -1, -1) if embedding_type == 'text': input_ids, lengths = self._tokenize_color_texts(texts) embeddings = self.color_model.text_encoder(input_ids, lengths) elif embedding_type == 'image': embeddings = self.color_model.image_encoder(images) else: input_ids, lengths = self._tokenize_color_texts(texts) embeddings = self.color_model.text_encoder(input_ids, lengths) all_embeddings.append(embeddings.cpu().numpy()) normalized_colors = [str(c).lower().strip().replace("grey", "gray") for c in colors] all_colors.extend(normalized_colors) sample_count += len(images) del images, embeddings if embedding_type != 'image': del input_ids, lengths torch.cuda.empty_cache() if torch.cuda.is_available() else None return np.vstack(all_embeddings), all_colors def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000): """Extract embeddings from baseline Fashion CLIP model""" all_embeddings = [] all_colors = [] sample_count = 0 with torch.no_grad(): for batch in tqdm(dataloader, desc=f"Extracting baseline {embedding_type} embeddings"): if sample_count >= max_samples: break images, texts, colors = batch images = images.to(self.device) images = images.expand(-1, 3, -1, -1) # Ensure 3 channels # Process text inputs with baseline processor text_inputs = self.baseline_processor(text=texts, padding=True, return_tensors="pt") text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} # Forward pass through baseline model outputs = self.baseline_model(**text_inputs, pixel_values=images) # Extract embeddings based on type if embedding_type == 'text': embeddings = outputs.text_embeds elif embedding_type == 'image': embeddings = outputs.image_embeds else: embeddings = outputs.text_embeds all_embeddings.append(embeddings.cpu().numpy()) all_colors.extend(colors) sample_count += len(images) # Clear GPU memory del images, text_inputs, outputs, embeddings torch.cuda.empty_cache() if torch.cuda.is_available() else None return np.vstack(all_embeddings), all_colors def compute_similarity_metrics(self, embeddings, labels): """Compute intra-class and inter-class similarities - optimized version""" max_samples = min(5000, len(embeddings)) if len(embeddings) > max_samples: indices = np.random.choice(len(embeddings), max_samples, replace=False) embeddings = embeddings[indices] labels = [labels[i] for i in indices] similarities = cosine_similarity(embeddings) # Create label groups using numpy for faster indexing label_array = np.array(labels) unique_labels = np.unique(label_array) label_groups = {label: np.where(label_array == label)[0] for label in unique_labels} # Compute intra-class similarities using vectorized operations intra_class_similarities = [] for label, indices in label_groups.items(): if len(indices) > 1: # Extract submatrix for this class class_similarities = similarities[np.ix_(indices, indices)] # Get upper triangle (excluding diagonal) triu_indices = np.triu_indices_from(class_similarities, k=1) intra_class_similarities.extend(class_similarities[triu_indices].tolist()) # Compute inter-class similarities using vectorized operations inter_class_similarities = [] labels_list = list(label_groups.keys()) for i in range(len(labels_list)): for j in range(i + 1, len(labels_list)): label1_indices = label_groups[labels_list[i]] label2_indices = label_groups[labels_list[j]] # Extract submatrix between two classes inter_sims = similarities[np.ix_(label1_indices, label2_indices)] inter_class_similarities.extend(inter_sims.flatten().tolist()) nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities) centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels) return { 'intra_class_similarities': intra_class_similarities, 'inter_class_similarities': inter_class_similarities, 'intra_class_mean': float(np.mean(intra_class_similarities)) if intra_class_similarities else 0.0, 'inter_class_mean': float(np.mean(inter_class_similarities)) if inter_class_similarities else 0.0, 'separation_score': float(np.mean(intra_class_similarities) - np.mean(inter_class_similarities)) if intra_class_similarities and inter_class_similarities else 0.0, 'accuracy': nn_accuracy, 'centroid_accuracy': centroid_accuracy, } def compute_embedding_accuracy(self, embeddings, labels, similarities): """Compute classification accuracy using nearest neighbor""" correct_predictions = 0 total_predictions = len(labels) for i in range(len(embeddings)): true_label = labels[i] similarities_row = similarities[i].copy() similarities_row[i] = -1 nearest_neighbor_idx = int(np.argmax(similarities_row)) predicted_label = labels[nearest_neighbor_idx] if predicted_label == true_label: correct_predictions += 1 return correct_predictions / total_predictions if total_predictions > 0 else 0.0 def compute_centroid_accuracy(self, embeddings, labels): """Compute classification accuracy using centroids - optimized vectorized version""" unique_labels = list(set(labels)) # Compute centroids efficiently centroids = {} for label in unique_labels: label_mask = np.array(labels) == label centroids[label] = np.mean(embeddings[label_mask], axis=0) # Stack centroids for vectorized similarity computation centroid_matrix = np.vstack([centroids[label] for label in unique_labels]) # Compute all similarities at once similarities = cosine_similarity(embeddings, centroid_matrix) # Get predicted labels predicted_indices = np.argmax(similarities, axis=1) predicted_labels = [unique_labels[idx] for idx in predicted_indices] # Compute accuracy correct_predictions = sum(pred == true for pred, true in zip(predicted_labels, labels)) return correct_predictions / len(labels) if len(labels) > 0 else 0.0 def predict_labels_from_embeddings(self, embeddings, labels): """Predict labels from embeddings using centroid-based classification - optimized vectorized version""" # Filter out None labels when computing centroids unique_labels = [l for l in set(labels) if l is not None] if len(unique_labels) == 0: # If no valid labels, return None for all predictions return [None] * len(embeddings) # Compute centroids efficiently centroids = {} for label in unique_labels: label_mask = np.array(labels) == label if np.any(label_mask): centroids[label] = np.mean(embeddings[label_mask], axis=0) # Stack centroids for vectorized similarity computation centroid_labels = list(centroids.keys()) centroid_matrix = np.vstack([centroids[label] for label in centroid_labels]) # Compute all similarities at once similarities = cosine_similarity(embeddings, centroid_matrix) # Get predicted labels predicted_indices = np.argmax(similarities, axis=1) predictions = [centroid_labels[idx] for idx in predicted_indices] return predictions def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix", label_type="Label"): """Create and plot confusion matrix""" unique_labels = sorted(list(set(true_labels + predicted_labels))) cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels) accuracy = accuracy_score(true_labels, predicted_labels) plt.figure(figsize=(12, 10)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=unique_labels, yticklabels=unique_labels) plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)') plt.ylabel(f'True {label_type}') plt.xlabel(f'Predicted {label_type}') plt.xticks(rotation=45) plt.yticks(rotation=0) plt.tight_layout() return plt.gcf(), accuracy, cm def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label"): """ Evaluate classification performance and create confusion matrix. Args: embeddings: Embeddings labels: True labels embedding_type: Type of embeddings for display label_type: Type of labels (Color) full_embeddings: Optional full 512-dim embeddings for ensemble (if None, uses only embeddings) ensemble_weight: Weight for embeddings in ensemble (0.0 = only full, 1.0 = only embeddings) """ predictions = self.predict_labels_from_embeddings(embeddings, labels) title_suffix = "" # Filter out None values from labels and predictions valid_indices = [i for i, (label, pred) in enumerate(zip(labels, predictions)) if label is not None and pred is not None] if len(valid_indices) == 0: print(f"āš ļø Warning: No valid labels/predictions found (all are None)") return { 'accuracy': 0.0, 'predictions': predictions, 'confusion_matrix': None, 'classification_report': None, 'figure': None, } filtered_labels = [labels[i] for i in valid_indices] filtered_predictions = [predictions[i] for i in valid_indices] accuracy = accuracy_score(filtered_labels, filtered_predictions) fig, acc, cm = self.create_confusion_matrix( filtered_labels, filtered_predictions, f"{embedding_type} - {label_type} Classification{title_suffix}", label_type ) unique_labels = sorted(list(set(filtered_labels))) report = classification_report(filtered_labels, filtered_predictions, labels=unique_labels, target_names=unique_labels, output_dict=True) return { 'accuracy': accuracy, 'predictions': predictions, 'confusion_matrix': cm, 'classification_report': report, 'figure': fig, } def evaluate_kaggle_marqo(self, max_samples): """Evaluate both color embeddings on KAGL Marqo dataset""" print(f"\n{'='*60}") print("Evaluating KAGL Marqo Dataset with Color embeddings") print(f"Max samples: {max_samples}") print(f"{'='*60}") kaggle_dataset = load_kaggle_marqo_dataset(max_samples) if kaggle_dataset is None: print("āŒ Failed to load KAGL dataset") return None dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn_filter_none) results = {} # ========== EXTRACT BASELINE EMBEDDINGS ========== print("\nšŸ“¦ Extracting baseline embeddings...") text_full_embeddings, text_colors_full = self.extract_color_embeddings(dataloader, embedding_type='text', max_samples=max_samples) image_full_embeddings, image_colors_full = self.extract_color_embeddings(dataloader, embedding_type='image', max_samples=max_samples) text_color_metrics = self.compute_similarity_metrics(text_full_embeddings, text_colors_full) text_color_class = self.evaluate_classification_performance( text_full_embeddings, text_colors_full, "Text Color Embeddings (Baseline)", "Color", ) text_color_metrics.update(text_color_class) results['text_color'] = text_color_metrics image_color_metrics = self.compute_similarity_metrics(image_full_embeddings, image_colors_full) image_color_class = self.evaluate_classification_performance( image_full_embeddings, image_colors_full, "Image Color Embeddings (Baseline)", "Color", ) image_color_metrics.update(image_color_class) results['image_color'] = image_color_metrics del text_full_embeddings, image_full_embeddings torch.cuda.empty_cache() if torch.cuda.is_available() else None # ========== SAVE VISUALIZATIONS ========== os.makedirs(self.directory, exist_ok=True) for key in ['text_color', 'image_color']: results[key]['figure'].savefig( f"{self.directory}/kaggle_{key.replace('_', '_')}_confusion_matrix.png", dpi=300, bbox_inches='tight', ) plt.close(results[key]['figure']) return results def evaluate_local_validation(self, max_samples): """Evaluate both color embeddings on local validation dataset""" print(f"\n{'='*60}") print("Evaluating Local Validation Dataset") print(" Color embeddings") print(f"Max samples: {max_samples}") print(f"{'='*60}") local_dataset = load_local_validation_dataset(max_samples) dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0) results = {} # ========== COLOR EVALUATION ========== print("\nšŸŽØ COLOR EVALUATION ") print("=" * 50) # Text color embeddings print("\nšŸ“ Extracting text color embeddings...") text_color_embeddings, text_colors = self.extract_color_embeddings(dataloader, 'text', max_samples) print(f" Text color embeddings shape: {text_color_embeddings.shape}") text_color_metrics = self.compute_similarity_metrics(text_color_embeddings, text_colors) text_color_class = self.evaluate_classification_performance( text_color_embeddings, text_colors, "Text Color Embeddings (Baseline)", "Color" ) text_color_metrics.update(text_color_class) results['text_color'] = text_color_metrics del text_color_embeddings torch.cuda.empty_cache() if torch.cuda.is_available() else None # Image color embeddings print("\nšŸ–¼ļø Extracting image color embeddings...") image_color_embeddings, image_colors = self.extract_color_embeddings(dataloader, 'image', max_samples) print(f" Image color embeddings shape: {image_color_embeddings.shape}") image_color_metrics = self.compute_similarity_metrics(image_color_embeddings, image_colors) image_color_class = self.evaluate_classification_performance( image_color_embeddings, image_colors, "Image Color Embeddings (Baseline)", "Color" ) image_color_metrics.update(image_color_class) results['image_color'] = image_color_metrics del image_color_embeddings torch.cuda.empty_cache() if torch.cuda.is_available() else None # ========== SAVE VISUALIZATIONS ========== os.makedirs(self.directory, exist_ok=True) for key in ['text_color', 'image_color']: results[key]['figure'].savefig( f"{self.directory}/local_{key.replace('_', '_')}_confusion_matrix.png", dpi=300, bbox_inches='tight', ) plt.close(results[key]['figure']) return results def evaluate_baseline_kaggle_marqo(self, max_samples=5000): """Evaluate baseline Fashion CLIP model on KAGL Marqo dataset""" print(f"\n{'='*60}") print("Evaluating Baseline Fashion CLIP on KAGL Marqo Dataset") print(f"Max samples: {max_samples}") print(f"{'='*60}") # Load KAGL Marqo dataset kaggle_dataset = load_kaggle_marqo_dataset(max_samples) if kaggle_dataset is None: print("āŒ Failed to load KAGL dataset") return None # Create dataloader dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn_filter_none) results = {} # Evaluate text embeddings print("\nšŸ“ Extracting baseline text embeddings from KAGL Marqo...") text_embeddings, text_colors = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)") text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors) text_color_classification = self.evaluate_classification_performance( text_embeddings, text_colors, "Baseline KAGL Marqo Text Embeddings - Color", "Color" ) text_color_metrics.update(text_color_classification) results['text'] = { 'color': text_color_metrics } # Clear memory del text_embeddings torch.cuda.empty_cache() if torch.cuda.is_available() else None # Evaluate image embeddings print("\nšŸ–¼ļø Extracting baseline image embeddings from KAGL Marqo...") image_embeddings, image_colors = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)") image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors) image_color_classification = self.evaluate_classification_performance( image_embeddings, image_colors, "Baseline KAGL Marqo Image Embeddings - Color", "Color" ) image_color_metrics.update(image_color_classification) results['image'] = { 'color': image_color_metrics } # Clear memory del image_embeddings torch.cuda.empty_cache() if torch.cuda.is_available() else None # ========== SAVE VISUALIZATIONS ========== os.makedirs(self.directory, exist_ok=True) for key in ['text', 'image']: for subkey in ['color']: figure = results[key][subkey]['figure'] figure.savefig( f"{self.directory}/kaggle_baseline_{key}_{subkey}_confusion_matrix.png", dpi=300, bbox_inches='tight', ) plt.close(figure) return results def evaluate_baseline_local_validation(self, max_samples=5000): """Evaluate baseline Fashion CLIP model on local validation dataset""" print(f"\n{'='*60}") print("Evaluating Baseline Fashion CLIP on Local Validation Dataset") print(f"Max samples: {max_samples}") print(f"{'='*60}") # Load local validation dataset local_dataset = load_local_validation_dataset(max_samples) if local_dataset is None: print("āŒ Failed to load local validation dataset") return None # Create dataloader dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0) results = {} # Evaluate text embeddings print("\nšŸ“ Extracting baseline text embeddings from Local Validation...") text_embeddings, text_colors = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)") text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors) text_color_classification = self.evaluate_classification_performance( text_embeddings, text_colors, "Baseline Local Validation Text Embeddings - Color", "Color" ) text_color_metrics.update(text_color_classification) results['text'] = { 'color': text_color_metrics } # Clear memory del text_embeddings torch.cuda.empty_cache() if torch.cuda.is_available() else None # Evaluate image embeddings print("\nšŸ–¼ļø Extracting baseline image embeddings from Local Validation...") image_embeddings, image_colors = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)") image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors) image_color_classification = self.evaluate_classification_performance( image_embeddings, image_colors, "Baseline Local Validation Image Embeddings - Color", "Color" ) image_color_metrics.update(image_color_classification) results['image'] = { 'color': image_color_metrics } # Clear memory del image_embeddings torch.cuda.empty_cache() if torch.cuda.is_available() else None # ========== SAVE VISUALIZATIONS ========== os.makedirs(self.directory, exist_ok=True) for key in ['text', 'image']: for subkey in ['color']: figure = results[key][subkey]['figure'] figure.savefig( f"{self.directory}/local_baseline_{key}_{subkey}_confusion_matrix.png", dpi=300, bbox_inches='tight', ) plt.close(figure) return results if __name__ == "__main__": device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") print(f"Using device: {device}") directory = 'color_model_analysis' max_samples = 10000 evaluator = ColorEvaluator(device=device, directory=directory) # Evaluate KAGL Marqo print("\n" + "="*60) print("šŸš€ Starting evaluation of KAGL Marqo with Color embeddings") print("="*60) results_kaggle = evaluator.evaluate_kaggle_marqo(max_samples=max_samples) print(f"\n{'='*60}") print("KAGL MARQO EVALUATION SUMMARY") print(f"{'='*60}") print("\nšŸŽØ COLOR CLASSIFICATION RESULTS:") print(f" Text - NN Acc: {results_kaggle['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['text_color']['separation_score']:.4f}") print(f" Image - NN Acc: {results_kaggle['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['image_color']['separation_score']:.4f}") # Evaluate Baseline Fashion CLIP on KAGL Marqo print("\n" + "="*60) print("šŸš€ Starting evaluation of Baseline Fashion CLIP on KAGL Marqo") print("="*60) results_baseline_kaggle = evaluator.evaluate_baseline_kaggle_marqo(max_samples=max_samples) print(f"\n{'='*60}") print("BASELINE KAGL MARQO EVALUATION SUMMARY") print(f"{'='*60}") print("\nšŸŽØ COLOR CLASSIFICATION RESULTS (Baseline):") print(f" Text - NN Acc: {results_baseline_kaggle['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['text']['color']['separation_score']:.4f}") print(f" Image - NN Acc: {results_baseline_kaggle['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['image']['color']['separation_score']:.4f}") # Evaluate Local Validation Dataset print("\n" + "="*60) print("šŸš€ Starting evaluation of Local Validation Dataset with Color embeddings") print("="*60) results_local = evaluator.evaluate_local_validation(max_samples=max_samples) if results_local is not None: print(f"\n{'='*60}") print("LOCAL VALIDATION DATASET EVALUATION SUMMARY") print(f"{'='*60}") print("\nšŸŽØ COLOR CLASSIFICATION RESULTS:") print(f" Text - NN Acc: {results_local['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['text_color']['separation_score']:.4f}") print(f" Image - NN Acc: {results_local['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['image_color']['separation_score']:.4f}") # Evaluate Baseline Fashion CLIP on Local Validation print("\n" + "="*60) print("šŸš€ Starting evaluation of Baseline Fashion CLIP on Local Validation") print("="*60) results_baseline_local = evaluator.evaluate_baseline_local_validation(max_samples=max_samples) if results_baseline_local is not None: print(f"\n{'='*60}") print("BASELINE LOCAL VALIDATION EVALUATION SUMMARY") print(f"{'='*60}") print("\nšŸŽØ COLOR CLASSIFICATION RESULTS (Baseline):") print(f" Text - NN Acc: {results_baseline_local['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['text']['color']['separation_score']:.4f}") print(f" Image - NN Acc: {results_baseline_local['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['color']['separation_score']:.4f}") print(f"\nāœ… Evaluation completed! Check '{directory}/' for visualization files.")