|
|
import os |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
import torch |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
import difflib |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score |
|
|
from collections import defaultdict |
|
|
from tqdm import tqdm |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
import warnings |
|
|
warnings.filterwarnings('ignore') |
|
|
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers |
|
|
|
|
|
from config import main_model_path, hierarchy_model_path, color_model_path, color_emb_dim, hierarchy_emb_dim, local_dataset_path, column_local_image_path |
|
|
|
|
|
|
|
|
def create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes): |
|
|
"""Create mapping from Fashion-MNIST labels to hierarchy classes""" |
|
|
|
|
|
fashion_mnist_labels = { |
|
|
0: "T-shirt/top", |
|
|
1: "Trouser", |
|
|
2: "Pullover", |
|
|
3: "Dress", |
|
|
4: "Coat", |
|
|
5: "Sandal", |
|
|
6: "Shirt", |
|
|
7: "Sneaker", |
|
|
8: "Bag", |
|
|
9: "Ankle boot", |
|
|
} |
|
|
|
|
|
|
|
|
hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] |
|
|
|
|
|
|
|
|
mapping = {} |
|
|
|
|
|
for fm_label_id, fm_label in fashion_mnist_labels.items(): |
|
|
fm_label_lower = fm_label.lower() |
|
|
matched_hierarchy = None |
|
|
|
|
|
|
|
|
if fm_label_lower in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(fm_label_lower)] |
|
|
|
|
|
elif any(h in fm_label_lower or fm_label_lower in h for h in hierarchy_classes_lower): |
|
|
for h_class in hierarchy_classes: |
|
|
h_lower = h_class.lower() |
|
|
if h_lower in fm_label_lower or fm_label_lower in h_lower: |
|
|
matched_hierarchy = h_class |
|
|
break |
|
|
|
|
|
else: |
|
|
|
|
|
if fm_label_lower in ['t-shirt/top', 'top']: |
|
|
if 'top' in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('top')] |
|
|
|
|
|
|
|
|
elif 'trouser' in fm_label_lower: |
|
|
for possible in ['bottom', 'pants', 'trousers', 'trouser', 'pant']: |
|
|
if possible in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)] |
|
|
break |
|
|
|
|
|
|
|
|
elif 'pullover' in fm_label_lower: |
|
|
for possible in ['sweater', 'pullover']: |
|
|
if possible in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)] |
|
|
break |
|
|
|
|
|
|
|
|
elif 'dress' in fm_label_lower: |
|
|
if 'dress' in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('dress')] |
|
|
|
|
|
elif 'coat' in fm_label_lower: |
|
|
for possible in ['jacket', 'outerwear', 'coat']: |
|
|
if possible in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)] |
|
|
break |
|
|
|
|
|
elif fm_label_lower in ['sandal', 'sneaker', 'ankle boot']: |
|
|
for possible in ['shoes', 'shoe', 'sandal', 'sneaker', 'boot']: |
|
|
if possible in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(possible)] |
|
|
break |
|
|
|
|
|
elif 'bag' in fm_label_lower: |
|
|
if 'bag' in hierarchy_classes_lower: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index('bag')] |
|
|
|
|
|
if matched_hierarchy is None: |
|
|
close_matches = difflib.get_close_matches(fm_label_lower, hierarchy_classes_lower, n=1, cutoff=0.6) |
|
|
if close_matches: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(close_matches[0])] |
|
|
|
|
|
mapping[fm_label_id] = matched_hierarchy |
|
|
if matched_hierarchy: |
|
|
print(f" {fm_label} ({fm_label_id}) -> {matched_hierarchy}") |
|
|
else: |
|
|
print(f" β οΈ {fm_label} ({fm_label_id}) -> NO MATCH (will be filtered out)") |
|
|
|
|
|
return mapping |
|
|
|
|
|
|
|
|
def convert_fashion_mnist_to_image(pixel_values): |
|
|
image_array = np.array(pixel_values).reshape(28, 28).astype(np.uint8) |
|
|
image_array = np.stack([image_array] * 3, axis=-1) |
|
|
image = Image.fromarray(image_array) |
|
|
return image |
|
|
|
|
|
|
|
|
def get_fashion_mnist_labels(): |
|
|
return { |
|
|
0: "T-shirt/top", |
|
|
1: "Trouser", |
|
|
2: "Pullover", |
|
|
3: "Dress", |
|
|
4: "Coat", |
|
|
5: "Sandal", |
|
|
6: "Shirt", |
|
|
7: "Sneaker", |
|
|
8: "Bag", |
|
|
9: "Ankle boot", |
|
|
} |
|
|
|
|
|
|
|
|
class FashionMNISTDataset(Dataset): |
|
|
def __init__(self, dataframe, image_size=224, label_mapping=None): |
|
|
self.dataframe = dataframe |
|
|
self.image_size = image_size |
|
|
self.labels_map = get_fashion_mnist_labels() |
|
|
self.label_mapping = label_mapping |
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((image_size, image_size)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataframe) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
row = self.dataframe.iloc[idx] |
|
|
|
|
|
pixel_cols = [f"pixel{i}" for i in range(1, 785)] |
|
|
pixel_values = row[pixel_cols].values |
|
|
|
|
|
image = convert_fashion_mnist_to_image(pixel_values) |
|
|
image = self.transform(image) |
|
|
|
|
|
label_id = int(row['label']) |
|
|
description = self.labels_map[label_id] |
|
|
|
|
|
color = "unknown" |
|
|
|
|
|
if self.label_mapping and label_id in self.label_mapping: |
|
|
hierarchy = self.label_mapping[label_id] |
|
|
else: |
|
|
hierarchy = self.labels_map[label_id] |
|
|
|
|
|
return image, description, color, hierarchy |
|
|
|
|
|
|
|
|
def load_fashion_mnist_dataset(max_samples=1000, hierarchy_classes=None): |
|
|
print("π Loading Fashion-MNIST test dataset...") |
|
|
df = pd.read_csv("/Users/leaattiasarfati/Desktop/docs/search/old/MainModel/data/fashion-mnist_test.csv") |
|
|
print(f"β
Fashion-MNIST dataset loaded: {len(df)} samples") |
|
|
|
|
|
|
|
|
label_mapping = None |
|
|
if hierarchy_classes is not None: |
|
|
print("\nπ Creating mapping from Fashion-MNIST labels to hierarchy classes:") |
|
|
label_mapping = create_fashion_mnist_to_hierarchy_mapping(hierarchy_classes) |
|
|
|
|
|
|
|
|
valid_label_ids = [label_id for label_id, hierarchy in label_mapping.items() if hierarchy is not None] |
|
|
df_filtered = df[df['label'].isin(valid_label_ids)] |
|
|
print(f"\nπ After filtering to mappable labels: {len(df_filtered)} samples (from {len(df)})") |
|
|
|
|
|
|
|
|
df_sample = df_filtered.head(max_samples) |
|
|
else: |
|
|
df_sample = df.head(max_samples) |
|
|
|
|
|
print(f"π Using {len(df_sample)} samples for evaluation") |
|
|
return FashionMNISTDataset(df_sample, label_mapping=label_mapping) |
|
|
|
|
|
|
|
|
def create_kaggle_marqo_to_hierarchy_mapping(kaggle_labels, hierarchy_classes): |
|
|
"""Create mapping from Kaggle Marqo categories to hierarchy classes""" |
|
|
hierarchy_classes = list(hierarchy_classes) |
|
|
hierarchy_classes_lower = [h.lower() for h in hierarchy_classes] |
|
|
|
|
|
synonyms = { |
|
|
'topwear': 'top', |
|
|
'tops': 'top', |
|
|
'tee': 'top', |
|
|
'tees': 'top', |
|
|
't-shirt': 'top', |
|
|
'tshirt': 'top', |
|
|
'tshirts': 'top', |
|
|
'shirt': 'shirt', |
|
|
'shirts': 'shirt', |
|
|
'sweater': 'sweater', |
|
|
'sweaters': 'sweater', |
|
|
'outerwear': 'coat', |
|
|
'outer': 'coat', |
|
|
'coat': 'coat', |
|
|
'coats': 'coat', |
|
|
'jacket': 'coat', |
|
|
'jackets': 'coat', |
|
|
'blazer': 'coat', |
|
|
'blazers': 'coat', |
|
|
'hoodie': 'hoodie', |
|
|
'hoodies': 'hoodie', |
|
|
'bottomwear': 'bottom', |
|
|
'bottoms': 'bottom', |
|
|
'pants': 'bottom', |
|
|
'pant': 'bottom', |
|
|
'trouser': 'bottom', |
|
|
'trousers': 'bottom', |
|
|
'jeans': 'jeans', |
|
|
'denim': 'jeans', |
|
|
'short': 'shorts', |
|
|
'shorts': 'shorts', |
|
|
'skirt': 'skirt', |
|
|
'skirts': 'skirt', |
|
|
'dress': 'dress', |
|
|
'dresses': 'dress', |
|
|
'saree': 'saree', |
|
|
'lehenga': 'lehenga', |
|
|
'shoe': 'shoes', |
|
|
'shoes': 'shoes', |
|
|
'sandal': 'shoes', |
|
|
'sandals': 'shoes', |
|
|
'sneaker': 'shoes', |
|
|
'sneakers': 'shoes', |
|
|
'boot': 'shoes', |
|
|
'boots': 'shoes', |
|
|
'heel': 'shoes', |
|
|
'heels': 'shoes', |
|
|
'flip flops': 'shoes', |
|
|
'flip-flops': 'shoes', |
|
|
'loafer': 'shoes', |
|
|
'loafers': 'shoes', |
|
|
'bag': 'bag', |
|
|
'bags': 'bag', |
|
|
'backpack': 'bag', |
|
|
'backpacks': 'bag', |
|
|
'handbag': 'bag', |
|
|
'handbags': 'bag', |
|
|
'accessory': 'accessories', |
|
|
'accessories': 'accessories', |
|
|
'belt': 'belt', |
|
|
'belts': 'belt', |
|
|
'scarf': 'scarf', |
|
|
'scarves': 'scarf', |
|
|
'cap': 'cap', |
|
|
'caps': 'cap', |
|
|
'hat': 'cap', |
|
|
'hats': 'cap', |
|
|
'watch': 'watch', |
|
|
'watches': 'watch', |
|
|
} |
|
|
|
|
|
def match_candidate(candidate): |
|
|
if candidate in hierarchy_classes_lower: |
|
|
return hierarchy_classes[hierarchy_classes_lower.index(candidate)] |
|
|
return None |
|
|
|
|
|
mapping = {} |
|
|
|
|
|
for label in sorted(set(kaggle_labels)): |
|
|
label_str = str(label) if not pd.isna(label) else '' |
|
|
label_lower = label_str.strip().lower() |
|
|
matched_hierarchy = None |
|
|
|
|
|
if not label_lower: |
|
|
mapping[label_lower] = None |
|
|
continue |
|
|
|
|
|
|
|
|
candidate = synonyms.get(label_lower, label_lower) |
|
|
matched_hierarchy = match_candidate(candidate) |
|
|
|
|
|
|
|
|
if matched_hierarchy is None: |
|
|
for idx, h_lower in enumerate(hierarchy_classes_lower): |
|
|
if h_lower in candidate or candidate in h_lower: |
|
|
matched_hierarchy = hierarchy_classes[idx] |
|
|
break |
|
|
|
|
|
|
|
|
if matched_hierarchy is None: |
|
|
tokens = set(candidate.replace('-', ' ').replace('/', ' ').split()) |
|
|
for token in tokens: |
|
|
token_candidate = synonyms.get(token, token) |
|
|
matched_hierarchy = match_candidate(token_candidate) |
|
|
if matched_hierarchy: |
|
|
break |
|
|
|
|
|
|
|
|
if matched_hierarchy is None: |
|
|
for synonym_key, synonym_value in synonyms.items(): |
|
|
if synonym_key in candidate: |
|
|
matched_hierarchy = match_candidate(synonym_value) |
|
|
if matched_hierarchy: |
|
|
break |
|
|
|
|
|
|
|
|
if matched_hierarchy is None: |
|
|
close_matches = difflib.get_close_matches(candidate, hierarchy_classes_lower, n=1, cutoff=0.6) |
|
|
if close_matches: |
|
|
matched_hierarchy = hierarchy_classes[hierarchy_classes_lower.index(close_matches[0])] |
|
|
|
|
|
mapping[label_lower] = matched_hierarchy |
|
|
|
|
|
if matched_hierarchy: |
|
|
print(f" {label_str} -> {matched_hierarchy}") |
|
|
else: |
|
|
print(f" β οΈ {label_str} -> NO MATCH (will be filtered out)") |
|
|
|
|
|
return mapping |
|
|
|
|
|
|
|
|
class KaggleDataset(Dataset): |
|
|
"""Dataset class for KAGL Marqo dataset""" |
|
|
def __init__(self, dataframe, image_size=224): |
|
|
self.dataframe = dataframe |
|
|
self.image_size = image_size |
|
|
|
|
|
|
|
|
self.val_transform = transforms.Compose([ |
|
|
transforms.Resize((image_size, image_size)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataframe) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
row = self.dataframe.iloc[idx] |
|
|
|
|
|
|
|
|
image_data = row['image_url'] |
|
|
|
|
|
|
|
|
if isinstance(image_data, dict) and 'bytes' in image_data: |
|
|
image = Image.open(BytesIO(image_data['bytes'])).convert("RGB") |
|
|
elif hasattr(image_data, 'convert'): |
|
|
image = image_data.convert("RGB") |
|
|
else: |
|
|
|
|
|
image = Image.open(BytesIO(image_data)).convert("RGB") |
|
|
|
|
|
|
|
|
image = self.val_transform(image) |
|
|
|
|
|
|
|
|
description = row['text'] |
|
|
color = row.get('color', 'unknown') |
|
|
hierarchy = row['hierarchy'] |
|
|
|
|
|
return image, description, color, hierarchy |
|
|
|
|
|
|
|
|
def load_kaggle_marqo_dataset(evaluator, max_samples=5000): |
|
|
"""Load and prepare Kaggle KAGL dataset with memory optimization""" |
|
|
from datasets import load_dataset |
|
|
print("π Loading Kaggle KAGL dataset...") |
|
|
|
|
|
|
|
|
dataset = load_dataset("Marqo/KAGL") |
|
|
df = dataset["data"].to_pandas() |
|
|
print(f"β
Dataset Kaggle loaded") |
|
|
print(f" Before filtering: {len(df)} samples") |
|
|
print(f" Available columns: {list(df.columns)}") |
|
|
|
|
|
|
|
|
available_categories = sorted(df['category2'].dropna().unique()) |
|
|
print(f"π¨ Available categories: {available_categories}") |
|
|
|
|
|
validation_hierarchies = evaluator.validation_hierarchy_classes or evaluator.hierarchy_classes |
|
|
print(f"π Validation hierarchies: {sorted(validation_hierarchies)}") |
|
|
|
|
|
print("\nπ Creating mapping from Kaggle categories to validation hierarchies:") |
|
|
category_mapping = create_kaggle_marqo_to_hierarchy_mapping(available_categories, validation_hierarchies) |
|
|
|
|
|
total_categories = {str(cat).strip().lower() for cat in df['category2'].dropna()} |
|
|
unmapped_categories = sorted(cat for cat in total_categories if category_mapping.get(cat) is None) |
|
|
if unmapped_categories: |
|
|
print(f"β οΈ Categories without mapping (will be dropped): {unmapped_categories}") |
|
|
|
|
|
df['hierarchy'] = df['category2'].apply( |
|
|
lambda cat: category_mapping.get(str(cat).strip().lower()) if pd.notna(cat) else None |
|
|
) |
|
|
|
|
|
before_mapping_len = len(df) |
|
|
df = df[df['hierarchy'].notna()] |
|
|
print(f" After mapping to validation hierarchies: {len(df)} samples (from {before_mapping_len})") |
|
|
|
|
|
if len(df) == 0: |
|
|
print("β No samples left after hierarchy mapping.") |
|
|
return None |
|
|
|
|
|
|
|
|
df = df.dropna(subset=['text', 'image']) |
|
|
print(f" After removing missing text/image: {len(df)} samples") |
|
|
|
|
|
|
|
|
print(f"π Sample texts:") |
|
|
for i, (text, hierarchy) in enumerate(zip(df['text'].head(3), df['hierarchy'].head(3))): |
|
|
print(f" {i+1}. [{hierarchy}] {text[:100]}...") |
|
|
|
|
|
df_test = df.copy() |
|
|
|
|
|
|
|
|
if len(df_test) > max_samples: |
|
|
df_test = df_test.head(max_samples) |
|
|
|
|
|
print(f"π After sampling: {len(df_test)} samples") |
|
|
print(f" Samples per hierarchy:") |
|
|
for hierarchy in sorted(df_test['hierarchy'].unique()): |
|
|
count = len(df_test[df_test['hierarchy'] == hierarchy]) |
|
|
print(f" {hierarchy}: {count} samples") |
|
|
|
|
|
|
|
|
kaggle_formatted = pd.DataFrame({ |
|
|
'image_url': df_test['image'], |
|
|
'text': df_test['text'], |
|
|
'hierarchy': df_test['hierarchy'], |
|
|
'color': df_test['baseColour'].str.lower().str.replace("grey", "gray") |
|
|
}) |
|
|
|
|
|
print(f" Final dataset size: {len(kaggle_formatted)} samples") |
|
|
return KaggleDataset(kaggle_formatted) |
|
|
|
|
|
|
|
|
class LocalDataset(Dataset): |
|
|
"""Dataset class for local validation dataset""" |
|
|
def __init__(self, dataframe, image_size=224): |
|
|
self.dataframe = dataframe |
|
|
self.image_size = image_size |
|
|
|
|
|
|
|
|
self.val_transform = transforms.Compose([ |
|
|
transforms.Resize((image_size, image_size)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataframe) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
row = self.dataframe.iloc[idx] |
|
|
|
|
|
|
|
|
image_path = row[column_local_image_path] |
|
|
try: |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
except Exception as e: |
|
|
print(f"Error loading image at index {idx} from {image_path}: {e}") |
|
|
|
|
|
image = Image.new('RGB', (224, 224), color='gray') |
|
|
|
|
|
|
|
|
image = self.val_transform(image) |
|
|
|
|
|
|
|
|
description = row['text'] |
|
|
color = row.get('color', 'unknown') |
|
|
hierarchy = row['hierarchy'] |
|
|
|
|
|
return image, description, color, hierarchy |
|
|
|
|
|
|
|
|
def load_local_validation_dataset(max_samples=5000): |
|
|
"""Load and prepare local validation dataset""" |
|
|
print("π Loading local validation dataset...") |
|
|
|
|
|
if not os.path.exists(local_dataset_path): |
|
|
print(f"β Local dataset file not found: {local_dataset_path}") |
|
|
return None |
|
|
|
|
|
df = pd.read_csv(local_dataset_path) |
|
|
print(f"β
Dataset loaded: {len(df)} samples") |
|
|
|
|
|
|
|
|
df_clean = df.dropna(subset=[column_local_image_path]) |
|
|
print(f"π After filtering NaN image paths: {len(df_clean)} samples") |
|
|
|
|
|
if len(df_clean) == 0: |
|
|
print("β No valid samples after filtering.") |
|
|
return None |
|
|
|
|
|
|
|
|
if 'color' in df_clean.columns: |
|
|
print(f"π¨ Total unique colors in dataset: {len(df_clean['color'].unique())}") |
|
|
print(f"π¨ Colors found: {sorted(df_clean['color'].unique())}") |
|
|
print(f"π¨ Color distribution (top 15):") |
|
|
color_counts = df_clean['color'].value_counts() |
|
|
for color in color_counts.index[:15]: |
|
|
print(f" {color}: {color_counts[color]} samples") |
|
|
|
|
|
|
|
|
required_cols = ['text', 'hierarchy'] |
|
|
missing_cols = [col for col in required_cols if col not in df_clean.columns] |
|
|
if missing_cols: |
|
|
print(f"β Missing required columns: {missing_cols}") |
|
|
return None |
|
|
|
|
|
|
|
|
if len(df_clean) > max_samples: |
|
|
df_clean = df_clean.sample(n=max_samples, random_state=42) |
|
|
print(f"π Randomly sampled {max_samples} samples") |
|
|
|
|
|
print(f"π Using {len(df_clean)} samples for evaluation") |
|
|
print(f" Samples per hierarchy:") |
|
|
for hierarchy in sorted(df_clean['hierarchy'].unique()): |
|
|
count = len(df_clean[df_clean['hierarchy'] == hierarchy]) |
|
|
print(f" {hierarchy}: {count} samples") |
|
|
|
|
|
|
|
|
if 'color' in df_clean.columns: |
|
|
print(f"\nπ¨ Color distribution in sampled data:") |
|
|
color_counts = df_clean['color'].value_counts() |
|
|
print(f" Total unique colors: {len(color_counts)}") |
|
|
for color in color_counts.index[:15]: |
|
|
print(f" {color}: {color_counts[color]} samples") |
|
|
|
|
|
return LocalDataset(df_clean) |
|
|
|
|
|
|
|
|
class ColorHierarchyEvaluator: |
|
|
"""Evaluate color (dims 0-15) and hierarchy (dims 16-79) embeddings on Fashion-MNIST""" |
|
|
|
|
|
def __init__(self, device='mps', directory='fashion_mnist_color_hierarchy_analysis'): |
|
|
self.device = torch.device(device) |
|
|
self.directory = directory |
|
|
self.color_emb_dim = color_emb_dim |
|
|
self.hierarchy_emb_dim = hierarchy_emb_dim |
|
|
os.makedirs(self.directory, exist_ok=True) |
|
|
|
|
|
print(f"π Loading main model from {main_model_path}") |
|
|
if not os.path.exists(main_model_path): |
|
|
raise FileNotFoundError(f"Main model file {main_model_path} not found") |
|
|
|
|
|
|
|
|
print("π Loading hierarchy classes from hierarchy model...") |
|
|
if not os.path.exists(hierarchy_model_path): |
|
|
raise FileNotFoundError(f"Hierarchy model file {hierarchy_model_path} not found") |
|
|
|
|
|
hierarchy_checkpoint = torch.load(hierarchy_model_path, map_location=self.device) |
|
|
self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', []) |
|
|
print(f"β
Found {len(self.hierarchy_classes)} hierarchy classes: {sorted(self.hierarchy_classes)}") |
|
|
|
|
|
self.validation_hierarchy_classes = self._load_validation_hierarchy_classes() |
|
|
if self.validation_hierarchy_classes: |
|
|
print(f"β
Validation dataset hierarchies ({len(self.validation_hierarchy_classes)} classes): {sorted(self.validation_hierarchy_classes)}") |
|
|
else: |
|
|
print("β οΈ Unable to load validation hierarchy classes, falling back to hierarchy model classes.") |
|
|
self.validation_hierarchy_classes = self.hierarchy_classes |
|
|
|
|
|
checkpoint = torch.load(main_model_path, map_location=self.device) |
|
|
self.processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') |
|
|
self.model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') |
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
print("β
Main model loaded successfully") |
|
|
|
|
|
|
|
|
print("π¦ Loading baseline Fashion CLIP model...") |
|
|
patrick_model_name = "patrickjohncyh/fashion-clip" |
|
|
self.baseline_processor = CLIPProcessor.from_pretrained(patrick_model_name) |
|
|
self.baseline_model = CLIPModel_transformers.from_pretrained(patrick_model_name).to(self.device) |
|
|
self.baseline_model.eval() |
|
|
print("β
Baseline Fashion CLIP model loaded successfully") |
|
|
|
|
|
def _load_validation_hierarchy_classes(self): |
|
|
"""Load hierarchy classes present in the validation dataset""" |
|
|
if not os.path.exists(local_dataset_path): |
|
|
print(f"β οΈ Validation dataset not found at {local_dataset_path}") |
|
|
return [] |
|
|
|
|
|
try: |
|
|
df = pd.read_csv(local_dataset_path) |
|
|
except Exception as exc: |
|
|
print(f"β οΈ Failed to read validation dataset: {exc}") |
|
|
return [] |
|
|
|
|
|
if 'hierarchy' not in df.columns: |
|
|
print("β οΈ Validation dataset does not contain 'hierarchy' column.") |
|
|
return [] |
|
|
|
|
|
hierarchies = ( |
|
|
df['hierarchy'] |
|
|
.dropna() |
|
|
.astype(str) |
|
|
.str.strip() |
|
|
) |
|
|
hierarchies = [h for h in hierarchies if h] |
|
|
return sorted(set(hierarchies)) |
|
|
|
|
|
def extract_color_embeddings(self, dataloader, embedding_type='text', max_samples=10000): |
|
|
"""Extract color embeddings from dims 0-15 (16 dimensions)""" |
|
|
all_embeddings = [] |
|
|
all_colors = [] |
|
|
all_hierarchies = [] |
|
|
|
|
|
sample_count = 0 |
|
|
with torch.no_grad(): |
|
|
for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} color embeddings (dims 0-15)"): |
|
|
if sample_count >= max_samples: |
|
|
break |
|
|
|
|
|
images, texts, colors, hierarchies = batch |
|
|
images = images.to(self.device) |
|
|
images = images.expand(-1, 3, -1, -1) |
|
|
|
|
|
text_inputs = self.processor(text=texts, padding=True, return_tensors="pt") |
|
|
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} |
|
|
|
|
|
outputs = self.model(**text_inputs, pixel_values=images) |
|
|
|
|
|
if embedding_type == 'text': |
|
|
embeddings = outputs.text_embeds |
|
|
elif embedding_type == 'image': |
|
|
embeddings = outputs.image_embeds |
|
|
else: |
|
|
embeddings = outputs.text_embeds |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
color_embeddings = embeddings |
|
|
all_embeddings.append(color_embeddings.cpu().numpy()) |
|
|
all_colors.extend(colors) |
|
|
all_hierarchies.extend(hierarchies) |
|
|
|
|
|
sample_count += len(images) |
|
|
|
|
|
del images, text_inputs, outputs, embeddings, color_embeddings |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
return np.vstack(all_embeddings), all_colors, all_hierarchies |
|
|
|
|
|
def extract_hierarchy_embeddings(self, dataloader, embedding_type='text', max_samples=10000): |
|
|
"""Extract hierarchy embeddings from dims 16-79 (indices 16:79)""" |
|
|
all_embeddings = [] |
|
|
all_colors = [] |
|
|
all_hierarchies = [] |
|
|
|
|
|
sample_count = 0 |
|
|
with torch.no_grad(): |
|
|
for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} hierarchy embeddings (dims 16-79)"): |
|
|
if sample_count >= max_samples: |
|
|
break |
|
|
|
|
|
images, texts, colors, hierarchies = batch |
|
|
images = images.to(self.device) |
|
|
images = images.expand(-1, 3, -1, -1) |
|
|
|
|
|
text_inputs = self.processor(text=texts, padding=True, return_tensors="pt") |
|
|
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} |
|
|
|
|
|
outputs = self.model(**text_inputs, pixel_values=images) |
|
|
|
|
|
if embedding_type == 'text': |
|
|
embeddings = outputs.text_embeds |
|
|
elif embedding_type == 'image': |
|
|
embeddings = outputs.image_embeds |
|
|
else: |
|
|
embeddings = outputs.text_embeds |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hierarchy_embeddings = embeddings |
|
|
all_embeddings.append(hierarchy_embeddings.cpu().numpy()) |
|
|
all_colors.extend(colors) |
|
|
all_hierarchies.extend(hierarchies) |
|
|
|
|
|
sample_count += len(images) |
|
|
|
|
|
del images, text_inputs, outputs, embeddings, hierarchy_embeddings |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
return np.vstack(all_embeddings), all_colors, all_hierarchies |
|
|
|
|
|
def extract_full_embeddings(self, dataloader, embedding_type='text', max_samples=10000): |
|
|
"""Extract full 512-dimensional embeddings (all dimensions)""" |
|
|
all_embeddings = [] |
|
|
all_colors = [] |
|
|
all_hierarchies = [] |
|
|
|
|
|
sample_count = 0 |
|
|
with torch.no_grad(): |
|
|
for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} FULL embeddings (all dims)"): |
|
|
if sample_count >= max_samples: |
|
|
break |
|
|
|
|
|
images, texts, colors, hierarchies = batch |
|
|
images = images.to(self.device) |
|
|
images = images.expand(-1, 3, -1, -1) |
|
|
|
|
|
text_inputs = self.processor(text=texts, padding=True, return_tensors="pt") |
|
|
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} |
|
|
|
|
|
outputs = self.model(**text_inputs, pixel_values=images) |
|
|
|
|
|
if embedding_type == 'text': |
|
|
embeddings = outputs.text_embeds |
|
|
elif embedding_type == 'image': |
|
|
embeddings = outputs.image_embeds |
|
|
else: |
|
|
embeddings = outputs.text_embeds |
|
|
|
|
|
|
|
|
all_embeddings.append(embeddings.cpu().numpy()) |
|
|
all_colors.extend(colors) |
|
|
all_hierarchies.extend(hierarchies) |
|
|
|
|
|
sample_count += len(images) |
|
|
|
|
|
del images, text_inputs, outputs, embeddings |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
return np.vstack(all_embeddings), all_colors, all_hierarchies |
|
|
|
|
|
def extract_baseline_embeddings_batch(self, dataloader, embedding_type='text', max_samples=10000): |
|
|
""" |
|
|
Extract embeddings from baseline Fashion CLIP model. |
|
|
|
|
|
This method properly processes images and text through the Fashion-CLIP processor |
|
|
and applies L2 normalization to embeddings, matching the evaluation in evaluate_color_embeddings.py |
|
|
""" |
|
|
all_embeddings = [] |
|
|
all_colors = [] |
|
|
all_hierarchies = [] |
|
|
|
|
|
sample_count = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in tqdm(dataloader, desc=f"Extracting baseline {embedding_type} embeddings"): |
|
|
if sample_count >= max_samples: |
|
|
break |
|
|
|
|
|
images, texts, colors, hierarchies = batch |
|
|
|
|
|
|
|
|
if embedding_type == 'text': |
|
|
|
|
|
text_inputs = self.baseline_processor(text=texts, return_tensors="pt", padding=True, truncation=True, max_length=77) |
|
|
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} |
|
|
|
|
|
|
|
|
text_features = self.baseline_model.get_text_features(**text_inputs) |
|
|
|
|
|
|
|
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|
|
embeddings = text_features |
|
|
|
|
|
elif embedding_type == 'image': |
|
|
|
|
|
pil_images = [] |
|
|
for i in range(images.shape[0]): |
|
|
img_tensor = images[i] |
|
|
|
|
|
|
|
|
|
|
|
if img_tensor.min() < 0 or img_tensor.max() > 1: |
|
|
|
|
|
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) |
|
|
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) |
|
|
img_tensor = img_tensor * std + mean |
|
|
img_tensor = torch.clamp(img_tensor, 0, 1) |
|
|
|
|
|
|
|
|
img_pil = transforms.ToPILImage()(img_tensor) |
|
|
pil_images.append(img_pil) |
|
|
|
|
|
|
|
|
image_inputs = self.baseline_processor(images=pil_images, return_tensors="pt") |
|
|
image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()} |
|
|
|
|
|
|
|
|
image_features = self.baseline_model.get_image_features(**image_inputs) |
|
|
|
|
|
|
|
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
|
embeddings = image_features |
|
|
|
|
|
else: |
|
|
|
|
|
text_inputs = self.baseline_processor(text=texts, return_tensors="pt", padding=True, truncation=True, max_length=77) |
|
|
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} |
|
|
text_features = self.baseline_model.get_text_features(**text_inputs) |
|
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|
|
embeddings = text_features |
|
|
|
|
|
all_embeddings.append(embeddings.cpu().numpy()) |
|
|
all_colors.extend(colors) |
|
|
all_hierarchies.extend(hierarchies) |
|
|
|
|
|
sample_count += len(images) |
|
|
|
|
|
|
|
|
del embeddings |
|
|
if embedding_type == 'image': |
|
|
del pil_images, image_inputs |
|
|
else: |
|
|
del text_inputs |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
return np.vstack(all_embeddings), all_colors, all_hierarchies |
|
|
|
|
|
def compute_similarity_metrics(self, embeddings, labels): |
|
|
"""Compute intra-class and inter-class similarities""" |
|
|
max_samples = min(5000, len(embeddings)) |
|
|
if len(embeddings) > max_samples: |
|
|
indices = np.random.choice(len(embeddings), max_samples, replace=False) |
|
|
embeddings = embeddings[indices] |
|
|
labels = [labels[i] for i in indices] |
|
|
|
|
|
similarities = cosine_similarity(embeddings) |
|
|
|
|
|
label_groups = defaultdict(list) |
|
|
for i, label in enumerate(labels): |
|
|
label_groups[label].append(i) |
|
|
|
|
|
intra_class_similarities = [] |
|
|
for label, indices in label_groups.items(): |
|
|
if len(indices) > 1: |
|
|
for i in range(len(indices)): |
|
|
for j in range(i + 1, len(indices)): |
|
|
sim = similarities[indices[i], indices[j]] |
|
|
intra_class_similarities.append(sim) |
|
|
|
|
|
inter_class_similarities = [] |
|
|
labels_list = list(label_groups.keys()) |
|
|
for i in range(len(labels_list)): |
|
|
for j in range(i + 1, len(labels_list)): |
|
|
label1_indices = label_groups[labels_list[i]] |
|
|
label2_indices = label_groups[labels_list[j]] |
|
|
for idx1 in label1_indices: |
|
|
for idx2 in label2_indices: |
|
|
sim = similarities[idx1, idx2] |
|
|
inter_class_similarities.append(sim) |
|
|
|
|
|
nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities) |
|
|
centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels) |
|
|
|
|
|
return { |
|
|
'intra_class_similarities': intra_class_similarities, |
|
|
'inter_class_similarities': inter_class_similarities, |
|
|
'intra_class_mean': float(np.mean(intra_class_similarities)) if intra_class_similarities else 0.0, |
|
|
'inter_class_mean': float(np.mean(inter_class_similarities)) if inter_class_similarities else 0.0, |
|
|
'separation_score': float(np.mean(intra_class_similarities) - np.mean(inter_class_similarities)) if intra_class_similarities and inter_class_similarities else 0.0, |
|
|
'accuracy': nn_accuracy, |
|
|
'centroid_accuracy': centroid_accuracy, |
|
|
} |
|
|
|
|
|
def compute_embedding_accuracy(self, embeddings, labels, similarities): |
|
|
"""Compute classification accuracy using nearest neighbor""" |
|
|
correct_predictions = 0 |
|
|
total_predictions = len(labels) |
|
|
for i in range(len(embeddings)): |
|
|
true_label = labels[i] |
|
|
similarities_row = similarities[i].copy() |
|
|
similarities_row[i] = -1 |
|
|
nearest_neighbor_idx = int(np.argmax(similarities_row)) |
|
|
predicted_label = labels[nearest_neighbor_idx] |
|
|
if predicted_label == true_label: |
|
|
correct_predictions += 1 |
|
|
return correct_predictions / total_predictions if total_predictions > 0 else 0.0 |
|
|
|
|
|
def compute_centroid_accuracy(self, embeddings, labels): |
|
|
"""Compute classification accuracy using centroids""" |
|
|
unique_labels = list(set(labels)) |
|
|
centroids = {} |
|
|
for label in unique_labels: |
|
|
label_indices = [i for i, l in enumerate(labels) if l == label] |
|
|
centroids[label] = np.mean(embeddings[label_indices], axis=0) |
|
|
|
|
|
correct_predictions = 0 |
|
|
total_predictions = len(labels) |
|
|
for i, embedding in enumerate(embeddings): |
|
|
true_label = labels[i] |
|
|
best_similarity = -1 |
|
|
predicted_label = None |
|
|
for label, centroid in centroids.items(): |
|
|
similarity = cosine_similarity([embedding], [centroid])[0][0] |
|
|
if similarity > best_similarity: |
|
|
best_similarity = similarity |
|
|
predicted_label = label |
|
|
if predicted_label == true_label: |
|
|
correct_predictions += 1 |
|
|
return correct_predictions / total_predictions if total_predictions > 0 else 0.0 |
|
|
|
|
|
def predict_labels_from_embeddings(self, embeddings, labels): |
|
|
"""Predict labels from embeddings using centroid-based classification""" |
|
|
unique_labels = list(set(labels)) |
|
|
centroids = {} |
|
|
for label in unique_labels: |
|
|
label_indices = [i for i, l in enumerate(labels) if l == label] |
|
|
centroids[label] = np.mean(embeddings[label_indices], axis=0) |
|
|
|
|
|
predictions = [] |
|
|
for i, embedding in enumerate(embeddings): |
|
|
best_similarity = -1 |
|
|
predicted_label = None |
|
|
for label, centroid in centroids.items(): |
|
|
similarity = cosine_similarity([embedding], [centroid])[0][0] |
|
|
if similarity > best_similarity: |
|
|
best_similarity = similarity |
|
|
predicted_label = label |
|
|
predictions.append(predicted_label) |
|
|
return predictions |
|
|
|
|
|
def predict_labels_ensemble(self, specialized_embeddings, full_embeddings, labels, |
|
|
specialized_weight=0.5): |
|
|
""" |
|
|
Ensemble prediction combining specialized (16/64 dims) and full (512 dims) embeddings. |
|
|
|
|
|
Args: |
|
|
specialized_embeddings: Embeddings from specialized dimensions (e.g., dims 0-15 for color) |
|
|
full_embeddings: Full 512-dimensional embeddings |
|
|
labels: True labels for computing centroids |
|
|
specialized_weight: Weight for specialized embeddings (0.0 = only full, 1.0 = only specialized) |
|
|
|
|
|
Returns: |
|
|
List of predicted labels using weighted ensemble |
|
|
""" |
|
|
unique_labels = list(set(labels)) |
|
|
|
|
|
|
|
|
specialized_centroids = {} |
|
|
full_centroids = {} |
|
|
|
|
|
for label in unique_labels: |
|
|
label_indices = [i for i, l in enumerate(labels) if l == label] |
|
|
specialized_centroids[label] = np.mean(specialized_embeddings[label_indices], axis=0) |
|
|
full_centroids[label] = np.mean(full_embeddings[label_indices], axis=0) |
|
|
|
|
|
predictions = [] |
|
|
for i in range(len(specialized_embeddings)): |
|
|
best_combined_score = -np.inf |
|
|
predicted_label = None |
|
|
|
|
|
for label in unique_labels: |
|
|
|
|
|
spec_sim = cosine_similarity([specialized_embeddings[i]], [specialized_centroids[label]])[0][0] |
|
|
full_sim = cosine_similarity([full_embeddings[i]], [full_centroids[label]])[0][0] |
|
|
|
|
|
|
|
|
combined_score = specialized_weight * spec_sim + (1 - specialized_weight) * full_sim |
|
|
|
|
|
if combined_score > best_combined_score: |
|
|
best_combined_score = combined_score |
|
|
predicted_label = label |
|
|
|
|
|
predictions.append(predicted_label) |
|
|
|
|
|
return predictions |
|
|
|
|
|
def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix", label_type="Label"): |
|
|
"""Create and plot confusion matrix""" |
|
|
unique_labels = sorted(list(set(true_labels + predicted_labels))) |
|
|
cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels) |
|
|
accuracy = accuracy_score(true_labels, predicted_labels) |
|
|
plt.figure(figsize=(12, 10)) |
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=unique_labels, yticklabels=unique_labels) |
|
|
plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)') |
|
|
plt.ylabel(f'True {label_type}') |
|
|
plt.xlabel(f'Predicted {label_type}') |
|
|
plt.xticks(rotation=45) |
|
|
plt.yticks(rotation=0) |
|
|
plt.tight_layout() |
|
|
return plt.gcf(), accuracy, cm |
|
|
|
|
|
def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings", label_type="Label", |
|
|
full_embeddings=None, ensemble_weight=0.5): |
|
|
""" |
|
|
Evaluate classification performance and create confusion matrix. |
|
|
|
|
|
Args: |
|
|
embeddings: Specialized embeddings (e.g., dims 0-15 for color or dims 16-79 for hierarchy) |
|
|
labels: True labels |
|
|
embedding_type: Type of embeddings for display |
|
|
label_type: Type of labels (Color/Hierarchy) |
|
|
full_embeddings: Optional full 512-dim embeddings for ensemble (if None, uses only specialized) |
|
|
ensemble_weight: Weight for specialized embeddings in ensemble (0.0 = only full, 1.0 = only specialized) |
|
|
""" |
|
|
if full_embeddings is not None: |
|
|
|
|
|
predictions = self.predict_labels_ensemble(embeddings, full_embeddings, labels, ensemble_weight) |
|
|
title_suffix = f" (Ensemble: {ensemble_weight:.1f} specialized + {1-ensemble_weight:.1f} full)" |
|
|
else: |
|
|
|
|
|
predictions = self.predict_labels_from_embeddings(embeddings, labels) |
|
|
title_suffix = "" |
|
|
|
|
|
accuracy = accuracy_score(labels, predictions) |
|
|
fig, acc, cm = self.create_confusion_matrix( |
|
|
labels, predictions, |
|
|
f"{embedding_type} - {label_type} Classification{title_suffix}", |
|
|
label_type |
|
|
) |
|
|
unique_labels = sorted(list(set(labels))) |
|
|
report = classification_report(labels, predictions, labels=unique_labels, target_names=unique_labels, output_dict=True) |
|
|
return { |
|
|
'accuracy': accuracy, |
|
|
'predictions': predictions, |
|
|
'confusion_matrix': cm, |
|
|
'classification_report': report, |
|
|
'figure': fig, |
|
|
} |
|
|
|
|
|
def evaluate_fashion_mnist(self, max_samples): |
|
|
"""Evaluate both color and hierarchy embeddings on Fashion-MNIST""" |
|
|
print(f"\n{'='*60}") |
|
|
print("Evaluating Fashion-MNIST") |
|
|
print(" Color embeddings: dims 0-15") |
|
|
print(" Hierarchy embeddings: dims 16-79") |
|
|
print(f"Max samples: {max_samples}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
target_hierarchy_classes = self.validation_hierarchy_classes or self.hierarchy_classes |
|
|
fashion_dataset = load_fashion_mnist_dataset(max_samples, hierarchy_classes=target_hierarchy_classes) |
|
|
dataloader = DataLoader(fashion_dataset, batch_size=8, shuffle=False, num_workers=0) |
|
|
|
|
|
|
|
|
if len(fashion_dataset.dataframe) > 0: |
|
|
print(f"\nπ Hierarchy distribution in dataset:") |
|
|
if fashion_dataset.label_mapping: |
|
|
hierarchy_counts = {} |
|
|
for _, row in fashion_dataset.dataframe.iterrows(): |
|
|
label_id = int(row['label']) |
|
|
hierarchy = fashion_dataset.label_mapping.get(label_id, 'unknown') |
|
|
hierarchy_counts[hierarchy] = hierarchy_counts.get(hierarchy, 0) + 1 |
|
|
|
|
|
for hierarchy, count in sorted(hierarchy_counts.items()): |
|
|
print(f" {hierarchy}: {count} samples") |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
print("\nπ¦ Extracting full 512-dimensional embeddings for ensemble...") |
|
|
text_full_embeddings, text_colors_full, text_hierarchies_full = self.extract_full_embeddings(dataloader, 'text', max_samples) |
|
|
image_full_embeddings, image_colors_full, image_hierarchies_full = self.extract_full_embeddings(dataloader, 'image', max_samples) |
|
|
print(f" Text full embeddings shape: {text_full_embeddings.shape}") |
|
|
print(f" Image full embeddings shape: {image_full_embeddings.shape}") |
|
|
|
|
|
|
|
|
print("\nπ HIERARCHY EVALUATION (dims 16-79) - Using Ensemble") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
print("\nπ Extracting specialized text hierarchy embeddings (dims 16-79)...") |
|
|
text_hierarchy_embeddings_spec = text_full_embeddings[:, self.color_emb_dim:self.color_emb_dim+self.hierarchy_emb_dim] |
|
|
print(f" Specialized text hierarchy embeddings shape: {text_hierarchy_embeddings_spec.shape}") |
|
|
text_hierarchy_metrics = self.compute_similarity_metrics(text_hierarchy_embeddings_spec, text_hierarchies_full) |
|
|
|
|
|
text_hierarchy_class = self.evaluate_classification_performance( |
|
|
text_hierarchy_embeddings_spec, text_hierarchies_full, |
|
|
"Text Hierarchy Embeddings (Ensemble)", "Hierarchy", |
|
|
full_embeddings=text_full_embeddings, ensemble_weight=1 |
|
|
) |
|
|
text_hierarchy_metrics.update(text_hierarchy_class) |
|
|
results['text_hierarchy'] = text_hierarchy_metrics |
|
|
|
|
|
|
|
|
print("\nπΌοΈ Extracting specialized image hierarchy embeddings (dims 16-79)...") |
|
|
image_hierarchy_embeddings_spec = image_full_embeddings[:, self.color_emb_dim:self.color_emb_dim+self.hierarchy_emb_dim] |
|
|
print(f" Specialized image hierarchy embeddings shape: {image_hierarchy_embeddings_spec.shape}") |
|
|
image_hierarchy_metrics = self.compute_similarity_metrics(image_hierarchy_embeddings_spec, image_hierarchies_full) |
|
|
image_hierarchy_class = self.evaluate_classification_performance( |
|
|
image_hierarchy_embeddings_spec, image_hierarchies_full, |
|
|
"Image Hierarchy Embeddings (Ensemble)", "Hierarchy", |
|
|
full_embeddings=image_full_embeddings, ensemble_weight=1 |
|
|
) |
|
|
image_hierarchy_metrics.update(image_hierarchy_class) |
|
|
results['image_hierarchy'] = image_hierarchy_metrics |
|
|
|
|
|
|
|
|
del text_full_embeddings, image_full_embeddings |
|
|
del text_hierarchy_embeddings_spec, image_hierarchy_embeddings_spec |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
|
|
|
os.makedirs(self.directory, exist_ok=True) |
|
|
for key in ['text_hierarchy', 'image_hierarchy']: |
|
|
results[key]['figure'].savefig( |
|
|
f"{self.directory}/fashion_{key.replace('_', '_')}_confusion_matrix.png", |
|
|
dpi=300, |
|
|
bbox_inches='tight', |
|
|
) |
|
|
plt.close(results[key]['figure']) |
|
|
|
|
|
return results |
|
|
|
|
|
def evaluate_kaggle_marqo(self, max_samples): |
|
|
"""Evaluate both color and hierarchy embeddings on KAGL Marqo dataset""" |
|
|
print(f"\n{'='*60}") |
|
|
print("Evaluating KAGL Marqo Dataset") |
|
|
print(" Color embeddings: dims 0-15") |
|
|
print(" Hierarchy embeddings: dims 16-79") |
|
|
print(f"Max samples: {max_samples}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
kaggle_dataset = load_kaggle_marqo_dataset(self, max_samples) |
|
|
if kaggle_dataset is None: |
|
|
print("β Failed to load KAGL dataset") |
|
|
return None |
|
|
|
|
|
dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0) |
|
|
|
|
|
|
|
|
if len(kaggle_dataset.dataframe) > 0: |
|
|
print(f"\nπ Hierarchy distribution in dataset:") |
|
|
hierarchy_counts = {} |
|
|
for _, row in kaggle_dataset.dataframe.iterrows(): |
|
|
hierarchy = row['hierarchy'] |
|
|
hierarchy_counts[hierarchy] = hierarchy_counts.get(hierarchy, 0) + 1 |
|
|
|
|
|
for hierarchy, count in sorted(hierarchy_counts.items()): |
|
|
print(f" {hierarchy}: {count} samples") |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
print("\nπ¦ Extracting full 512-dimensional embeddings for ensemble...") |
|
|
text_full_embeddings, text_colors_full, text_hierarchies_full = self.extract_full_embeddings(dataloader, 'text', max_samples) |
|
|
image_full_embeddings, image_colors_full, image_hierarchies_full = self.extract_full_embeddings(dataloader, 'image', max_samples) |
|
|
print(f" Text full embeddings shape: {text_full_embeddings.shape}") |
|
|
print(f" Image full embeddings shape: {image_full_embeddings.shape}") |
|
|
|
|
|
|
|
|
print("\nπ¨ COLOR EVALUATION (dims 0-15) - Using Ensemble") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
print("\nπ Extracting specialized text color embeddings (dims 0-15)...") |
|
|
text_color_embeddings_spec = text_full_embeddings[:, :self.color_emb_dim] |
|
|
print(f" Specialized text color embeddings shape: {text_color_embeddings_spec.shape}") |
|
|
text_color_metrics = self.compute_similarity_metrics(text_color_embeddings_spec, text_colors_full) |
|
|
|
|
|
text_color_class = self.evaluate_classification_performance( |
|
|
text_color_embeddings_spec, text_colors_full, |
|
|
"Text Color Embeddings (Ensemble)", "Color", |
|
|
full_embeddings=text_full_embeddings, ensemble_weight=1 |
|
|
) |
|
|
text_color_metrics.update(text_color_class) |
|
|
results['text_color'] = text_color_metrics |
|
|
|
|
|
|
|
|
print("\nπΌοΈ Extracting specialized image color embeddings (dims 0-15)...") |
|
|
image_color_embeddings_spec = image_full_embeddings[:, :self.color_emb_dim] |
|
|
print(f" Specialized image color embeddings shape: {image_color_embeddings_spec.shape}") |
|
|
image_color_metrics = self.compute_similarity_metrics(image_color_embeddings_spec, image_colors_full) |
|
|
image_color_class = self.evaluate_classification_performance( |
|
|
image_color_embeddings_spec, image_colors_full, |
|
|
"Image Color Embeddings (Ensemble)", "Color", |
|
|
full_embeddings=image_full_embeddings, ensemble_weight=1 |
|
|
) |
|
|
image_color_metrics.update(image_color_class) |
|
|
results['image_color'] = image_color_metrics |
|
|
|
|
|
|
|
|
print("\nπ HIERARCHY EVALUATION (dims 16-79) - Using Ensemble") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
print("\nπ Extracting specialized text hierarchy embeddings (dims 16-79)...") |
|
|
text_hierarchy_embeddings_spec = text_full_embeddings[:, self.color_emb_dim:self.color_emb_dim+self.hierarchy_emb_dim] |
|
|
print(f" Specialized text hierarchy embeddings shape: {text_hierarchy_embeddings_spec.shape}") |
|
|
text_hierarchy_metrics = self.compute_similarity_metrics(text_hierarchy_embeddings_spec, text_hierarchies_full) |
|
|
|
|
|
text_hierarchy_class = self.evaluate_classification_performance( |
|
|
text_hierarchy_embeddings_spec, text_hierarchies_full, |
|
|
"Text Hierarchy Embeddings (Ensemble)", "Hierarchy", |
|
|
full_embeddings=text_full_embeddings, ensemble_weight=0.4 |
|
|
) |
|
|
text_hierarchy_metrics.update(text_hierarchy_class) |
|
|
results['text_hierarchy'] = text_hierarchy_metrics |
|
|
|
|
|
|
|
|
print("\nπΌοΈ Extracting specialized image hierarchy embeddings (dims 16-79)...") |
|
|
image_hierarchy_embeddings_spec = image_full_embeddings[:, self.color_emb_dim:self.color_emb_dim+self.hierarchy_emb_dim] |
|
|
print(f" Specialized image hierarchy embeddings shape: {image_hierarchy_embeddings_spec.shape}") |
|
|
image_hierarchy_metrics = self.compute_similarity_metrics(image_hierarchy_embeddings_spec, image_hierarchies_full) |
|
|
image_hierarchy_class = self.evaluate_classification_performance( |
|
|
image_hierarchy_embeddings_spec, image_hierarchies_full, |
|
|
"Image Hierarchy Embeddings (Ensemble)", "Hierarchy", |
|
|
full_embeddings=image_full_embeddings, ensemble_weight=0.4 |
|
|
) |
|
|
image_hierarchy_metrics.update(image_hierarchy_class) |
|
|
results['image_hierarchy'] = image_hierarchy_metrics |
|
|
|
|
|
|
|
|
del text_full_embeddings, image_full_embeddings |
|
|
del text_color_embeddings_spec, image_color_embeddings_spec |
|
|
del text_hierarchy_embeddings_spec, image_hierarchy_embeddings_spec |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
|
|
|
os.makedirs(self.directory, exist_ok=True) |
|
|
for key in ['text_color', 'image_color', 'text_hierarchy', 'image_hierarchy']: |
|
|
results[key]['figure'].savefig( |
|
|
f"{self.directory}/kaggle_{key.replace('_', '_')}_confusion_matrix.png", |
|
|
dpi=300, |
|
|
bbox_inches='tight', |
|
|
) |
|
|
plt.close(results[key]['figure']) |
|
|
|
|
|
return results |
|
|
|
|
|
def evaluate_local_validation(self, max_samples): |
|
|
"""Evaluate both color and hierarchy embeddings on local validation dataset (NO ENSEMBLE - only specialized embeddings)""" |
|
|
print(f"\n{'='*60}") |
|
|
print("Evaluating Local Validation Dataset") |
|
|
print(" Color embeddings: dims 0-15 (specialized only, no ensemble)") |
|
|
print(" Hierarchy embeddings: dims 16-79 (specialized only, no ensemble)") |
|
|
print(f"Max samples: {max_samples}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
local_dataset = load_local_validation_dataset(max_samples) |
|
|
if local_dataset is None: |
|
|
print("β Failed to load local validation dataset") |
|
|
return None |
|
|
|
|
|
|
|
|
if len(local_dataset.dataframe) > 0: |
|
|
valid_df = local_dataset.dataframe[local_dataset.dataframe['hierarchy'].isin(self.hierarchy_classes)] |
|
|
if len(valid_df) == 0: |
|
|
print("β No samples left after hierarchy filtering.") |
|
|
return None |
|
|
if len(valid_df) < len(local_dataset.dataframe): |
|
|
print(f"π Filtered to model hierarchies: {len(valid_df)} samples (from {len(local_dataset.dataframe)})") |
|
|
local_dataset = LocalDataset(valid_df) |
|
|
|
|
|
dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0) |
|
|
|
|
|
|
|
|
if len(local_dataset.dataframe) > 0: |
|
|
print(f"\nπ Hierarchy distribution in dataset:") |
|
|
hierarchy_counts = {} |
|
|
for _, row in local_dataset.dataframe.iterrows(): |
|
|
hierarchy = row['hierarchy'] |
|
|
hierarchy_counts[hierarchy] = hierarchy_counts.get(hierarchy, 0) + 1 |
|
|
|
|
|
for hierarchy, count in sorted(hierarchy_counts.items()): |
|
|
print(f" {hierarchy}: {count} samples") |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
print("\nπ¨ COLOR EVALUATION (dims 0-15) - Specialized embeddings only") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
print("\nπ Extracting text color embeddings...") |
|
|
text_color_embeddings, text_colors, _ = self.extract_color_embeddings(dataloader, 'text', max_samples) |
|
|
print(f" Text color embeddings shape: {text_color_embeddings.shape}") |
|
|
text_color_metrics = self.compute_similarity_metrics(text_color_embeddings, text_colors) |
|
|
text_color_class = self.evaluate_classification_performance( |
|
|
text_color_embeddings, text_colors, "Text Color Embeddings (16D)", "Color" |
|
|
) |
|
|
text_color_metrics.update(text_color_class) |
|
|
results['text_color'] = text_color_metrics |
|
|
|
|
|
del text_color_embeddings |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
|
|
|
print("\nπΌοΈ Extracting image color embeddings...") |
|
|
image_color_embeddings, image_colors, _ = self.extract_color_embeddings(dataloader, 'image', max_samples) |
|
|
print(f" Image color embeddings shape: {image_color_embeddings.shape}") |
|
|
image_color_metrics = self.compute_similarity_metrics(image_color_embeddings, image_colors) |
|
|
image_color_class = self.evaluate_classification_performance( |
|
|
image_color_embeddings, image_colors, "Image Color Embeddings (16D)", "Color" |
|
|
) |
|
|
image_color_metrics.update(image_color_class) |
|
|
results['image_color'] = image_color_metrics |
|
|
|
|
|
del image_color_embeddings |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
|
|
|
print("\nπ HIERARCHY EVALUATION (dims 16-79) - Specialized embeddings only") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
print("\nπ Extracting text hierarchy embeddings...") |
|
|
text_hierarchy_embeddings, _, text_hierarchies = self.extract_hierarchy_embeddings(dataloader, 'text', max_samples) |
|
|
print(f" Text hierarchy embeddings shape: {text_hierarchy_embeddings.shape}") |
|
|
text_hierarchy_metrics = self.compute_similarity_metrics(text_hierarchy_embeddings, text_hierarchies) |
|
|
text_hierarchy_class = self.evaluate_classification_performance( |
|
|
text_hierarchy_embeddings, text_hierarchies, "Text Hierarchy Embeddings (64D)", "Hierarchy" |
|
|
) |
|
|
text_hierarchy_metrics.update(text_hierarchy_class) |
|
|
results['text_hierarchy'] = text_hierarchy_metrics |
|
|
|
|
|
del text_hierarchy_embeddings |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
|
|
|
print("\nπΌοΈ Extracting image hierarchy embeddings...") |
|
|
image_hierarchy_embeddings, _, image_hierarchies = self.extract_hierarchy_embeddings(dataloader, 'image', max_samples) |
|
|
print(f" Image hierarchy embeddings shape: {image_hierarchy_embeddings.shape}") |
|
|
image_hierarchy_metrics = self.compute_similarity_metrics(image_hierarchy_embeddings, image_hierarchies) |
|
|
image_hierarchy_class = self.evaluate_classification_performance( |
|
|
image_hierarchy_embeddings, image_hierarchies, "Image Hierarchy Embeddings (64D)", "Hierarchy" |
|
|
) |
|
|
image_hierarchy_metrics.update(image_hierarchy_class) |
|
|
results['image_hierarchy'] = image_hierarchy_metrics |
|
|
|
|
|
del image_hierarchy_embeddings |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
|
|
|
os.makedirs(self.directory, exist_ok=True) |
|
|
for key in ['text_color', 'image_color', 'text_hierarchy', 'image_hierarchy']: |
|
|
results[key]['figure'].savefig( |
|
|
f"{self.directory}/local_{key.replace('_', '_')}_confusion_matrix.png", |
|
|
dpi=300, |
|
|
bbox_inches='tight', |
|
|
) |
|
|
plt.close(results[key]['figure']) |
|
|
|
|
|
return results |
|
|
|
|
|
def evaluate_baseline_fashion_mnist(self, max_samples=1000): |
|
|
"""Evaluate baseline Fashion CLIP model on Fashion-MNIST""" |
|
|
print(f"\n{'='*60}") |
|
|
print("Evaluating Baseline Fashion CLIP on Fashion-MNIST") |
|
|
print(f"Max samples: {max_samples}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
|
|
|
target_hierarchy_classes = self.validation_hierarchy_classes or self.hierarchy_classes |
|
|
fashion_dataset = load_fashion_mnist_dataset(max_samples, hierarchy_classes=target_hierarchy_classes) |
|
|
|
|
|
|
|
|
dataloader = DataLoader( |
|
|
fashion_dataset, |
|
|
batch_size=8, |
|
|
shuffle=False, |
|
|
num_workers=0 |
|
|
) |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
print("\nπ Extracting baseline text embeddings from Fashion-MNIST...") |
|
|
text_embeddings, _, text_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) |
|
|
print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)") |
|
|
text_hierarchy_metrics = self.compute_similarity_metrics(text_embeddings, text_hierarchies) |
|
|
text_hierarchy_classification = self.evaluate_classification_performance( |
|
|
text_embeddings, text_hierarchies, "Baseline Fashion-MNIST Text Embeddings - Hierarchy", "Hierarchy" |
|
|
) |
|
|
|
|
|
text_hierarchy_metrics.update(text_hierarchy_classification) |
|
|
results['text'] = { |
|
|
'hierarchy': text_hierarchy_metrics |
|
|
} |
|
|
|
|
|
|
|
|
del text_embeddings |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
|
|
|
print("\nπΌοΈ Extracting baseline image embeddings from Fashion-MNIST...") |
|
|
image_embeddings, image_colors, image_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) |
|
|
print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)") |
|
|
image_hierarchy_metrics = self.compute_similarity_metrics(image_embeddings, image_hierarchies) |
|
|
|
|
|
image_hierarchy_classification = self.evaluate_classification_performance( |
|
|
image_embeddings, image_hierarchies, "Baseline Fashion-MNIST Image Embeddings - Hierarchy", "Hierarchy" |
|
|
) |
|
|
|
|
|
image_hierarchy_metrics.update(image_hierarchy_classification) |
|
|
results['image'] = { |
|
|
'hierarchy': image_hierarchy_metrics |
|
|
} |
|
|
|
|
|
|
|
|
del image_embeddings |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
|
|
|
os.makedirs(self.directory, exist_ok=True) |
|
|
for key in ['text', 'image']: |
|
|
for subkey in ['hierarchy']: |
|
|
figure = results[key][subkey]['figure'] |
|
|
figure.savefig( |
|
|
f"{self.directory}/fashion_baseline_{key}_{subkey}_confusion_matrix.png", |
|
|
dpi=300, |
|
|
bbox_inches='tight', |
|
|
) |
|
|
plt.close(figure) |
|
|
|
|
|
return results |
|
|
|
|
|
def evaluate_baseline_kaggle_marqo(self, max_samples=5000): |
|
|
"""Evaluate baseline Fashion CLIP model on KAGL Marqo dataset""" |
|
|
print(f"\n{'='*60}") |
|
|
print("Evaluating Baseline Fashion CLIP on KAGL Marqo Dataset") |
|
|
print(f"Max samples: {max_samples}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
|
|
|
kaggle_dataset = load_kaggle_marqo_dataset(self, max_samples) |
|
|
if kaggle_dataset is None: |
|
|
print("β Failed to load KAGL dataset") |
|
|
return None |
|
|
|
|
|
|
|
|
dataloader = DataLoader(kaggle_dataset, batch_size=8, shuffle=False, num_workers=0) |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
print("\nπ Extracting baseline text embeddings from KAGL Marqo...") |
|
|
text_embeddings, text_colors, text_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) |
|
|
print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)") |
|
|
text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors) |
|
|
text_hierarchy_metrics = self.compute_similarity_metrics(text_embeddings, text_hierarchies) |
|
|
|
|
|
text_color_classification = self.evaluate_classification_performance( |
|
|
text_embeddings, text_colors, "Baseline KAGL Marqo Text Embeddings - Color", "Color" |
|
|
) |
|
|
text_hierarchy_classification = self.evaluate_classification_performance( |
|
|
text_embeddings, text_hierarchies, "Baseline KAGL Marqo Text Embeddings - Hierarchy", "Hierarchy" |
|
|
) |
|
|
|
|
|
text_color_metrics.update(text_color_classification) |
|
|
text_hierarchy_metrics.update(text_hierarchy_classification) |
|
|
results['text'] = { |
|
|
'color': text_color_metrics, |
|
|
'hierarchy': text_hierarchy_metrics |
|
|
} |
|
|
|
|
|
|
|
|
del text_embeddings |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
|
|
|
print("\nπΌοΈ Extracting baseline image embeddings from KAGL Marqo...") |
|
|
image_embeddings, image_colors, image_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) |
|
|
print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)") |
|
|
image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors) |
|
|
image_hierarchy_metrics = self.compute_similarity_metrics(image_embeddings, image_hierarchies) |
|
|
|
|
|
image_color_classification = self.evaluate_classification_performance( |
|
|
image_embeddings, image_colors, "Baseline KAGL Marqo Image Embeddings - Color", "Color" |
|
|
) |
|
|
image_hierarchy_classification = self.evaluate_classification_performance( |
|
|
image_embeddings, image_hierarchies, "Baseline KAGL Marqo Image Embeddings - Hierarchy", "Hierarchy" |
|
|
) |
|
|
|
|
|
image_color_metrics.update(image_color_classification) |
|
|
image_hierarchy_metrics.update(image_hierarchy_classification) |
|
|
results['image'] = { |
|
|
'color': image_color_metrics, |
|
|
'hierarchy': image_hierarchy_metrics |
|
|
} |
|
|
|
|
|
|
|
|
del image_embeddings |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
|
|
|
os.makedirs(self.directory, exist_ok=True) |
|
|
for key in ['text', 'image']: |
|
|
for subkey in ['color', 'hierarchy']: |
|
|
figure = results[key][subkey]['figure'] |
|
|
figure.savefig( |
|
|
f"{self.directory}/kaggle_baseline_{key}_{subkey}_confusion_matrix.png", |
|
|
dpi=300, |
|
|
bbox_inches='tight', |
|
|
) |
|
|
plt.close(figure) |
|
|
|
|
|
return results |
|
|
|
|
|
def evaluate_baseline_local_validation(self, max_samples=5000): |
|
|
"""Evaluate baseline Fashion CLIP model on local validation dataset""" |
|
|
print(f"\n{'='*60}") |
|
|
print("Evaluating Baseline Fashion CLIP on Local Validation Dataset") |
|
|
print(f"Max samples: {max_samples}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
|
|
|
local_dataset = load_local_validation_dataset(max_samples) |
|
|
if local_dataset is None: |
|
|
print("β Failed to load local validation dataset") |
|
|
return None |
|
|
|
|
|
|
|
|
if len(local_dataset.dataframe) > 0: |
|
|
valid_df = local_dataset.dataframe[local_dataset.dataframe['hierarchy'].isin(self.hierarchy_classes)] |
|
|
if len(valid_df) == 0: |
|
|
print("β No samples left after hierarchy filtering.") |
|
|
return None |
|
|
if len(valid_df) < len(local_dataset.dataframe): |
|
|
print(f"π Filtered to model hierarchies: {len(valid_df)} samples (from {len(local_dataset.dataframe)})") |
|
|
local_dataset = LocalDataset(valid_df) |
|
|
|
|
|
|
|
|
dataloader = DataLoader(local_dataset, batch_size=8, shuffle=False, num_workers=0) |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
print("\nπ Extracting baseline text embeddings from Local Validation...") |
|
|
text_embeddings, text_colors, text_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'text', max_samples) |
|
|
print(f" Baseline text embeddings shape: {text_embeddings.shape} (using all {text_embeddings.shape[1]} dimensions)") |
|
|
text_color_metrics = self.compute_similarity_metrics(text_embeddings, text_colors) |
|
|
text_hierarchy_metrics = self.compute_similarity_metrics(text_embeddings, text_hierarchies) |
|
|
|
|
|
text_color_classification = self.evaluate_classification_performance( |
|
|
text_embeddings, text_colors, "Baseline Local Validation Text Embeddings - Color", "Color" |
|
|
) |
|
|
text_hierarchy_classification = self.evaluate_classification_performance( |
|
|
text_embeddings, text_hierarchies, "Baseline Local Validation Text Embeddings - Hierarchy", "Hierarchy" |
|
|
) |
|
|
|
|
|
text_color_metrics.update(text_color_classification) |
|
|
text_hierarchy_metrics.update(text_hierarchy_classification) |
|
|
results['text'] = { |
|
|
'color': text_color_metrics, |
|
|
'hierarchy': text_hierarchy_metrics |
|
|
} |
|
|
|
|
|
|
|
|
del text_embeddings |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
|
|
|
print("\nπΌοΈ Extracting baseline image embeddings from Local Validation...") |
|
|
image_embeddings, image_colors, image_hierarchies = self.extract_baseline_embeddings_batch(dataloader, 'image', max_samples) |
|
|
print(f" Baseline image embeddings shape: {image_embeddings.shape} (using all {image_embeddings.shape[1]} dimensions)") |
|
|
image_color_metrics = self.compute_similarity_metrics(image_embeddings, image_colors) |
|
|
image_hierarchy_metrics = self.compute_similarity_metrics(image_embeddings, image_hierarchies) |
|
|
|
|
|
image_color_classification = self.evaluate_classification_performance( |
|
|
image_embeddings, image_colors, "Baseline Local Validation Image Embeddings - Color", "Color" |
|
|
) |
|
|
image_hierarchy_classification = self.evaluate_classification_performance( |
|
|
image_embeddings, image_hierarchies, "Baseline Local Validation Image Embeddings - Hierarchy", "Hierarchy" |
|
|
) |
|
|
|
|
|
image_color_metrics.update(image_color_classification) |
|
|
image_hierarchy_metrics.update(image_hierarchy_classification) |
|
|
results['image'] = { |
|
|
'color': image_color_metrics, |
|
|
'hierarchy': image_hierarchy_metrics |
|
|
} |
|
|
|
|
|
|
|
|
del image_embeddings |
|
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
|
|
|
|
os.makedirs(self.directory, exist_ok=True) |
|
|
for key in ['text', 'image']: |
|
|
for subkey in ['color', 'hierarchy']: |
|
|
figure = results[key][subkey]['figure'] |
|
|
figure.savefig( |
|
|
f"{self.directory}/local_baseline_{key}_{subkey}_confusion_matrix.png", |
|
|
dpi=300, |
|
|
bbox_inches='tight', |
|
|
) |
|
|
plt.close(figure) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
directory = 'main_model_analysis' |
|
|
max_samples = 10000 |
|
|
|
|
|
evaluator = ColorHierarchyEvaluator(device=device, directory=directory) |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("π Starting evaluation of Fashion-MNIST Hierarchy embeddings") |
|
|
print("="*60) |
|
|
results_fashion = evaluator.evaluate_fashion_mnist(max_samples=max_samples) |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("FASHION-MNIST EVALUATION SUMMARY") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
print("\nπ HIERARCHY CLASSIFICATION RESULTS (dims 16-79):") |
|
|
print(f" Text - NN Acc: {results_fashion['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion['text_hierarchy']['separation_score']:.4f}") |
|
|
print(f" Image - NN Acc: {results_fashion['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_fashion['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_fashion['image_hierarchy']['separation_score']:.4f}") |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("π Starting evaluation of Baseline Fashion CLIP on Fashion-MNIST") |
|
|
print("="*60) |
|
|
results_baseline = evaluator.evaluate_baseline_fashion_mnist(max_samples=max_samples) |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("BASELINE FASHION-MNIST EVALUATION SUMMARY") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
print("\nπ HIERARCHY CLASSIFICATION RESULTS (Baseline):") |
|
|
print(f" Text - NN Acc: {results_baseline['text']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline['text']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline['text']['hierarchy']['separation_score']:.4f}") |
|
|
print(f" Image - NN Acc: {results_baseline['image']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline['image']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline['image']['hierarchy']['separation_score']:.4f}") |
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("π Starting evaluation of KAGL Marqo with Color & Hierarchy embeddings") |
|
|
print("="*60) |
|
|
results_kaggle = evaluator.evaluate_kaggle_marqo(max_samples=max_samples) |
|
|
|
|
|
if results_kaggle is not None: |
|
|
print(f"\n{'='*60}") |
|
|
print("KAGL MARQO EVALUATION SUMMARY") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
print("\nπ¨ COLOR CLASSIFICATION RESULTS (dims 0-15):") |
|
|
print(f" Text - NN Acc: {results_kaggle['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['text_color']['separation_score']:.4f}") |
|
|
print(f" Image - NN Acc: {results_kaggle['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['image_color']['separation_score']:.4f}") |
|
|
|
|
|
print("\nπ HIERARCHY CLASSIFICATION RESULTS (dims 16-79):") |
|
|
print(f" Text - NN Acc: {results_kaggle['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['text_hierarchy']['separation_score']:.4f}") |
|
|
print(f" Image - NN Acc: {results_kaggle['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_kaggle['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_kaggle['image_hierarchy']['separation_score']:.4f}") |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("π Starting evaluation of Baseline Fashion CLIP on KAGL Marqo") |
|
|
print("="*60) |
|
|
results_baseline_kaggle = evaluator.evaluate_baseline_kaggle_marqo(max_samples=max_samples) |
|
|
|
|
|
if results_baseline_kaggle is not None: |
|
|
print(f"\n{'='*60}") |
|
|
print("BASELINE KAGL MARQO EVALUATION SUMMARY") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
print("\nπ¨ COLOR CLASSIFICATION RESULTS (Baseline):") |
|
|
print(f" Text - NN Acc: {results_baseline_kaggle['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['text']['color']['separation_score']:.4f}") |
|
|
print(f" Image - NN Acc: {results_baseline_kaggle['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['image']['color']['separation_score']:.4f}") |
|
|
|
|
|
print("\nπ HIERARCHY CLASSIFICATION RESULTS (Baseline):") |
|
|
print(f" Text - NN Acc: {results_baseline_kaggle['text']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['text']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['text']['hierarchy']['separation_score']:.4f}") |
|
|
print(f" Image - NN Acc: {results_baseline_kaggle['image']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_kaggle['image']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_kaggle['image']['hierarchy']['separation_score']:.4f}") |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("π Starting evaluation of Local Validation Dataset with Color & Hierarchy embeddings") |
|
|
print("="*60) |
|
|
results_local = evaluator.evaluate_local_validation(max_samples=max_samples) |
|
|
|
|
|
if results_local is not None: |
|
|
print(f"\n{'='*60}") |
|
|
print("LOCAL VALIDATION DATASET EVALUATION SUMMARY") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
print("\nπ¨ COLOR CLASSIFICATION RESULTS (dims 0-15):") |
|
|
print(f" Text - NN Acc: {results_local['text_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['text_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['text_color']['separation_score']:.4f}") |
|
|
print(f" Image - NN Acc: {results_local['image_color']['accuracy']*100:.1f}% | Centroid Acc: {results_local['image_color']['centroid_accuracy']*100:.1f}% | Separation: {results_local['image_color']['separation_score']:.4f}") |
|
|
|
|
|
print("\nπ HIERARCHY CLASSIFICATION RESULTS (dims 16-79):") |
|
|
print(f" Text - NN Acc: {results_local['text_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_local['text_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_local['text_hierarchy']['separation_score']:.4f}") |
|
|
print(f" Image - NN Acc: {results_local['image_hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_local['image_hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_local['image_hierarchy']['separation_score']:.4f}") |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("π Starting evaluation of Baseline Fashion CLIP on Local Validation") |
|
|
print("="*60) |
|
|
results_baseline_local = evaluator.evaluate_baseline_local_validation(max_samples=max_samples) |
|
|
|
|
|
if results_baseline_local is not None: |
|
|
print(f"\n{'='*60}") |
|
|
print("BASELINE LOCAL VALIDATION EVALUATION SUMMARY") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
print("\nπ¨ COLOR CLASSIFICATION RESULTS (Baseline):") |
|
|
print(f" Text - NN Acc: {results_baseline_local['text']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['text']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['text']['color']['separation_score']:.4f}") |
|
|
print(f" Image - NN Acc: {results_baseline_local['image']['color']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['color']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['color']['separation_score']:.4f}") |
|
|
|
|
|
print("\nπ HIERARCHY CLASSIFICATION RESULTS (Baseline):") |
|
|
print(f" Text - NN Acc: {results_baseline_local['text']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['text']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['text']['hierarchy']['separation_score']:.4f}") |
|
|
print(f" Image - NN Acc: {results_baseline_local['image']['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {results_baseline_local['image']['hierarchy']['centroid_accuracy']*100:.1f}% | Separation: {results_baseline_local['image']['hierarchy']['separation_score']:.4f}") |
|
|
|