gap-clip / evaluation /color_evaluation.py
Leacb4's picture
Upload evaluation/color_evaluation.py with huggingface_hub
d98569a verified
import os
import json
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import difflib
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from collections import defaultdict
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from io import BytesIO
import warnings
warnings.filterwarnings('ignore')
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
from config import (
color_model_path,
color_emb_dim,
local_dataset_path,
column_local_image_path,
tokeniser_path,
)
from color_model import ColorCLIP, Tokenizer
class KaggleDataset(Dataset):
"""Dataset class for KAGL Marqo dataset"""
def __init__(self, dataframe, image_size=224):
self.dataframe = dataframe
self.image_size = image_size
# Transforms for validation (no augmentation)
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # AUGMENTATION
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
row = self.dataframe.iloc[idx]
# Handle image - it should be in row['image_url'] and contain the image data as bytes
image_data = row['image_url']
# Check if image_data has 'bytes' key or is already PIL Image
if isinstance(image_data, dict) and 'bytes' in image_data:
image = Image.open(BytesIO(image_data['bytes'])).convert("RGB")
elif hasattr(image_data, 'convert'): # Already a PIL Image
image = image_data.convert("RGB")
else:
# Assume it's raw bytes
image = Image.open(BytesIO(image_data)).convert("RGB")
# Apply validation transform
image = self.transform(image)
# Get text and labels
description = row['text']
color = row['color']
return image, description, color
def load_kaggle_marqo_dataset(max_samples=5000):
"""Load and prepare Kaggle KAGL dataset with memory optimization"""
from datasets import load_dataset
print("๐Ÿ“Š Loading Kaggle KAGL dataset...")
# Load the dataset
dataset = load_dataset("Marqo/KAGL")
df = dataset["data"].to_pandas()
print(f"โœ… Dataset Kaggle loaded")
print(f" Before filtering: {len(df)} samples")
print(f" Available columns: {list(df.columns)}")
# Ensure we have text and image data
df = df.dropna(subset=['text', 'image'])
print(f" After removing missing text/image: {len(df)} samples")
df_test = df.copy()
# Limit to max_samples with RANDOM SAMPLING to get diverse colors
if len(df_test) > max_samples:
df_test = df_test.sample(n=max_samples, random_state=42)
print(f"๐Ÿ“Š Randomly sampled {max_samples} samples from Kaggle dataset")
# Create formatted dataset with proper column names
kaggle_formatted = pd.DataFrame({
'image_url': df_test['image'], # This contains image data as bytes
'text': df_test['text'],
'color': df_test['baseColour'].str.lower().str.replace("grey", "gray") # Use actual colors
})
# Filter out rows with None/NaN colors
before_color_filter = len(kaggle_formatted)
kaggle_formatted = kaggle_formatted.dropna(subset=['color'])
if len(kaggle_formatted) < before_color_filter:
print(f" After removing missing colors: {len(kaggle_formatted)} samples (removed {before_color_filter - len(kaggle_formatted)} samples)")
# Filter for colors that were used during training (11 colors)
valid_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
before_valid_filter = len(kaggle_formatted)
kaggle_formatted = kaggle_formatted[kaggle_formatted['color'].isin(valid_colors)]
print(f" After filtering for valid colors: {len(kaggle_formatted)} samples (removed {before_valid_filter - len(kaggle_formatted)} samples)")
print(f" Valid colors found: {sorted(kaggle_formatted['color'].unique())}")
print(f" Final dataset size: {len(kaggle_formatted)} samples")
# Show color distribution in final dataset
print(f"๐ŸŽจ Color distribution in Kaggle dataset:")
color_counts = kaggle_formatted['color'].value_counts()
for color in color_counts.index:
print(f" {color}: {color_counts[color]} samples")
return KaggleDataset(kaggle_formatted)
class LocalDataset(Dataset):
"""Dataset class for local validation dataset"""
def __init__(self, dataframe, image_size=224):
self.dataframe = dataframe
self.image_size = image_size
# Transforms for validation (no augmentation)
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # AUGMENTATION
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
row = self.dataframe.iloc[idx]
# Load image from local path
image_path = row[column_local_image_path]
try:
image = Image.open(image_path).convert("RGB")
except Exception as e:
print(f"Error loading image at index {idx} from {image_path}: {e}")
# Create a dummy image if loading fails
image = Image.new('RGB', (224, 224), color='gray')
# Apply validation transform
image = self.transform(image)
# Get text and labels
description = row['text']
color = row['color']
return image, description, color
def load_local_validation_dataset(max_samples=5000):
"""Load and prepare local validation dataset"""
print("๐Ÿ“Š Loading local validation dataset...")
df = pd.read_csv(local_dataset_path)
print(f"โœ… Dataset loaded: {len(df)} samples")
# Filter out rows with NaN values in image path
df_clean = df.dropna(subset=[column_local_image_path])
print(f"๐Ÿ“Š After filtering NaN image paths: {len(df_clean)} samples")
# Filter for colors that were used during training (11 colors)
valid_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
if 'color' in df_clean.columns:
before_valid_filter = len(df_clean)
df_clean = df_clean[df_clean['color'].isin(valid_colors)]
print(f"๐Ÿ“Š After filtering for valid colors: {len(df_clean)} samples (removed {before_valid_filter - len(df_clean)} samples)")
print(f"๐ŸŽจ Valid colors found: {sorted(df_clean['color'].unique())}")
# Limit to max_samples with RANDOM SAMPLING to get diverse colors
if len(df_clean) > max_samples:
df_clean = df_clean.sample(n=max_samples, random_state=42)
print(f"๐Ÿ“Š Randomly sampled {max_samples} samples")
print(f"๐Ÿ“Š Using {len(df_clean)} samples for evaluation")
# Show color distribution after sampling
if 'color' in df_clean.columns:
print(f"๐ŸŽจ Color distribution in sampled data:")
color_counts = df_clean['color'].value_counts()
print(f" Total unique colors: {len(color_counts)}")
for color in color_counts.index[:15]: # Show top 15
print(f" {color}: {color_counts[color]} samples")
return LocalDataset(df_clean)
def collate_fn_filter_none(batch):
"""Collate function that filters out None values from batch with debug print"""
# Filter out None values
original_len = len(batch)
batch = [item for item in batch if item is not None]
if original_len > len(batch):
print(f"โš ๏ธ Filtered out {original_len - len(batch)} None values from batch (original: {original_len}, filtered: {len(batch)})")
if len(batch) == 0:
# Return empty batch with correct structure
print("โš ๏ธ Empty batch after filtering None values")
return torch.tensor([]), [], []
images, texts, colors = zip(*batch)
images = torch.stack(images, dim=0)
return images, list(texts), list(colors)
class ColorEvaluator:
"""Evaluate color 16 embeddings"""
def __init__(self, device='mps', directory="color_model_analysis"):
self.device = torch.device(device)
self.directory = directory
self.color_emb_dim = color_emb_dim
os.makedirs(self.directory, exist_ok=True)
# Load baseline Fashion CLIP model
print("๐Ÿ“ฆ Loading baseline Fashion CLIP model...")
patrick_model_name = "patrickjohncyh/fashion-clip"
self.baseline_processor = CLIPProcessor.from_pretrained(patrick_model_name)
self.baseline_model = CLIPModel_transformers.from_pretrained(patrick_model_name).to(self.device)
self.baseline_model.eval()
print("โœ… Baseline Fashion CLIP model loaded successfully")
# Load specialized color model (16D)
self.color_model = None
self.color_tokenizer = None
self._load_color_model()
def _load_color_model(self):
"""Load the specialized 16D color model and tokenizer."""
if self.color_model is not None and self.color_tokenizer is not None:
return
if not os.path.exists(color_model_path):
raise FileNotFoundError(f"Color model file {color_model_path} not found")
if not os.path.exists(tokeniser_path):
raise FileNotFoundError(f"Tokenizer vocab file {tokeniser_path} not found")
print("๐ŸŽจ Loading specialized color model (16D)...")
# Load checkpoint first to get the actual vocab size
state_dict = torch.load(color_model_path, map_location=self.device)
# Get vocab size from the embedding weight shape in checkpoint
vocab_size = state_dict['text_encoder.embedding.weight'].shape[0]
print(f" Detected vocab size from checkpoint: {vocab_size}")
# Load tokenizer vocab
with open(tokeniser_path, "r") as f:
vocab = json.load(f)
self.color_tokenizer = Tokenizer()
self.color_tokenizer.load_vocab(vocab)
# Create model with the vocab size from checkpoint (not from tokenizer)
self.color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=self.color_emb_dim)
# Load state dict
self.color_model.load_state_dict(state_dict)
self.color_model.to(self.device)
self.color_model.eval()
print("โœ… Color model loaded successfully")
def _tokenize_color_texts(self, texts):
"""Tokenize texts with the color tokenizer and return padded tensors."""
token_lists = [self.color_tokenizer(t) for t in texts]
max_len = max((len(toks) for toks in token_lists), default=0)
max_len = max_len if max_len > 0 else 1
input_ids = torch.zeros(len(texts), max_len, dtype=torch.long, device=self.device)
lengths = torch.zeros(len(texts), dtype=torch.long, device=self.device)
for i, toks in enumerate(token_lists):
if len(toks) > 0:
input_ids[i, :len(toks)] = torch.tensor(toks, dtype=torch.long, device=self.device)
lengths[i] = len(toks)
else:
lengths[i] = 1 # avoid zero-length
return input_ids, lengths
def extract_color_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
"""Extract 16D color embeddings from specialized color model."""
self._load_color_model()
all_embeddings = []
all_colors = []
sample_count = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} color embeddings"):
if sample_count >= max_samples:
break
images, texts, colors = batch
images = images.to(self.device)
images = images.expand(-1, 3, -1, -1)
if embedding_type == 'text':
input_ids, lengths = self._tokenize_color_texts(texts)
embeddings = self.color_model.text_encoder(input_ids, lengths)
elif embedding_type == 'image':
embeddings = self.color_model.image_encoder(images)
else:
input_ids, lengths = self._tokenize_color_texts(texts)
embeddings = self.color_model.text_encoder(input_ids, lengths)
all_embeddings.append(embeddings.cpu().numpy())
normalized_colors = [str(c).lower().strip().replace("grey", "gray") for c in colors]
all_colors.extend(normalized_colors)
sample_count += len(images)
del images, embeddings
if embedding_type != 'image':
del input_ids, lengths
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return np.vstack(all_embeddings), all_colors
def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000):
"""Extract embeddings from baseline Fashion CLIP model"""
all_embeddings = []
all_colors = []
sample_count = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc=f"Extracting baseline {embedding_type} embeddings"):
if sample_count >= max_samples:
break
images, texts, colors = batch
images = images.to(self.device)
images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
# Process text inputs with baseline processor
text_inputs = self.baseline_processor(text=texts, padding=True, return_tensors="pt")
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
# Forward pass through baseline model
outputs = self.baseline_model(**text_inputs, pixel_values=images)
# Extract embeddings based on type
if embedding_type == 'text':
embeddings = outputs.text_embeds
elif embedding_type == 'image':
embeddings = outputs.image_embeds
else:
embeddings = outputs.text_embeds
all_embeddings.append(embeddings.cpu().numpy())
all_colors.extend(colors)
sample_count += len(images)
# Clear GPU memory
del images, text_inputs, outputs, embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return np.vstack(all_embeddings), all_colors
def compute_similarity_metrics(self, embeddings, labels):
"""Compute intra-class and inter-class similarities - optimized version"""
max_samples = min(5000, len(embeddings))
if len(embeddings) > max_samples:
indices = np.random.choice(len(embeddings), max_samples, replace=False)
embeddings = embeddings[indices]
labels = [labels[i] for i in indices]
similarities = cosine_similarity(embeddings)
# Create label groups using numpy for faster indexing
label_array = np.array(labels)
unique_labels = np.unique(label_array)
label_groups = {label: np.where(label_array == label)[0] for label in unique_labels}
# Compute intra-class similarities using vectorized operations
intra_class_similarities = []
for label, indices in label_groups.items():
if len(indices) > 1:
# Extract submatrix for this class
class_similarities = similarities[np.ix_(indices, indices)]
# Get upper triangle (excluding diagonal)
triu_indices = np.triu_indices_from(class_similarities, k=1)
intra_class_similarities.extend(class_similarities[triu_indices].tolist())
# Compute inter-class similarities using vectorized operations
inter_class_similarities = []
labels_list = list(label_groups.keys())
for i in range(len(labels_list)):
for j in range(i + 1, len(labels_list)):
label1_indices = label_groups[labels_list[i]]
label2_indices = label_groups[labels_list[j]]
# Extract submatrix between two classes
inter_sims = similarities[np.ix_(label1_indices, label2_indices)]
inter_class_similarities.extend(inter_sims.flatten().tolist())
nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
return {
'intra_class_similarities': intra_class_similarities,
'inter_class_similarities': inter_class_similarities,
'intra_class_mean': float(np.mean(intra_class_similarities)) if intra_class_similarities else 0.0,
'inter_class_mean': float(np.mean(inter_class_similarities)) if inter_class_similarities else 0.0,
'separation_score': float(np.mean(intra_class_similarities) - np.mean(inter_class_similarities)) if intra_class_similarities and inter_class_similarities else 0.0,
'accuracy': nn_accuracy,
'centroid_accuracy': centroid_accuracy,
}
def compute_embedding_accuracy(self, embeddings, labels, similarities):
"""Compute classification accuracy using nearest neighbor"""
correct_predictions = 0
total_predictions = len(labels)
for i in range(len(embeddings)):
true_label = labels[i]
similarities_row = similarities[i].copy()
similarities_row[i] = -1
nearest_neighbor_idx = int(np.argmax(similarities_row))
predicted_label = labels[nearest_neighbor_idx]
if predicted_label == true_label:
correct_predictions += 1
return correct_predictions / total_predictions if total_predictions > 0 else 0.0
def compute_centroid_accuracy(self, embeddings, labels):
"""Compute classification accuracy using centroids - optimized vectorized version"""
unique_labels = list(set(labels))
# Compute centroids efficiently
centroids = {}
for label in unique_labels:
label_mask = np.array(labels) == label
centroids[label] = np.mean(embeddings[label_mask], axis=0)
# Stack centroids for vectorized similarity computation
centroid_matrix = np.vstack([centroids[label] for label in unique_labels])
# Compute all similarities at once
similarities = cosine_similarity(embeddings, centroid_matrix)
# Get predicted labels
predicted_indices = np.argmax(similarities, axis=1)
predicted_labels = [unique_labels[idx] for idx in predicted_indices]
# Compute accuracy
correct_predictions = sum(pred == true for pred, true in zip(predicted_labels, labels))
return correct_predictions / len(labels) if len(labels) > 0 else 0.0
def predict_labels_from_embeddings(self, embeddings, labels):
"""Predict labels from embeddings using centroid-based classification - optimized vectorized version"""
# Filter out None labels when computing centroids
unique_labels = [l for l in set(labels) if l is not None]
if len(unique_labels) == 0:
# If no valid labels, return None for all predictions
return [None] * len(embeddings)
# Compute centroids efficiently
centroids = {}
for label in unique_labels:
label_mask = np.array(labels) == label
if np.any(label_mask):
centroids[label] = np.mean(embeddings[label_mask], axis=0)
# Stack centroids for vectorized similarity computation
centroid_labels = list(centroids.keys())
centroid_matrix = np.vstack([centroids[label] for label in centroid_labels])
# Compute all similarities at once
similarities = cosine_similarity(embeddings, centroid_matrix)
# Get predicted labels
predicted_indices = np.argmax(similarities, axis=1)
predictions = [centroid_labels[idx] for idx in predicted_indices]
return predictions
def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix", label_type="Label"):
"""Create and plot confusion matrix"""
unique_labels = sorted(list(set(true_labels + predicted_labels)))
cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
accuracy = accuracy_score(true_labels, predicted_labels)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=unique_labels, yticklabels=unique_labels)
plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
plt.ylabel(f'True {label_type}')
plt.xlabel(f'Predicted {label_type}')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
return plt.gcf(), accuracy, cm
def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label"):
"""
Evaluate classification performance and create confusion matrix.
Args:
embeddings: Embeddings
labels: True labels
embedding_type: Type of embeddings for display
label_type: Type of labels (Color)
full_embeddings: Optional full 512-dim embeddings for ensemble (if None, uses only embeddings)
ensemble_weight: Weight for embeddings in ensemble (0.0 = only full, 1.0 = only embeddings)
"""
predictions = self.predict_labels_from_embeddings(embeddings, labels)
title_suffix = ""
# Filter out None values from labels and predictions
valid_indices = [i for i, (label, pred) in enumerate(zip(labels, predictions))
if label is not None and pred is not None]
if len(valid_indices) == 0:
print(f"โš ๏ธ Warning: No valid labels/predictions found (all are None)")
return {
'accuracy': 0.0,
'predictions': predictions,
'confusion_matrix': None,
'classification_report': None,
'figure': None,
}
filtered_labels = [labels[i] for i in valid_indices]
filtered_predictions = [predictions[i] for i in valid_indices]
accuracy = accuracy_score(filtered_labels, filtered_predictions)
fig, acc, cm = self.create_confusion_matrix(
filtered_labels, filtered_predictions,
f"{embedding_type} - {label_type} Classification{title_suffix}",
label_type
)
unique_labels = sorted(list(set(filtered_labels)))
report = classification_report(filtered_labels, filtered_predictions, labels=unique_labels, target_names=unique_labels, output_dict=True)
return {
'accuracy': accuracy,
'predictions': predictions,
'confusion_matrix': cm,
'classification_report': report,
'figure': fig,
}
def evaluate_kaggle_marqo(self, max_samples):
"""Evaluate both color embeddings on KAGL Marqo dataset"""
print(f"\n{'='*60}")
print("Evaluating KAGL Marqo Dataset with Color embeddings")
print(f"Max samples: {max_samples}")
print(f"{'='*60}")
kaggle_dataset = load_kaggle_marqo_dataset(max_samples)
if kaggle_dataset is None:
print("โŒ Failed to load KAGL dataset")
return None
dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn_filter_none)
results = {}
# ========== EXTRACT BASELINE EMBEDDINGS ==========
print("\n๐Ÿ“ฆ Extracting baseline embeddings...")
text_full_embeddings, text_colors_full = self.extract_color_embeddings(dataloader, embedding_type='text', max_samples=max_samples)
image_full_embeddings, image_colors_full = self.extract_color_embeddings(dataloader, embedding_type='image', max_samples=max_samples)
text_color_metrics = self.compute_similarity_metrics(text_full_embeddings, text_colors_full)
text_color_class = self.evaluate_classification_performance(
text_full_embeddings, text_colors_full,
"Text Color Embeddings (Baseline)", "Color",
)
text_color_metrics.update(text_color_class)
results['text_color'] = text_color_metrics
image_color_metrics = self.compute_similarity_metrics(image_full_embeddings, image_colors_full)
image_color_class = self.evaluate_classification_performance(
image_full_embeddings, image_colors_full,
"Image Color Embeddings (Baseline)", "Color",
)
image_color_metrics.update(image_color_class)
results['image_color'] = image_color_metrics
del text_full_embeddings, image_full_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# ========== SAVE VISUALIZATIONS ==========
os.makedirs(self.directory, exist_ok=True)
for key in ['text_color', 'image_color']:
results[key]['figure'].savefig(
f"{self.directory}/kaggle_{key.replace('_', '_')}_confusion_matrix.png",
dpi=300,
bbox_inches='tight',
)
plt.close(results[key]['figure'])
return results
def evaluate_local_validation(self, max_samples):
"""Evaluate both color embeddings on local validation dataset"""
print(f"\n{'='*60}")
print("Evaluating Local Validation Dataset")
print(" Color embeddings")
print(f"Max samples: {max_samples}")
print(f"{'='*60}")
local_dataset = load_local_validation_dataset(max_samples)
dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0)
results = {}
# ========== COLOR EVALUATION ==========
print("\n๐ŸŽจ COLOR EVALUATION ")
print("=" * 50)
# Text color embeddings
print("\n๐Ÿ“ Extracting text color embeddings...")
text_color_embeddings, text_colors = self.extract_color_embeddings(dataloader, 'text', max_samples)
print(f" Text color embeddings shape: {text_color_embeddings.shape}")
text_color_metrics = self.compute_similarity_metrics(text_color_embeddings, text_colors)
text_color_class = self.evaluate_classification_performance(
text_color_embeddings, text_colors, "Text Color Embeddings (Baseline)", "Color"
)
text_color_metrics.update(text_color_class)
results['text_color'] = text_color_metrics
del text_color_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Image color embeddings
print("\n๐Ÿ–ผ๏ธ Extracting image color embeddings...")
image_color_embeddings, image_colors = self.extract_color_embeddings(dataloader, 'image', max_samples)
print(f" Image color embeddings shape: {image_color_embeddings.shape}")
image_color_metrics = self.compute_similarity_metrics(image_color_embeddings, image_colors)
image_color_class = self.evaluate_classification_performance(
image_color_embeddings, image_colors, "Image Color Embeddings (Baseline)", "Color"
)
image_color_metrics.update(image_color_class)
results['image_color'] = image_color_metrics
del image_color_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# ========== SAVE VISUALIZATIONS ==========
os.makedirs(self.directory, exist_ok=True)
for key in ['text_color', 'image_color']:
results[key]['figure'].savefig(
f"{self.directory}/local_{key.replace('_', '_')}_confusion_matrix.png",
dpi=300,
bbox_inches='tight',
)
plt.close(results[key]['figure'])
return results
def evaluate_baseline_kaggle_marqo(self, max_samples=5000):
"""Evaluate baseline Fashion CLIP model on KAGL Marqo dataset"""
print(f"\n{'='*60}")
print("Evaluating Baseline Fashion CLIP on KAGL Marqo Dataset")
print(f"Max samples: {max_samples}")
print(f"{'='*60}")
# Load KAGL Marqo dataset
kaggle_dataset = load_kaggle_marqo_dataset(max_samples)
if kaggle_dataset is None:
print("โŒ Failed to load KAGL dataset")
return None
# Create dataloader
dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn_filter_none)
results = {}
# Evaluate text embeddings
print("\n๐Ÿ“ Extracting baseline text embeddings from KAGL Marqo...")
text_embeddings, text_colors = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples)
print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)")
text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors)
text_color_classification = self.evaluate_classification_performance(
text_embeddings, text_colors, "Baseline KAGL Marqo Text Embeddings - Color", "Color"
)
text_color_metrics.update(text_color_classification)
results['text'] = {
'color': text_color_metrics
}
# Clear memory
del text_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Evaluate image embeddings
print("\n๐Ÿ–ผ๏ธ Extracting baseline image embeddings from KAGL Marqo...")
image_embeddings, image_colors = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples)
print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)")
image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors)
image_color_classification = self.evaluate_classification_performance(
image_embeddings, image_colors, "Baseline KAGL Marqo Image Embeddings - Color", "Color"
)
image_color_metrics.update(image_color_classification)
results['image'] = {
'color': image_color_metrics
}
# Clear memory
del image_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# ========== SAVE VISUALIZATIONS ==========
os.makedirs(self.directory, exist_ok=True)
for key in ['text', 'image']:
for subkey in ['color']:
figure = results[key][subkey]['figure']
figure.savefig(
f"{self.directory}/kaggle_baseline_{key}_{subkey}_confusion_matrix.png",
dpi=300,
bbox_inches='tight',
)
plt.close(figure)
return results
def evaluate_baseline_local_validation(self, max_samples=5000):
"""Evaluate baseline Fashion CLIP model on local validation dataset"""
print(f"\n{'='*60}")
print("Evaluating Baseline Fashion CLIP on Local Validation Dataset")
print(f"Max samples: {max_samples}")
print(f"{'='*60}")
# Load local validation dataset
local_dataset = load_local_validation_dataset(max_samples)
if local_dataset is None:
print("โŒ Failed to load local validation dataset")
return None
# Create dataloader
dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0)
results = {}
# Evaluate text embeddings
print("\n๐Ÿ“ Extracting baseline text embeddings from Local Validation...")
text_embeddings, text_colors = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples)
print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)")
text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors)
text_color_classification = self.evaluate_classification_performance(
text_embeddings, text_colors, "Baseline Local Validation Text Embeddings - Color", "Color"
)
text_color_metrics.update(text_color_classification)
results['text'] = {
'color': text_color_metrics
}
# Clear memory
del text_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Evaluate image embeddings
print("\n๐Ÿ–ผ๏ธ Extracting baseline image embeddings from Local Validation...")
image_embeddings, image_colors = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples)
print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)")
image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors)
image_color_classification = self.evaluate_classification_performance(
image_embeddings, image_colors, "Baseline Local Validation Image Embeddings - Color", "Color"
)
image_color_metrics.update(image_color_classification)
results['image'] = {
'color': image_color_metrics
}
# Clear memory
del image_embeddings
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# ========== SAVE VISUALIZATIONS ==========
os.makedirs(self.directory, exist_ok=True)
for key in ['text', 'image']:
for subkey in ['color']:
figure = results[key][subkey]['figure']
figure.savefig(
f"{self.directory}/local_baseline_{key}_{subkey}_confusion_matrix.png",
dpi=300,
bbox_inches='tight',
)
plt.close(figure)
return results
if __name__ == "__main__":
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
directory = 'color_model_analysis'
max_samples = 10000
evaluator = ColorEvaluator(device=device, directory=directory)
# Evaluate KAGL Marqo
print("\n" + "="*60)
print("๐Ÿš€ Starting evaluation of KAGL Marqo with Color embeddings")
print("="*60)
results_kaggle = evaluator.evaluate_kaggle_marqo(max_samples=max_samples)
print(f"\n{'='*60}")
print("KAGL MARQO EVALUATION SUMMARY")
print(f"{'='*60}")
print("\n๐ŸŽจ COLOR CLASSIFICATION RESULTS:")
print(f" Text - NN Acc: {results_kaggle['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['text_color']['separation_score']:.4f}")
print(f" Image - NN Acc: {results_kaggle['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['image_color']['separation_score']:.4f}")
# Evaluate Baseline Fashion CLIP on KAGL Marqo
print("\n" + "="*60)
print("๐Ÿš€ Starting evaluation of Baseline Fashion CLIP on KAGL Marqo")
print("="*60)
results_baseline_kaggle = evaluator.evaluate_baseline_kaggle_marqo(max_samples=max_samples)
print(f"\n{'='*60}")
print("BASELINE KAGL MARQO EVALUATION SUMMARY")
print(f"{'='*60}")
print("\n๐ŸŽจ COLOR CLASSIFICATION RESULTS (Baseline):")
print(f" Text - NN Acc: {results_baseline_kaggle['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['text']['color']['separation_score']:.4f}")
print(f" Image - NN Acc: {results_baseline_kaggle['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['image']['color']['separation_score']:.4f}")
# Evaluate Local Validation Dataset
print("\n" + "="*60)
print("๐Ÿš€ Starting evaluation of Local Validation Dataset with Color embeddings")
print("="*60)
results_local = evaluator.evaluate_local_validation(max_samples=max_samples)
if results_local is not None:
print(f"\n{'='*60}")
print("LOCAL VALIDATION DATASET EVALUATION SUMMARY")
print(f"{'='*60}")
print("\n๐ŸŽจ COLOR CLASSIFICATION RESULTS:")
print(f" Text - NN Acc: {results_local['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['text_color']['separation_score']:.4f}")
print(f" Image - NN Acc: {results_local['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['image_color']['separation_score']:.4f}")
# Evaluate Baseline Fashion CLIP on Local Validation
print("\n" + "="*60)
print("๐Ÿš€ Starting evaluation of Baseline Fashion CLIP on Local Validation")
print("="*60)
results_baseline_local = evaluator.evaluate_baseline_local_validation(max_samples=max_samples)
if results_baseline_local is not None:
print(f"\n{'='*60}")
print("BASELINE LOCAL VALIDATION EVALUATION SUMMARY")
print(f"{'='*60}")
print("\n๐ŸŽจ COLOR CLASSIFICATION RESULTS (Baseline):")
print(f" Text - NN Acc: {results_baseline_local['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['text']['color']['separation_score']:.4f}")
print(f" Image - NN Acc: {results_baseline_local['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['color']['separation_score']:.4f}")
print(f"\nโœ… Evaluation completed! Check '{directory}/' for visualization files.")