| | import os |
| | import json |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | import torch |
| | import pandas as pd |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | import difflib |
| | from sklearn.metrics.pairwise import cosine_similarity |
| | from sklearn.metrics import confusion_matrix, classification_report, accuracy_score |
| | from collections import defaultdict |
| | from tqdm import tqdm |
| | from torch.utils.data import Dataset, DataLoader |
| | from torchvision import transforms |
| | from PIL import Image |
| | from io import BytesIO |
| | import warnings |
| | warnings.filterwarnings('ignore') |
| | from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers |
| |
|
| | from config import ( |
| | color_model_path, |
| | color_emb_dim, |
| | local_dataset_path, |
| | column_local_image_path, |
| | tokeniser_path, |
| | ) |
| | from color_model import ColorCLIP, Tokenizer |
| |
|
| |
|
| | class KaggleDataset(Dataset): |
| | """Dataset class for KAGL Marqo dataset""" |
| | def __init__(self, dataframe, image_size=224): |
| | self.dataframe = dataframe |
| | self.image_size = image_size |
| | |
| | |
| | self.transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| | |
| | def __len__(self): |
| | return len(self.dataframe) |
| |
|
| | def __getitem__(self, idx): |
| | row = self.dataframe.iloc[idx] |
| | |
| | |
| | image_data = row['image_url'] |
| | |
| | |
| | if isinstance(image_data, dict) and 'bytes' in image_data: |
| | image = Image.open(BytesIO(image_data['bytes'])).convert("RGB") |
| | elif hasattr(image_data, 'convert'): |
| | image = image_data.convert("RGB") |
| | else: |
| | |
| | image = Image.open(BytesIO(image_data)).convert("RGB") |
| | |
| | |
| | image = self.transform(image) |
| |
|
| | |
| | description = row['text'] |
| | color = row['color'] |
| |
|
| | return image, description, color |
| |
|
| |
|
| | def load_kaggle_marqo_dataset(max_samples=5000): |
| | """Load and prepare Kaggle KAGL dataset with memory optimization""" |
| | from datasets import load_dataset |
| | print("๐ Loading Kaggle KAGL dataset...") |
| |
|
| | |
| | dataset = load_dataset("Marqo/KAGL") |
| | df = dataset["data"].to_pandas() |
| | print(f"โ
Dataset Kaggle loaded") |
| | print(f" Before filtering: {len(df)} samples") |
| | print(f" Available columns: {list(df.columns)}") |
| | |
| | |
| | df = df.dropna(subset=['text', 'image']) |
| | print(f" After removing missing text/image: {len(df)} samples") |
| |
|
| | df_test = df.copy() |
| | |
| | |
| | if len(df_test) > max_samples: |
| | df_test = df_test.sample(n=max_samples, random_state=42) |
| | print(f"๐ Randomly sampled {max_samples} samples from Kaggle dataset") |
| | |
| | |
| | kaggle_formatted = pd.DataFrame({ |
| | 'image_url': df_test['image'], |
| | 'text': df_test['text'], |
| | 'color': df_test['baseColour'].str.lower().str.replace("grey", "gray") |
| | }) |
| | |
| | |
| | before_color_filter = len(kaggle_formatted) |
| | kaggle_formatted = kaggle_formatted.dropna(subset=['color']) |
| | if len(kaggle_formatted) < before_color_filter: |
| | print(f" After removing missing colors: {len(kaggle_formatted)} samples (removed {before_color_filter - len(kaggle_formatted)} samples)") |
| | |
| | |
| | valid_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow'] |
| | before_valid_filter = len(kaggle_formatted) |
| | kaggle_formatted = kaggle_formatted[kaggle_formatted['color'].isin(valid_colors)] |
| | print(f" After filtering for valid colors: {len(kaggle_formatted)} samples (removed {before_valid_filter - len(kaggle_formatted)} samples)") |
| | print(f" Valid colors found: {sorted(kaggle_formatted['color'].unique())}") |
| | |
| | print(f" Final dataset size: {len(kaggle_formatted)} samples") |
| | |
| | |
| | print(f"๐จ Color distribution in Kaggle dataset:") |
| | color_counts = kaggle_formatted['color'].value_counts() |
| | for color in color_counts.index: |
| | print(f" {color}: {color_counts[color]} samples") |
| | |
| | return KaggleDataset(kaggle_formatted) |
| |
|
| |
|
| | class LocalDataset(Dataset): |
| | """Dataset class for local validation dataset""" |
| | def __init__(self, dataframe, image_size=224): |
| | self.dataframe = dataframe |
| | self.image_size = image_size |
| | |
| | |
| | self.transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| | |
| | def __len__(self): |
| | return len(self.dataframe) |
| |
|
| | def __getitem__(self, idx): |
| | row = self.dataframe.iloc[idx] |
| | |
| | |
| | image_path = row[column_local_image_path] |
| | try: |
| | image = Image.open(image_path).convert("RGB") |
| | except Exception as e: |
| | print(f"Error loading image at index {idx} from {image_path}: {e}") |
| | |
| | image = Image.new('RGB', (224, 224), color='gray') |
| | |
| | |
| | image = self.transform(image) |
| |
|
| | |
| | description = row['text'] |
| | color = row['color'] |
| |
|
| | return image, description, color |
| |
|
| |
|
| | def load_local_validation_dataset(max_samples=5000): |
| | """Load and prepare local validation dataset""" |
| | print("๐ Loading local validation dataset...") |
| | |
| | df = pd.read_csv(local_dataset_path) |
| | print(f"โ
Dataset loaded: {len(df)} samples") |
| | |
| | |
| | df_clean = df.dropna(subset=[column_local_image_path]) |
| | print(f"๐ After filtering NaN image paths: {len(df_clean)} samples") |
| | |
| | |
| | valid_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow'] |
| | if 'color' in df_clean.columns: |
| | before_valid_filter = len(df_clean) |
| | df_clean = df_clean[df_clean['color'].isin(valid_colors)] |
| | print(f"๐ After filtering for valid colors: {len(df_clean)} samples (removed {before_valid_filter - len(df_clean)} samples)") |
| | print(f"๐จ Valid colors found: {sorted(df_clean['color'].unique())}") |
| | |
| | |
| | if len(df_clean) > max_samples: |
| | df_clean = df_clean.sample(n=max_samples, random_state=42) |
| | print(f"๐ Randomly sampled {max_samples} samples") |
| | |
| | print(f"๐ Using {len(df_clean)} samples for evaluation") |
| | |
| | |
| | if 'color' in df_clean.columns: |
| | print(f"๐จ Color distribution in sampled data:") |
| | color_counts = df_clean['color'].value_counts() |
| | print(f" Total unique colors: {len(color_counts)}") |
| | for color in color_counts.index[:15]: |
| | print(f" {color}: {color_counts[color]} samples") |
| | |
| | return LocalDataset(df_clean) |
| |
|
| |
|
| | def collate_fn_filter_none(batch): |
| | """Collate function that filters out None values from batch with debug print""" |
| | |
| | original_len = len(batch) |
| | batch = [item for item in batch if item is not None] |
| | |
| | if original_len > len(batch): |
| | print(f"โ ๏ธ Filtered out {original_len - len(batch)} None values from batch (original: {original_len}, filtered: {len(batch)})") |
| | |
| | if len(batch) == 0: |
| | |
| | print("โ ๏ธ Empty batch after filtering None values") |
| | return torch.tensor([]), [], [] |
| | |
| | images, texts, colors = zip(*batch) |
| | images = torch.stack(images, dim=0) |
| | return images, list(texts), list(colors) |
| |
|
| |
|
| | class ColorEvaluator: |
| | """Evaluate color 16 embeddings""" |
| |
|
| | def __init__(self, device='mps', directory="color_model_analysis"): |
| | self.device = torch.device(device) |
| | self.directory = directory |
| | self.color_emb_dim = color_emb_dim |
| | os.makedirs(self.directory, exist_ok=True) |
| | |
| | |
| | print("๐ฆ Loading baseline Fashion CLIP model...") |
| | patrick_model_name = "patrickjohncyh/fashion-clip" |
| | self.baseline_processor = CLIPProcessor.from_pretrained(patrick_model_name) |
| | self.baseline_model = CLIPModel_transformers.from_pretrained(patrick_model_name).to(self.device) |
| | self.baseline_model.eval() |
| | print("โ
Baseline Fashion CLIP model loaded successfully") |
| |
|
| | |
| | self.color_model = None |
| | self.color_tokenizer = None |
| | self._load_color_model() |
| |
|
| | def _load_color_model(self): |
| | """Load the specialized 16D color model and tokenizer.""" |
| | if self.color_model is not None and self.color_tokenizer is not None: |
| | return |
| |
|
| | if not os.path.exists(color_model_path): |
| | raise FileNotFoundError(f"Color model file {color_model_path} not found") |
| | if not os.path.exists(tokeniser_path): |
| | raise FileNotFoundError(f"Tokenizer vocab file {tokeniser_path} not found") |
| |
|
| | print("๐จ Loading specialized color model (16D)...") |
| | |
| | |
| | state_dict = torch.load(color_model_path, map_location=self.device) |
| | |
| | |
| | vocab_size = state_dict['text_encoder.embedding.weight'].shape[0] |
| | print(f" Detected vocab size from checkpoint: {vocab_size}") |
| | |
| | |
| | with open(tokeniser_path, "r") as f: |
| | vocab = json.load(f) |
| |
|
| | self.color_tokenizer = Tokenizer() |
| | self.color_tokenizer.load_vocab(vocab) |
| | |
| | |
| | self.color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=self.color_emb_dim) |
| | |
| | |
| | self.color_model.load_state_dict(state_dict) |
| | self.color_model.to(self.device) |
| | self.color_model.eval() |
| | print("โ
Color model loaded successfully") |
| |
|
| | def _tokenize_color_texts(self, texts): |
| | """Tokenize texts with the color tokenizer and return padded tensors.""" |
| | token_lists = [self.color_tokenizer(t) for t in texts] |
| | max_len = max((len(toks) for toks in token_lists), default=0) |
| | max_len = max_len if max_len > 0 else 1 |
| |
|
| | input_ids = torch.zeros(len(texts), max_len, dtype=torch.long, device=self.device) |
| | lengths = torch.zeros(len(texts), dtype=torch.long, device=self.device) |
| |
|
| | for i, toks in enumerate(token_lists): |
| | if len(toks) > 0: |
| | input_ids[i, :len(toks)] = torch.tensor(toks, dtype=torch.long, device=self.device) |
| | lengths[i] = len(toks) |
| | else: |
| | lengths[i] = 1 |
| |
|
| | return input_ids, lengths |
| |
|
| | def extract_color_embeddings(self, dataloader, embedding_type='text', max_samples=10000): |
| | """Extract 16D color embeddings from specialized color model.""" |
| | self._load_color_model() |
| | all_embeddings = [] |
| | all_colors = [] |
| |
|
| | sample_count = 0 |
| | with torch.no_grad(): |
| | for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} color embeddings"): |
| | if sample_count >= max_samples: |
| | break |
| |
|
| | images, texts, colors = batch |
| | images = images.to(self.device) |
| | images = images.expand(-1, 3, -1, -1) |
| |
|
| | if embedding_type == 'text': |
| | input_ids, lengths = self._tokenize_color_texts(texts) |
| | embeddings = self.color_model.text_encoder(input_ids, lengths) |
| | elif embedding_type == 'image': |
| | embeddings = self.color_model.image_encoder(images) |
| | else: |
| | input_ids, lengths = self._tokenize_color_texts(texts) |
| | embeddings = self.color_model.text_encoder(input_ids, lengths) |
| |
|
| | all_embeddings.append(embeddings.cpu().numpy()) |
| | normalized_colors = [str(c).lower().strip().replace("grey", "gray") for c in colors] |
| | all_colors.extend(normalized_colors) |
| |
|
| | sample_count += len(images) |
| |
|
| | del images, embeddings |
| | if embedding_type != 'image': |
| | del input_ids, lengths |
| | torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| |
|
| | return np.vstack(all_embeddings), all_colors |
| |
|
| | def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000): |
| | """Extract embeddings from baseline Fashion CLIP model""" |
| | all_embeddings = [] |
| | all_colors = [] |
| | |
| | sample_count = 0 |
| | |
| | with torch.no_grad(): |
| | for batch in tqdm(dataloader, desc=f"Extracting baseline {embedding_type} embeddings"): |
| | if sample_count >= max_samples: |
| | break |
| | |
| | images, texts, colors = batch |
| | images = images.to(self.device) |
| | images = images.expand(-1, 3, -1, -1) |
| | |
| | |
| | text_inputs = self.baseline_processor(text=texts, padding=True, return_tensors="pt") |
| | text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} |
| | |
| | |
| | outputs = self.baseline_model(**text_inputs, pixel_values=images) |
| | |
| | |
| | if embedding_type == 'text': |
| | embeddings = outputs.text_embeds |
| | elif embedding_type == 'image': |
| | embeddings = outputs.image_embeds |
| | else: |
| | embeddings = outputs.text_embeds |
| | |
| | all_embeddings.append(embeddings.cpu().numpy()) |
| | all_colors.extend(colors) |
| | |
| | sample_count += len(images) |
| | |
| | |
| | del images, text_inputs, outputs, embeddings |
| | torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| | |
| | return np.vstack(all_embeddings), all_colors |
| |
|
| | def compute_similarity_metrics(self, embeddings, labels): |
| | """Compute intra-class and inter-class similarities - optimized version""" |
| | max_samples = min(5000, len(embeddings)) |
| | if len(embeddings) > max_samples: |
| | indices = np.random.choice(len(embeddings), max_samples, replace=False) |
| | embeddings = embeddings[indices] |
| | labels = [labels[i] for i in indices] |
| |
|
| | similarities = cosine_similarity(embeddings) |
| |
|
| | |
| | label_array = np.array(labels) |
| | unique_labels = np.unique(label_array) |
| | label_groups = {label: np.where(label_array == label)[0] for label in unique_labels} |
| |
|
| | |
| | intra_class_similarities = [] |
| | for label, indices in label_groups.items(): |
| | if len(indices) > 1: |
| | |
| | class_similarities = similarities[np.ix_(indices, indices)] |
| | |
| | triu_indices = np.triu_indices_from(class_similarities, k=1) |
| | intra_class_similarities.extend(class_similarities[triu_indices].tolist()) |
| |
|
| | |
| | inter_class_similarities = [] |
| | labels_list = list(label_groups.keys()) |
| | for i in range(len(labels_list)): |
| | for j in range(i + 1, len(labels_list)): |
| | label1_indices = label_groups[labels_list[i]] |
| | label2_indices = label_groups[labels_list[j]] |
| | |
| | inter_sims = similarities[np.ix_(label1_indices, label2_indices)] |
| | inter_class_similarities.extend(inter_sims.flatten().tolist()) |
| |
|
| | nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities) |
| | centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels) |
| |
|
| | return { |
| | 'intra_class_similarities': intra_class_similarities, |
| | 'inter_class_similarities': inter_class_similarities, |
| | 'intra_class_mean': float(np.mean(intra_class_similarities)) if intra_class_similarities else 0.0, |
| | 'inter_class_mean': float(np.mean(inter_class_similarities)) if inter_class_similarities else 0.0, |
| | 'separation_score': float(np.mean(intra_class_similarities) - np.mean(inter_class_similarities)) if intra_class_similarities and inter_class_similarities else 0.0, |
| | 'accuracy': nn_accuracy, |
| | 'centroid_accuracy': centroid_accuracy, |
| | } |
| |
|
| | def compute_embedding_accuracy(self, embeddings, labels, similarities): |
| | """Compute classification accuracy using nearest neighbor""" |
| | correct_predictions = 0 |
| | total_predictions = len(labels) |
| | for i in range(len(embeddings)): |
| | true_label = labels[i] |
| | similarities_row = similarities[i].copy() |
| | similarities_row[i] = -1 |
| | nearest_neighbor_idx = int(np.argmax(similarities_row)) |
| | predicted_label = labels[nearest_neighbor_idx] |
| | if predicted_label == true_label: |
| | correct_predictions += 1 |
| | return correct_predictions / total_predictions if total_predictions > 0 else 0.0 |
| |
|
| | def compute_centroid_accuracy(self, embeddings, labels): |
| | """Compute classification accuracy using centroids - optimized vectorized version""" |
| | unique_labels = list(set(labels)) |
| | |
| | |
| | centroids = {} |
| | for label in unique_labels: |
| | label_mask = np.array(labels) == label |
| | centroids[label] = np.mean(embeddings[label_mask], axis=0) |
| | |
| | |
| | centroid_matrix = np.vstack([centroids[label] for label in unique_labels]) |
| | |
| | |
| | similarities = cosine_similarity(embeddings, centroid_matrix) |
| | |
| | |
| | predicted_indices = np.argmax(similarities, axis=1) |
| | predicted_labels = [unique_labels[idx] for idx in predicted_indices] |
| | |
| | |
| | correct_predictions = sum(pred == true for pred, true in zip(predicted_labels, labels)) |
| | return correct_predictions / len(labels) if len(labels) > 0 else 0.0 |
| |
|
| | def predict_labels_from_embeddings(self, embeddings, labels): |
| | """Predict labels from embeddings using centroid-based classification - optimized vectorized version""" |
| | |
| | unique_labels = [l for l in set(labels) if l is not None] |
| | if len(unique_labels) == 0: |
| | |
| | return [None] * len(embeddings) |
| | |
| | |
| | centroids = {} |
| | for label in unique_labels: |
| | label_mask = np.array(labels) == label |
| | if np.any(label_mask): |
| | centroids[label] = np.mean(embeddings[label_mask], axis=0) |
| | |
| | |
| | centroid_labels = list(centroids.keys()) |
| | centroid_matrix = np.vstack([centroids[label] for label in centroid_labels]) |
| | |
| | |
| | similarities = cosine_similarity(embeddings, centroid_matrix) |
| | |
| | |
| | predicted_indices = np.argmax(similarities, axis=1) |
| | predictions = [centroid_labels[idx] for idx in predicted_indices] |
| | |
| | return predictions |
| |
|
| | def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix", label_type="Label"): |
| | """Create and plot confusion matrix""" |
| | unique_labels = sorted(list(set(true_labels + predicted_labels))) |
| | cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels) |
| | accuracy = accuracy_score(true_labels, predicted_labels) |
| | plt.figure(figsize=(12, 10)) |
| | sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=unique_labels, yticklabels=unique_labels) |
| | plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)') |
| | plt.ylabel(f'True {label_type}') |
| | plt.xlabel(f'Predicted {label_type}') |
| | plt.xticks(rotation=45) |
| | plt.yticks(rotation=0) |
| | plt.tight_layout() |
| | return plt.gcf(), accuracy, cm |
| |
|
| | def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label"): |
| | """ |
| | Evaluate classification performance and create confusion matrix. |
| | |
| | Args: |
| | embeddings: Embeddings |
| | labels: True labels |
| | embedding_type: Type of embeddings for display |
| | label_type: Type of labels (Color) |
| | full_embeddings: Optional full 512-dim embeddings for ensemble (if None, uses only embeddings) |
| | ensemble_weight: Weight for embeddings in ensemble (0.0 = only full, 1.0 = only embeddings) |
| | """ |
| | |
| | predictions = self.predict_labels_from_embeddings(embeddings, labels) |
| | title_suffix = "" |
| | |
| | |
| | valid_indices = [i for i, (label, pred) in enumerate(zip(labels, predictions)) |
| | if label is not None and pred is not None] |
| | |
| | if len(valid_indices) == 0: |
| | print(f"โ ๏ธ Warning: No valid labels/predictions found (all are None)") |
| | return { |
| | 'accuracy': 0.0, |
| | 'predictions': predictions, |
| | 'confusion_matrix': None, |
| | 'classification_report': None, |
| | 'figure': None, |
| | } |
| | |
| | filtered_labels = [labels[i] for i in valid_indices] |
| | filtered_predictions = [predictions[i] for i in valid_indices] |
| | |
| | accuracy = accuracy_score(filtered_labels, filtered_predictions) |
| | fig, acc, cm = self.create_confusion_matrix( |
| | filtered_labels, filtered_predictions, |
| | f"{embedding_type} - {label_type} Classification{title_suffix}", |
| | label_type |
| | ) |
| | unique_labels = sorted(list(set(filtered_labels))) |
| | report = classification_report(filtered_labels, filtered_predictions, labels=unique_labels, target_names=unique_labels, output_dict=True) |
| | return { |
| | 'accuracy': accuracy, |
| | 'predictions': predictions, |
| | 'confusion_matrix': cm, |
| | 'classification_report': report, |
| | 'figure': fig, |
| | } |
| |
|
| |
|
| | def evaluate_kaggle_marqo(self, max_samples): |
| | """Evaluate both color embeddings on KAGL Marqo dataset""" |
| | print(f"\n{'='*60}") |
| | print("Evaluating KAGL Marqo Dataset with Color embeddings") |
| | print(f"Max samples: {max_samples}") |
| | print(f"{'='*60}") |
| |
|
| | kaggle_dataset = load_kaggle_marqo_dataset(max_samples) |
| | if kaggle_dataset is None: |
| | print("โ Failed to load KAGL dataset") |
| | return None |
| |
|
| | dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn_filter_none) |
| | |
| | results = {} |
| |
|
| | |
| | print("\n๐ฆ Extracting baseline embeddings...") |
| | text_full_embeddings, text_colors_full = self.extract_color_embeddings(dataloader, embedding_type='text', max_samples=max_samples) |
| | image_full_embeddings, image_colors_full = self.extract_color_embeddings(dataloader, embedding_type='image', max_samples=max_samples) |
| | text_color_metrics = self.compute_similarity_metrics(text_full_embeddings, text_colors_full) |
| | text_color_class = self.evaluate_classification_performance( |
| | text_full_embeddings, text_colors_full, |
| | "Text Color Embeddings (Baseline)", "Color", |
| | ) |
| | text_color_metrics.update(text_color_class) |
| | results['text_color'] = text_color_metrics |
| | image_color_metrics = self.compute_similarity_metrics(image_full_embeddings, image_colors_full) |
| | image_color_class = self.evaluate_classification_performance( |
| | image_full_embeddings, image_colors_full, |
| | "Image Color Embeddings (Baseline)", "Color", |
| | ) |
| | image_color_metrics.update(image_color_class) |
| | results['image_color'] = image_color_metrics |
| | del text_full_embeddings, image_full_embeddings |
| | torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| |
|
| | |
| | os.makedirs(self.directory, exist_ok=True) |
| | for key in ['text_color', 'image_color']: |
| | results[key]['figure'].savefig( |
| | f"{self.directory}/kaggle_{key.replace('_', '_')}_confusion_matrix.png", |
| | dpi=300, |
| | bbox_inches='tight', |
| | ) |
| | plt.close(results[key]['figure']) |
| |
|
| | return results |
| |
|
| | def evaluate_local_validation(self, max_samples): |
| | """Evaluate both color embeddings on local validation dataset""" |
| | print(f"\n{'='*60}") |
| | print("Evaluating Local Validation Dataset") |
| | print(" Color embeddings") |
| | print(f"Max samples: {max_samples}") |
| | print(f"{'='*60}") |
| |
|
| | local_dataset = load_local_validation_dataset(max_samples) |
| | dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0) |
| |
|
| | results = {} |
| |
|
| | |
| | print("\n๐จ COLOR EVALUATION ") |
| | print("=" * 50) |
| | |
| | |
| | print("\n๐ Extracting text color embeddings...") |
| | text_color_embeddings, text_colors = self.extract_color_embeddings(dataloader, 'text', max_samples) |
| | print(f" Text color embeddings shape: {text_color_embeddings.shape}") |
| | text_color_metrics = self.compute_similarity_metrics(text_color_embeddings, text_colors) |
| | text_color_class = self.evaluate_classification_performance( |
| | text_color_embeddings, text_colors, "Text Color Embeddings (Baseline)", "Color" |
| | ) |
| | text_color_metrics.update(text_color_class) |
| | results['text_color'] = text_color_metrics |
| |
|
| | del text_color_embeddings |
| | torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| |
|
| | |
| | print("\n๐ผ๏ธ Extracting image color embeddings...") |
| | image_color_embeddings, image_colors = self.extract_color_embeddings(dataloader, 'image', max_samples) |
| | print(f" Image color embeddings shape: {image_color_embeddings.shape}") |
| | image_color_metrics = self.compute_similarity_metrics(image_color_embeddings, image_colors) |
| | image_color_class = self.evaluate_classification_performance( |
| | image_color_embeddings, image_colors, "Image Color Embeddings (Baseline)", "Color" |
| | ) |
| | image_color_metrics.update(image_color_class) |
| | results['image_color'] = image_color_metrics |
| |
|
| | del image_color_embeddings |
| | torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| | |
| | os.makedirs(self.directory, exist_ok=True) |
| | for key in ['text_color', 'image_color']: |
| | results[key]['figure'].savefig( |
| | f"{self.directory}/local_{key.replace('_', '_')}_confusion_matrix.png", |
| | dpi=300, |
| | bbox_inches='tight', |
| | ) |
| | plt.close(results[key]['figure']) |
| |
|
| | return results |
| |
|
| |
|
| | def evaluate_baseline_kaggle_marqo(self, max_samples=5000): |
| | """Evaluate baseline Fashion CLIP model on KAGL Marqo dataset""" |
| | print(f"\n{'='*60}") |
| | print("Evaluating Baseline Fashion CLIP on KAGL Marqo Dataset") |
| | print(f"Max samples: {max_samples}") |
| | print(f"{'='*60}") |
| | |
| | |
| | kaggle_dataset = load_kaggle_marqo_dataset(max_samples) |
| | if kaggle_dataset is None: |
| | print("โ Failed to load KAGL dataset") |
| | return None |
| | |
| | |
| | dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn_filter_none) |
| | |
| | results = {} |
| | |
| | |
| | print("\n๐ Extracting baseline text embeddings from KAGL Marqo...") |
| | text_embeddings, text_colors = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) |
| | print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)") |
| | text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors) |
| | |
| | text_color_classification = self.evaluate_classification_performance( |
| | text_embeddings, text_colors, "Baseline KAGL Marqo Text Embeddings - Color", "Color" |
| | ) |
| | text_color_metrics.update(text_color_classification) |
| | results['text'] = { |
| | 'color': text_color_metrics |
| | } |
| | |
| | |
| | del text_embeddings |
| | torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| | |
| | |
| | print("\n๐ผ๏ธ Extracting baseline image embeddings from KAGL Marqo...") |
| | image_embeddings, image_colors = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) |
| | print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)") |
| | image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors) |
| | |
| | image_color_classification = self.evaluate_classification_performance( |
| | image_embeddings, image_colors, "Baseline KAGL Marqo Image Embeddings - Color", "Color" |
| | ) |
| | image_color_metrics.update(image_color_classification) |
| | results['image'] = { |
| | 'color': image_color_metrics |
| | } |
| | |
| | |
| | del image_embeddings |
| | torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| | |
| | |
| | os.makedirs(self.directory, exist_ok=True) |
| | for key in ['text', 'image']: |
| | for subkey in ['color']: |
| | figure = results[key][subkey]['figure'] |
| | figure.savefig( |
| | f"{self.directory}/kaggle_baseline_{key}_{subkey}_confusion_matrix.png", |
| | dpi=300, |
| | bbox_inches='tight', |
| | ) |
| | plt.close(figure) |
| | |
| | return results |
| |
|
| | def evaluate_baseline_local_validation(self, max_samples=5000): |
| | """Evaluate baseline Fashion CLIP model on local validation dataset""" |
| | print(f"\n{'='*60}") |
| | print("Evaluating Baseline Fashion CLIP on Local Validation Dataset") |
| | print(f"Max samples: {max_samples}") |
| | print(f"{'='*60}") |
| | |
| | |
| | local_dataset = load_local_validation_dataset(max_samples) |
| | if local_dataset is None: |
| | print("โ Failed to load local validation dataset") |
| | return None |
| | |
| | |
| | dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0) |
| | |
| | results = {} |
| | |
| | |
| | print("\n๐ Extracting baseline text embeddings from Local Validation...") |
| | text_embeddings, text_colors = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) |
| | print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)") |
| | text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors) |
| | |
| | text_color_classification = self.evaluate_classification_performance( |
| | text_embeddings, text_colors, "Baseline Local Validation Text Embeddings - Color", "Color" |
| | ) |
| | text_color_metrics.update(text_color_classification) |
| | results['text'] = { |
| | 'color': text_color_metrics |
| | } |
| | |
| | |
| | del text_embeddings |
| | torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| | |
| | |
| | print("\n๐ผ๏ธ Extracting baseline image embeddings from Local Validation...") |
| | image_embeddings, image_colors = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) |
| | print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)") |
| | image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors) |
| | |
| | image_color_classification = self.evaluate_classification_performance( |
| | image_embeddings, image_colors, "Baseline Local Validation Image Embeddings - Color", "Color" |
| | ) |
| | image_color_metrics.update(image_color_classification) |
| | results['image'] = { |
| | 'color': image_color_metrics |
| | } |
| | |
| | |
| | del image_embeddings |
| | torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| | |
| | |
| | os.makedirs(self.directory, exist_ok=True) |
| | for key in ['text', 'image']: |
| | for subkey in ['color']: |
| | figure = results[key][subkey]['figure'] |
| | figure.savefig( |
| | f"{self.directory}/local_baseline_{key}_{subkey}_confusion_matrix.png", |
| | dpi=300, |
| | bbox_inches='tight', |
| | ) |
| | plt.close(figure) |
| | |
| | return results |
| |
|
| |
|
| | if __name__ == "__main__": |
| | device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
| | print(f"Using device: {device}") |
| |
|
| | directory = 'color_model_analysis' |
| | max_samples = 10000 |
| |
|
| | evaluator = ColorEvaluator(device=device, directory=directory) |
| |
|
| | |
| | print("\n" + "="*60) |
| | print("๐ Starting evaluation of KAGL Marqo with Color embeddings") |
| | print("="*60) |
| | results_kaggle = evaluator.evaluate_kaggle_marqo(max_samples=max_samples) |
| |
|
| | print(f"\n{'='*60}") |
| | print("KAGL MARQO EVALUATION SUMMARY") |
| | print(f"{'='*60}") |
| | |
| | print("\n๐จ COLOR CLASSIFICATION RESULTS:") |
| | print(f" Text - NN Acc: {results_kaggle['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['text_color']['separation_score']:.4f}") |
| | print(f" Image - NN Acc: {results_kaggle['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['image_color']['separation_score']:.4f}") |
| | |
| | |
| | print("\n" + "="*60) |
| | print("๐ Starting evaluation of Baseline Fashion CLIP on KAGL Marqo") |
| | print("="*60) |
| | results_baseline_kaggle = evaluator.evaluate_baseline_kaggle_marqo(max_samples=max_samples) |
| | |
| | print(f"\n{'='*60}") |
| | print("BASELINE KAGL MARQO EVALUATION SUMMARY") |
| | print(f"{'='*60}") |
| | |
| | print("\n๐จ COLOR CLASSIFICATION RESULTS (Baseline):") |
| | print(f" Text - NN Acc: {results_baseline_kaggle['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['text']['color']['separation_score']:.4f}") |
| | print(f" Image - NN Acc: {results_baseline_kaggle['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['image']['color']['separation_score']:.4f}") |
| |
|
| | |
| | print("\n" + "="*60) |
| | print("๐ Starting evaluation of Local Validation Dataset with Color embeddings") |
| | print("="*60) |
| | results_local = evaluator.evaluate_local_validation(max_samples=max_samples) |
| |
|
| | if results_local is not None: |
| | print(f"\n{'='*60}") |
| | print("LOCAL VALIDATION DATASET EVALUATION SUMMARY") |
| | print(f"{'='*60}") |
| | |
| | print("\n๐จ COLOR CLASSIFICATION RESULTS:") |
| | print(f" Text - NN Acc: {results_local['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['text_color']['separation_score']:.4f}") |
| | print(f" Image - NN Acc: {results_local['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['image_color']['separation_score']:.4f}") |
| | |
| | |
| | print("\n" + "="*60) |
| | print("๐ Starting evaluation of Baseline Fashion CLIP on Local Validation") |
| | print("="*60) |
| | results_baseline_local = evaluator.evaluate_baseline_local_validation(max_samples=max_samples) |
| | |
| | if results_baseline_local is not None: |
| | print(f"\n{'='*60}") |
| | print("BASELINE LOCAL VALIDATION EVALUATION SUMMARY") |
| | print(f"{'='*60}") |
| | |
| | print("\n๐จ COLOR CLASSIFICATION RESULTS (Baseline):") |
| | print(f" Text - NN Acc: {results_baseline_local['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['text']['color']['separation_score']:.4f}") |
| | print(f" Image - NN Acc: {results_baseline_local['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['color']['separation_score']:.4f}") |
| | |
| | |
| | print(f"\nโ
Evaluation completed! Check '{directory}/' for visualization files.") |
| |
|