Delete Evaluation
Browse files- Evaluation/0_shot_classification.py +0 -512
- Evaluation/basic_test_generalized.py +0 -425
- Evaluation/evaluate_color_embeddings.py +0 -1124
- Evaluation/fashion_search.py +0 -365
- Evaluation/hierarchy_evaluation.py +0 -589
- Evaluation/hierarchy_evaluation_with_clip_baseline.py +0 -808
- Evaluation/main_model_evaluation.py +0 -0
- Evaluation/tsne_images.py +0 -569
Evaluation/0_shot_classification.py
DELETED
|
@@ -1,512 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Zero-shot classification evaluation on a new dataset.
|
| 3 |
-
This file evaluates the main model's performance on unseen data by performing
|
| 4 |
-
zero-shot classification. It compares three methods: color-to-color classification,
|
| 5 |
-
text-to-text, and image-to-text. It generates confusion matrices and classification reports
|
| 6 |
-
for each method to analyze the model's generalization capability.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import os
|
| 10 |
-
# Set environment variable to disable tokenizers parallelism warnings
|
| 11 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 12 |
-
|
| 13 |
-
import torch
|
| 14 |
-
import torch.nn.functional as F
|
| 15 |
-
import numpy as np
|
| 16 |
-
import pandas as pd
|
| 17 |
-
from torch.utils.data import Dataset
|
| 18 |
-
import matplotlib.pyplot as plt
|
| 19 |
-
from PIL import Image
|
| 20 |
-
from torchvision import transforms
|
| 21 |
-
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
|
| 22 |
-
import warnings
|
| 23 |
-
import config
|
| 24 |
-
from tqdm import tqdm
|
| 25 |
-
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
|
| 26 |
-
import seaborn as sns
|
| 27 |
-
from color_model import CLIPModel as ColorModel
|
| 28 |
-
from hierarchy_model import Model, HierarchyExtractor
|
| 29 |
-
|
| 30 |
-
# Suppress warnings
|
| 31 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 32 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
| 33 |
-
|
| 34 |
-
def load_trained_model(model_path, device):
|
| 35 |
-
"""
|
| 36 |
-
Load the trained CLIP model from checkpoint
|
| 37 |
-
"""
|
| 38 |
-
print(f"Loading trained model from: {model_path}")
|
| 39 |
-
|
| 40 |
-
# Load checkpoint
|
| 41 |
-
checkpoint = torch.load(model_path, map_location=device)
|
| 42 |
-
|
| 43 |
-
# Create the base CLIP model
|
| 44 |
-
model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 45 |
-
|
| 46 |
-
# Load the trained weights
|
| 47 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
| 48 |
-
model = model.to(device)
|
| 49 |
-
model.eval()
|
| 50 |
-
|
| 51 |
-
print(f"✅ Model loaded successfully!")
|
| 52 |
-
print(f"📊 Training epoch: {checkpoint['epoch']}")
|
| 53 |
-
print(f"📉 Best validation loss: {checkpoint['best_val_loss']:.4f}")
|
| 54 |
-
|
| 55 |
-
return model, checkpoint
|
| 56 |
-
|
| 57 |
-
def load_feature_models(device):
|
| 58 |
-
"""Load feature models (color and hierarchy)"""
|
| 59 |
-
|
| 60 |
-
# Load color model (embed_dim=16)
|
| 61 |
-
color_checkpoint = torch.load(config.color_model_path, map_location=device, weights_only=True)
|
| 62 |
-
color_model = ColorModel(embed_dim=config.color_emb_dim).to(device)
|
| 63 |
-
color_model.load_state_dict(color_checkpoint)
|
| 64 |
-
color_model.eval()
|
| 65 |
-
color_model.name = 'color'
|
| 66 |
-
|
| 67 |
-
# Load hierarchy model (embed_dim=64)
|
| 68 |
-
hierarchy_checkpoint = torch.load(config.hierarchy_model_path, map_location=device)
|
| 69 |
-
hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
|
| 70 |
-
hierarchy_model = Model(
|
| 71 |
-
num_hierarchy_classes=len(hierarchy_classes),
|
| 72 |
-
embed_dim=config.hierarchy_emb_dim
|
| 73 |
-
).to(device)
|
| 74 |
-
hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
|
| 75 |
-
|
| 76 |
-
# Set up hierarchy extractor
|
| 77 |
-
hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
|
| 78 |
-
hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
|
| 79 |
-
hierarchy_model.eval()
|
| 80 |
-
hierarchy_model.name = 'hierarchy'
|
| 81 |
-
|
| 82 |
-
feature_models = {model.name: model for model in [color_model, hierarchy_model]}
|
| 83 |
-
return feature_models
|
| 84 |
-
|
| 85 |
-
def get_image_embedding(model, image, device):
|
| 86 |
-
"""Get image embedding from the trained model"""
|
| 87 |
-
model.eval()
|
| 88 |
-
with torch.no_grad():
|
| 89 |
-
# Ensure image has 3 channels
|
| 90 |
-
if image.dim() == 3 and image.size(0) == 1:
|
| 91 |
-
image = image.expand(3, -1, -1)
|
| 92 |
-
elif image.dim() == 4 and image.size(1) == 1:
|
| 93 |
-
image = image.expand(-1, 3, -1, -1)
|
| 94 |
-
|
| 95 |
-
# Add batch dimension if missing
|
| 96 |
-
if image.dim() == 3:
|
| 97 |
-
image = image.unsqueeze(0) # Add batch dimension: (C, H, W) -> (1, C, H, W)
|
| 98 |
-
|
| 99 |
-
image = image.to(device)
|
| 100 |
-
|
| 101 |
-
# Use vision model directly to get image embeddings
|
| 102 |
-
vision_outputs = model.vision_model(pixel_values=image)
|
| 103 |
-
image_features = model.visual_projection(vision_outputs.pooler_output)
|
| 104 |
-
|
| 105 |
-
return F.normalize(image_features, dim=-1)
|
| 106 |
-
|
| 107 |
-
def get_text_embedding(model, text, processor, device):
|
| 108 |
-
"""Get text embedding from the trained model"""
|
| 109 |
-
model.eval()
|
| 110 |
-
with torch.no_grad():
|
| 111 |
-
text_inputs = processor(text=text, padding=True, return_tensors="pt")
|
| 112 |
-
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 113 |
-
|
| 114 |
-
# Use text model directly to get text embeddings
|
| 115 |
-
text_outputs = model.text_model(**text_inputs)
|
| 116 |
-
text_features = model.text_projection(text_outputs.pooler_output)
|
| 117 |
-
|
| 118 |
-
return F.normalize(text_features, dim=-1)
|
| 119 |
-
|
| 120 |
-
def evaluate_custom_csv_accuracy(model, dataset, processor, method='similarity'):
|
| 121 |
-
"""
|
| 122 |
-
Evaluate the accuracy of the model on your custom CSV using text-to-text similarity
|
| 123 |
-
|
| 124 |
-
Args:
|
| 125 |
-
model: The trained CLIP model
|
| 126 |
-
dataset: CustomCSVDataset
|
| 127 |
-
processor: CLIPProcessor
|
| 128 |
-
method: 'similarity' or 'classification'
|
| 129 |
-
"""
|
| 130 |
-
print(f"\n📊 === Evaluation of the accuracy on custom CSV (TEXT-TO-TEXT method) ===")
|
| 131 |
-
|
| 132 |
-
model.eval()
|
| 133 |
-
|
| 134 |
-
# Get all unique colors for classification
|
| 135 |
-
all_colors = set()
|
| 136 |
-
for i in range(len(dataset)):
|
| 137 |
-
_, _, color = dataset[i]
|
| 138 |
-
all_colors.add(color)
|
| 139 |
-
|
| 140 |
-
color_list = sorted(list(all_colors))
|
| 141 |
-
print(f"🎨 Colors found: {color_list}")
|
| 142 |
-
|
| 143 |
-
true_labels = []
|
| 144 |
-
predicted_labels = []
|
| 145 |
-
|
| 146 |
-
# Pre-calculate the embeddings of the color descriptions
|
| 147 |
-
print("🔄 Pre-calculating the embeddings of the colors...")
|
| 148 |
-
color_embeddings = {}
|
| 149 |
-
for color in color_list:
|
| 150 |
-
color_emb = get_text_embedding(model, color, processor)
|
| 151 |
-
color_embeddings[color] = color_emb
|
| 152 |
-
|
| 153 |
-
print("🔄 Evaluation in progress...")
|
| 154 |
-
correct_predictions = 0
|
| 155 |
-
|
| 156 |
-
for idx in tqdm(range(len(dataset)), desc="Evaluation"):
|
| 157 |
-
image, text, true_color = dataset[idx]
|
| 158 |
-
|
| 159 |
-
# Get text embedding instead of image embedding
|
| 160 |
-
text_emb = get_text_embedding(model, text, processor)
|
| 161 |
-
|
| 162 |
-
# Calculate the similarity with each possible color
|
| 163 |
-
best_similarity = -1
|
| 164 |
-
predicted_color = color_list[0]
|
| 165 |
-
|
| 166 |
-
for color, color_emb in color_embeddings.items():
|
| 167 |
-
similarity = F.cosine_similarity(text_emb, color_emb, dim=1).item()
|
| 168 |
-
if similarity > best_similarity:
|
| 169 |
-
best_similarity = similarity
|
| 170 |
-
predicted_color = color
|
| 171 |
-
|
| 172 |
-
true_labels.append(true_color)
|
| 173 |
-
predicted_labels.append(predicted_color)
|
| 174 |
-
|
| 175 |
-
if true_color == predicted_color:
|
| 176 |
-
correct_predictions += 1
|
| 177 |
-
|
| 178 |
-
# Calculate the accuracy
|
| 179 |
-
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 180 |
-
|
| 181 |
-
print(f"\n✅ Results of evaluation:")
|
| 182 |
-
print(f"🎯 Global accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
|
| 183 |
-
print(f"📊 Correct predictions: {correct_predictions}/{len(true_labels)}")
|
| 184 |
-
|
| 185 |
-
return true_labels, predicted_labels, accuracy
|
| 186 |
-
|
| 187 |
-
def evaluate_custom_csv_accuracy_image(model, dataset, processor, method='similarity'):
|
| 188 |
-
"""
|
| 189 |
-
Evaluate the accuracy of the model on your custom CSV using image-to-text similarity
|
| 190 |
-
|
| 191 |
-
Args:
|
| 192 |
-
model: The trained CLIP model
|
| 193 |
-
dataset: CustomCSVDataset with images loaded
|
| 194 |
-
processor: CLIPProcessor
|
| 195 |
-
method: 'similarity' or 'classification'
|
| 196 |
-
"""
|
| 197 |
-
print(f"\n📊 === Evaluation of the accuracy on custom CSV (IMAGE-TO-TEXT method) ===")
|
| 198 |
-
|
| 199 |
-
model.eval()
|
| 200 |
-
|
| 201 |
-
# Get all unique colors for classification
|
| 202 |
-
all_colors = set()
|
| 203 |
-
for i in range(len(dataset)):
|
| 204 |
-
_, _, color = dataset[i]
|
| 205 |
-
all_colors.add(color)
|
| 206 |
-
|
| 207 |
-
color_list = sorted(list(all_colors))
|
| 208 |
-
print(f"🎨 Colors found: {color_list}")
|
| 209 |
-
|
| 210 |
-
true_labels = []
|
| 211 |
-
predicted_labels = []
|
| 212 |
-
|
| 213 |
-
# Pre-calculate the embeddings of the color descriptions
|
| 214 |
-
print("🔄 Pre-calculating the embeddings of the colors...")
|
| 215 |
-
color_embeddings = {}
|
| 216 |
-
for color in color_list:
|
| 217 |
-
color_emb = get_text_embedding(model, color, processor)
|
| 218 |
-
color_embeddings[color] = color_emb
|
| 219 |
-
|
| 220 |
-
print("🔄 Evaluation in progress...")
|
| 221 |
-
correct_predictions = 0
|
| 222 |
-
|
| 223 |
-
for idx in tqdm(range(len(dataset)), desc="Evaluation"):
|
| 224 |
-
image, text, true_color = dataset[idx]
|
| 225 |
-
|
| 226 |
-
# Get image embedding (this is the key difference from text-to-text)
|
| 227 |
-
image_emb = get_image_embedding(model, image, processor)
|
| 228 |
-
|
| 229 |
-
# Calculate the similarity with each possible color
|
| 230 |
-
best_similarity = -1
|
| 231 |
-
predicted_color = color_list[0]
|
| 232 |
-
|
| 233 |
-
for color, color_emb in color_embeddings.items():
|
| 234 |
-
similarity = F.cosine_similarity(image_emb, color_emb, dim=1).item()
|
| 235 |
-
if similarity > best_similarity:
|
| 236 |
-
best_similarity = similarity
|
| 237 |
-
predicted_color = color
|
| 238 |
-
|
| 239 |
-
true_labels.append(true_color)
|
| 240 |
-
predicted_labels.append(predicted_color)
|
| 241 |
-
|
| 242 |
-
if true_color == predicted_color:
|
| 243 |
-
correct_predictions += 1
|
| 244 |
-
|
| 245 |
-
# Calculate the accuracy
|
| 246 |
-
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 247 |
-
|
| 248 |
-
print(f"\n✅ Results of evaluation:")
|
| 249 |
-
print(f"🎯 Global accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
|
| 250 |
-
print(f"📊 Correct predictions: {correct_predictions}/{len(true_labels)}")
|
| 251 |
-
|
| 252 |
-
return true_labels, predicted_labels, accuracy
|
| 253 |
-
|
| 254 |
-
def evaluate_custom_csv_accuracy_color_only(model, dataset, processor):
|
| 255 |
-
"""
|
| 256 |
-
Evaluate the accuracy by encoding ONLY the color (not the full text)
|
| 257 |
-
This tests if the embedding space is consistent for colors
|
| 258 |
-
|
| 259 |
-
Args:
|
| 260 |
-
model: The trained CLIP model
|
| 261 |
-
dataset: CustomCSVDataset
|
| 262 |
-
processor: CLIPProcessor
|
| 263 |
-
"""
|
| 264 |
-
print(f"\n📊 === Evaluation of the accuracy on custom CSV (COLOR-TO-COLOR method) ===")
|
| 265 |
-
print("🔬 This test encodes ONLY the color name, not the full text")
|
| 266 |
-
|
| 267 |
-
model.eval()
|
| 268 |
-
|
| 269 |
-
# Get all unique colors for classification
|
| 270 |
-
all_colors = set()
|
| 271 |
-
for i in range(len(dataset)):
|
| 272 |
-
_, _, color = dataset[i]
|
| 273 |
-
all_colors.add(color)
|
| 274 |
-
|
| 275 |
-
color_list = sorted(list(all_colors))
|
| 276 |
-
print(f"🎨 Colors found: {color_list}")
|
| 277 |
-
|
| 278 |
-
true_labels = []
|
| 279 |
-
predicted_labels = []
|
| 280 |
-
|
| 281 |
-
# Pre-calculate the embeddings of the color descriptions
|
| 282 |
-
print("🔄 Pre-calculating the embeddings of the colors...")
|
| 283 |
-
color_embeddings = {}
|
| 284 |
-
for color in color_list:
|
| 285 |
-
color_emb = get_text_embedding(model, color, processor)
|
| 286 |
-
color_embeddings[color] = color_emb
|
| 287 |
-
|
| 288 |
-
print("🔄 Evaluation in progress...")
|
| 289 |
-
correct_predictions = 0
|
| 290 |
-
|
| 291 |
-
for idx in tqdm(range(len(dataset)), desc="Evaluation"):
|
| 292 |
-
image, text, true_color = dataset[idx]
|
| 293 |
-
|
| 294 |
-
# KEY DIFFERENCE: Get embedding of the TRUE COLOR only (not the full text)
|
| 295 |
-
true_color_emb = get_text_embedding(model, true_color, processor)
|
| 296 |
-
|
| 297 |
-
# Calculate the similarity with each possible color
|
| 298 |
-
best_similarity = -1
|
| 299 |
-
predicted_color = color_list[0]
|
| 300 |
-
|
| 301 |
-
for color, color_emb in color_embeddings.items():
|
| 302 |
-
similarity = F.cosine_similarity(true_color_emb, color_emb, dim=1).item()
|
| 303 |
-
if similarity > best_similarity:
|
| 304 |
-
best_similarity = similarity
|
| 305 |
-
predicted_color = color
|
| 306 |
-
|
| 307 |
-
true_labels.append(true_color)
|
| 308 |
-
predicted_labels.append(predicted_color)
|
| 309 |
-
|
| 310 |
-
if true_color == predicted_color:
|
| 311 |
-
correct_predictions += 1
|
| 312 |
-
|
| 313 |
-
# Calculate the accuracy
|
| 314 |
-
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 315 |
-
|
| 316 |
-
print(f"\n✅ Results of evaluation:")
|
| 317 |
-
print(f"🎯 Global accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
|
| 318 |
-
print(f"📊 Correct predictions: {correct_predictions}/{len(true_labels)}")
|
| 319 |
-
|
| 320 |
-
return true_labels, predicted_labels, accuracy
|
| 321 |
-
|
| 322 |
-
def search_custom_csv_by_text(model, dataset, query, processor, top_k=5):
|
| 323 |
-
"""Search in your CSV by text query"""
|
| 324 |
-
print(f"\n🔍 Search in custom CSV: '{query}'")
|
| 325 |
-
|
| 326 |
-
# Get the embedding of the query
|
| 327 |
-
query_emb = get_text_embedding(model, query, processor)
|
| 328 |
-
|
| 329 |
-
similarities = []
|
| 330 |
-
|
| 331 |
-
print("🔄 Calculating similarities...")
|
| 332 |
-
for idx in tqdm(range(len(dataset)), desc="Processing"):
|
| 333 |
-
image, text, color, _, image_path = dataset[idx]
|
| 334 |
-
|
| 335 |
-
# Get the embedding of the image
|
| 336 |
-
image_emb = get_image_embedding(model, image, processor)
|
| 337 |
-
|
| 338 |
-
# Calculer la similarité
|
| 339 |
-
similarity = F.cosine_similarity(query_emb, image_emb, dim=1).item()
|
| 340 |
-
|
| 341 |
-
similarities.append((idx, similarity, text, color, color, image_path))
|
| 342 |
-
|
| 343 |
-
# Trier par similarité
|
| 344 |
-
similarities.sort(key=lambda x: x[1], reverse=True)
|
| 345 |
-
|
| 346 |
-
return similarities[:top_k]
|
| 347 |
-
|
| 348 |
-
def plot_confusion_matrix(true_labels, predicted_labels, save_path=None, title_suffix="text"):
|
| 349 |
-
"""
|
| 350 |
-
Display and save the confusion matrix
|
| 351 |
-
"""
|
| 352 |
-
print("\n📈 === Generation of the confusion matrix ===")
|
| 353 |
-
|
| 354 |
-
# Calculate the confusion matrix
|
| 355 |
-
cm = confusion_matrix(true_labels, predicted_labels)
|
| 356 |
-
|
| 357 |
-
# Get unique labels in sorted order
|
| 358 |
-
unique_labels = sorted(set(true_labels + predicted_labels))
|
| 359 |
-
|
| 360 |
-
# Calculate accuracy
|
| 361 |
-
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 362 |
-
|
| 363 |
-
# Calculate the percentages and round to integers
|
| 364 |
-
cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
|
| 365 |
-
cm_percent = np.around(cm_percent).astype(int)
|
| 366 |
-
|
| 367 |
-
# Create the figure
|
| 368 |
-
plt.figure(figsize=(12, 10))
|
| 369 |
-
|
| 370 |
-
# Confusion matrix with percentages and labels (no decimal points)
|
| 371 |
-
sns.heatmap(cm_percent,
|
| 372 |
-
annot=True,
|
| 373 |
-
fmt='d',
|
| 374 |
-
cmap='Blues',
|
| 375 |
-
cbar_kws={'label': 'Percentage (%)'},
|
| 376 |
-
xticklabels=unique_labels,
|
| 377 |
-
yticklabels=unique_labels)
|
| 378 |
-
|
| 379 |
-
plt.title(f"Confusion Matrix for {title_suffix} - new data - accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)", fontsize=16)
|
| 380 |
-
plt.xlabel('Predictions', fontsize=12)
|
| 381 |
-
plt.ylabel('True colors', fontsize=12)
|
| 382 |
-
plt.xticks(rotation=45, ha='right')
|
| 383 |
-
plt.yticks(rotation=0)
|
| 384 |
-
plt.tight_layout()
|
| 385 |
-
|
| 386 |
-
if save_path:
|
| 387 |
-
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 388 |
-
print(f"💾 Confusion matrix saved: {save_path}")
|
| 389 |
-
|
| 390 |
-
plt.show()
|
| 391 |
-
|
| 392 |
-
return cm
|
| 393 |
-
|
| 394 |
-
class CustomCSVDataset(Dataset):
|
| 395 |
-
def __init__(self, dataframe, image_size=224, load_images=True):
|
| 396 |
-
self.dataframe = dataframe
|
| 397 |
-
self.image_size = image_size
|
| 398 |
-
self.load_images = load_images
|
| 399 |
-
|
| 400 |
-
# Define image transformations
|
| 401 |
-
self.transform = transforms.Compose([
|
| 402 |
-
transforms.Resize((image_size, image_size)),
|
| 403 |
-
transforms.ToTensor(),
|
| 404 |
-
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
| 405 |
-
std=[0.26862954, 0.26130258, 0.27577711])
|
| 406 |
-
])
|
| 407 |
-
|
| 408 |
-
def __len__(self):
|
| 409 |
-
return len(self.dataframe)
|
| 410 |
-
|
| 411 |
-
def __getitem__(self, idx):
|
| 412 |
-
row = self.dataframe.iloc[idx]
|
| 413 |
-
text = row[config.text_column]
|
| 414 |
-
colors = row[config.color_column]
|
| 415 |
-
|
| 416 |
-
if self.load_images and config.column_local_image_path in row:
|
| 417 |
-
# Load the actual image
|
| 418 |
-
try:
|
| 419 |
-
image = Image.open(row[config.column_local_image_path]).convert('RGB')
|
| 420 |
-
image = self.transform(image)
|
| 421 |
-
except Exception as e:
|
| 422 |
-
print(f"Warning: Could not load image {row.get(config.column_local_image_path, 'unknown')}: {e}")
|
| 423 |
-
image = torch.zeros(3, self.image_size, self.image_size)
|
| 424 |
-
else:
|
| 425 |
-
# Return dummy image if not loading images
|
| 426 |
-
image = torch.zeros(3, self.image_size, self.image_size)
|
| 427 |
-
|
| 428 |
-
return image, text, colors
|
| 429 |
-
|
| 430 |
-
if __name__ == "__main__":
|
| 431 |
-
"""Main function with evaluation"""
|
| 432 |
-
print("🚀 === Test and Evaluation of the model on new dataset ===")
|
| 433 |
-
|
| 434 |
-
# Load model
|
| 435 |
-
print("🔧 Loading the model...")
|
| 436 |
-
model, checkpoint = load_trained_model(config.main_model_path, config.device)
|
| 437 |
-
|
| 438 |
-
# Create processor
|
| 439 |
-
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 440 |
-
|
| 441 |
-
# Load new dataset
|
| 442 |
-
print("📊 Loading the new dataset...")
|
| 443 |
-
df = pd.read_csv(config.local_dataset_path) # replace local_dataset_path with a new df
|
| 444 |
-
|
| 445 |
-
print("\n" + "="*80)
|
| 446 |
-
print("🎨 COLOR-TO-COLOR CLASSIFICATION (Control Test)")
|
| 447 |
-
print("="*80)
|
| 448 |
-
|
| 449 |
-
# Create dataset without loading images
|
| 450 |
-
dataset_color = CustomCSVDataset(df, load_images=False)
|
| 451 |
-
|
| 452 |
-
# 0. Evaluation encoding ONLY the color (control test)
|
| 453 |
-
true_labels_color, predicted_labels_color, accuracy_color = evaluate_custom_csv_accuracy_color_only(
|
| 454 |
-
model, dataset_color, processor
|
| 455 |
-
)
|
| 456 |
-
|
| 457 |
-
# Confusion matrix for color-only
|
| 458 |
-
confusion_matrix_color = plot_confusion_matrix(
|
| 459 |
-
true_labels_color, predicted_labels_color,
|
| 460 |
-
save_path="confusion_matrix_color_only.png",
|
| 461 |
-
title_suffix="color-only"
|
| 462 |
-
)
|
| 463 |
-
|
| 464 |
-
print("\n" + "="*80)
|
| 465 |
-
print("📝 TEXT-TO-TEXT CLASSIFICATION")
|
| 466 |
-
print("="*80)
|
| 467 |
-
|
| 468 |
-
# Create dataset without loading images for text-to-text
|
| 469 |
-
dataset_text = CustomCSVDataset(df, load_images=False)
|
| 470 |
-
|
| 471 |
-
# 1. Evaluation of the accuracy (text-to-text)
|
| 472 |
-
true_labels_text, predicted_labels_text, accuracy_text = evaluate_custom_csv_accuracy(
|
| 473 |
-
model, dataset_text, processor, method='similarity'
|
| 474 |
-
)
|
| 475 |
-
|
| 476 |
-
# 2. Confusion matrix for text
|
| 477 |
-
confusion_matrix_text = plot_confusion_matrix(
|
| 478 |
-
true_labels_text, predicted_labels_text,
|
| 479 |
-
save_path="confusion_matrix_text.png",
|
| 480 |
-
title_suffix="text"
|
| 481 |
-
)
|
| 482 |
-
|
| 483 |
-
print("\n" + "="*80)
|
| 484 |
-
print("🖼️ IMAGE-TO-TEXT CLASSIFICATION")
|
| 485 |
-
print("="*80)
|
| 486 |
-
|
| 487 |
-
# Create dataset with images loaded for image-to-text
|
| 488 |
-
dataset_image = CustomCSVDataset(df, load_images=True)
|
| 489 |
-
|
| 490 |
-
# 3. Evaluation of the accuracy (image-to-text)
|
| 491 |
-
true_labels_image, predicted_labels_image, accuracy_image = evaluate_custom_csv_accuracy_image(
|
| 492 |
-
model, dataset_image, processor, method='similarity'
|
| 493 |
-
)
|
| 494 |
-
|
| 495 |
-
# 4. Confusion matrix for images
|
| 496 |
-
confusion_matrix_image = plot_confusion_matrix(
|
| 497 |
-
true_labels_image, predicted_labels_image,
|
| 498 |
-
save_path="confusion_matrix_image.png",
|
| 499 |
-
title_suffix="image"
|
| 500 |
-
)
|
| 501 |
-
|
| 502 |
-
# 5. Summary comparison
|
| 503 |
-
print("\n" + "="*80)
|
| 504 |
-
print("📊 SUMMARY")
|
| 505 |
-
print("="*80)
|
| 506 |
-
print(f"🎨 Color-to-Color Accuracy (Control): {accuracy_color:.4f} ({accuracy_color*100:.2f}%)")
|
| 507 |
-
print(f"📝 Text-to-Text Accuracy: {accuracy_text:.4f} ({accuracy_text*100:.2f}%)")
|
| 508 |
-
print(f"🖼️ Image-to-Text Accuracy: {accuracy_image:.4f} ({accuracy_image*100:.2f}%)")
|
| 509 |
-
print(f"\n📊 Analysis:")
|
| 510 |
-
print(f" • Loss from full text vs color-only: {abs(accuracy_color - accuracy_text):.4f} ({abs(accuracy_color - accuracy_text)*100:.2f}%)")
|
| 511 |
-
print(f" • Difference text vs image: {abs(accuracy_text - accuracy_image):.4f} ({abs(accuracy_text - accuracy_image)*100:.2f}%)")
|
| 512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Evaluation/basic_test_generalized.py
DELETED
|
@@ -1,425 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Generalized evaluation of the main model with sub-module comparison.
|
| 3 |
-
This file evaluates the main model's performance by comparing specialized parts
|
| 4 |
-
(color and hierarchy) with corresponding specialized models. It calculates similarity
|
| 5 |
-
matrices, linear projections between embedding spaces, and generates detailed statistics
|
| 6 |
-
on alignment between different representations.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import os
|
| 10 |
-
import json
|
| 11 |
-
import argparse
|
| 12 |
-
import config
|
| 13 |
-
import torch
|
| 14 |
-
import torch.nn.functional as F
|
| 15 |
-
import pandas as pd
|
| 16 |
-
from PIL import Image
|
| 17 |
-
from torchvision import transforms
|
| 18 |
-
from transformers import CLIPProcessor, CLIPModel as CLIPModelTransformers
|
| 19 |
-
from tqdm.auto import tqdm
|
| 20 |
-
|
| 21 |
-
# Local imports
|
| 22 |
-
from color_model import ColorCLIP as ColorModel, ColorDataset, Tokenizer
|
| 23 |
-
from config import color_model_path, color_emb_dim, device, hierarchy_model_path, hierarchy_emb_dim
|
| 24 |
-
from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def load_color_model(color_model_path, color_emb_dim, device):
|
| 28 |
-
# Load color model
|
| 29 |
-
color_checkpoint = torch.load(color_model_path, map_location=device, weights_only=True)
|
| 30 |
-
color_model = ColorModel(vocab_size=39, embedding_dim=color_emb_dim).to(device)
|
| 31 |
-
color_model.load_state_dict(color_checkpoint)
|
| 32 |
-
|
| 33 |
-
# Load and set the tokenizer
|
| 34 |
-
tokenizer = Tokenizer()
|
| 35 |
-
with open(config.tokeniser_path, 'r') as f:
|
| 36 |
-
vocab_dict = json.load(f)
|
| 37 |
-
color_model.tokenizer = tokenizer
|
| 38 |
-
|
| 39 |
-
color_model.eval()
|
| 40 |
-
return color_model
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def get_emb_color_model(color_model, image_path_to_encode, text_to_encode):
|
| 44 |
-
# Load and preprocess image
|
| 45 |
-
image = Image.open(image_path_to_encode).convert('RGB')
|
| 46 |
-
|
| 47 |
-
transform = transforms.Compose([
|
| 48 |
-
transforms.Resize((224, 224)),
|
| 49 |
-
transforms.ToTensor(),
|
| 50 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 51 |
-
])
|
| 52 |
-
|
| 53 |
-
processed_image = transform(image)
|
| 54 |
-
|
| 55 |
-
# Get embeddings
|
| 56 |
-
processed_image_batch = processed_image.unsqueeze(0).to(device) # Shape: [1, 3, 224, 224]
|
| 57 |
-
with torch.no_grad():
|
| 58 |
-
image_emb = color_model.image_encoder(processed_image_batch)
|
| 59 |
-
|
| 60 |
-
# Text embedding via tokenizer + text_encoder
|
| 61 |
-
token_ids = torch.tensor([color_model.tokenizer(text_to_encode)], dtype=torch.long, device=device)
|
| 62 |
-
lengths = torch.tensor([token_ids.size(1) if token_ids.dim() > 1 else token_ids.size(0)], dtype=torch.long, device=device)
|
| 63 |
-
with torch.no_grad():
|
| 64 |
-
txt_emb = color_model.text_encoder(token_ids, lengths)
|
| 65 |
-
|
| 66 |
-
return image_emb, txt_emb
|
| 67 |
-
|
| 68 |
-
def load_main_model(main_model_path, device):
|
| 69 |
-
checkpoint = torch.load(main_model_path, map_location=device)
|
| 70 |
-
main_model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 71 |
-
state = checkpoint['model_state_dict'] if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint else checkpoint
|
| 72 |
-
try:
|
| 73 |
-
main_model.load_state_dict(state, strict=False)
|
| 74 |
-
except Exception:
|
| 75 |
-
# Fallback: filter matching keys
|
| 76 |
-
model_state = main_model.state_dict()
|
| 77 |
-
filtered = {k: v for k, v in state.items() if k in model_state and model_state[k].shape == v.shape}
|
| 78 |
-
main_model.load_state_dict(filtered, strict=False)
|
| 79 |
-
main_model.to(device)
|
| 80 |
-
main_model.eval()
|
| 81 |
-
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 82 |
-
return main_model, processor
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
def load_hierarchy_model(hierarchy_model_path, device):
|
| 86 |
-
checkpoint = torch.load(hierarchy_model_path, map_location=device)
|
| 87 |
-
hierarchy_classes = checkpoint.get('hierarchy_classes', [])
|
| 88 |
-
model = HierarchyModel(num_hierarchy_classes=len(hierarchy_classes), embed_dim=config.hierarchy_emb_dim).to(device)
|
| 89 |
-
model.load_state_dict(checkpoint['model_state'])
|
| 90 |
-
extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
|
| 91 |
-
model.set_hierarchy_extractor(extractor)
|
| 92 |
-
model.eval()
|
| 93 |
-
return model
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def get_emb_hierarchy_model(hierarchy_model, image_path_to_encode, text_to_encode):
|
| 97 |
-
image = Image.open(image_path_to_encode).convert('RGB')
|
| 98 |
-
transform = transforms.Compose([
|
| 99 |
-
transforms.Resize((224, 224)),
|
| 100 |
-
transforms.ToTensor(),
|
| 101 |
-
])
|
| 102 |
-
image_tensor = transform(image).unsqueeze(0).to(device)
|
| 103 |
-
|
| 104 |
-
with torch.no_grad():
|
| 105 |
-
img_emb = hierarchy_model.get_image_embeddings(image_tensor)
|
| 106 |
-
txt_emb = hierarchy_model.get_text_embeddings(text_to_encode)
|
| 107 |
-
|
| 108 |
-
return img_emb, txt_emb
|
| 109 |
-
|
| 110 |
-
def get_emb_main_model(main_model, processor, image_path_to_encode, text_to_encode):
|
| 111 |
-
image = Image.open(image_path_to_encode).convert('RGB')
|
| 112 |
-
transform = transforms.Compose([
|
| 113 |
-
transforms.Resize((224, 224)),
|
| 114 |
-
transforms.ToTensor(),
|
| 115 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 116 |
-
])
|
| 117 |
-
image = transform(image)
|
| 118 |
-
image = image.unsqueeze(0).to(device)
|
| 119 |
-
# Prepare text inputs via processor
|
| 120 |
-
text_inputs = processor(text=[text_to_encode], return_tensors="pt", padding=True)
|
| 121 |
-
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
| 122 |
-
outputs = main_model(**text_inputs, pixel_values=image)
|
| 123 |
-
text_emb = outputs.text_embeds
|
| 124 |
-
image_emb = outputs.image_embeds
|
| 125 |
-
|
| 126 |
-
return text_emb, image_emb
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
if __name__ == '__main__':
|
| 130 |
-
parser = argparse.ArgumentParser(description='Evaluate main model parts vs small models and build similarity matrices')
|
| 131 |
-
parser.add_argument('--main-checkpoint', type=str, default='models/laion_explicable_model.pth')
|
| 132 |
-
parser.add_argument('--color-checkpoint', type=str, default='models/color_model.pt')
|
| 133 |
-
parser.add_argument('--csv', type=str, default='data/data_with_local_paths.csv')
|
| 134 |
-
parser.add_argument('--color-emb-dim', type=int, default=16)
|
| 135 |
-
parser.add_argument('--num-samples', type=int, default=200)
|
| 136 |
-
parser.add_argument('--seed', type=int, default=42)
|
| 137 |
-
parser.add_argument('--primary-metric', type=str, default='sim_color_txt_img',
|
| 138 |
-
choices=['sim_txt_color_part', 'sim_img_color_part', 'sim_color_txt_img', 'sim_small_txt_img',
|
| 139 |
-
'sim_txt_hierarchy_part', 'sim_img_hierarchy_part'])
|
| 140 |
-
parser.add_argument('--top-k', type=int, default=30)
|
| 141 |
-
parser.add_argument('--heatmap', action='store_true')
|
| 142 |
-
parser.add_argument('--l2-grid', type=str, default='1e-5,1e-4,1e-3,1e-2,1e-1')
|
| 143 |
-
args = parser.parse_args()
|
| 144 |
-
|
| 145 |
-
main_checkpoint = args.main_checkpoint
|
| 146 |
-
color_checkpoint = args.color_checkpoint
|
| 147 |
-
csv = args.csv
|
| 148 |
-
color_emb_dim = args.color_emb_dim
|
| 149 |
-
num_samples = args.num_samples
|
| 150 |
-
seed = args.seed
|
| 151 |
-
primary_metric = args.primary_metric
|
| 152 |
-
top_k = args.top_k
|
| 153 |
-
l2_grid = [float(x) for x in args.l2_grid.split(',') if x]
|
| 154 |
-
device = torch.device("mps")
|
| 155 |
-
|
| 156 |
-
df = pd.read_csv(csv)
|
| 157 |
-
|
| 158 |
-
# Normalize colors (reduce aliasing and sparsity)
|
| 159 |
-
def normalize_color(c):
|
| 160 |
-
if pd.isna(c):
|
| 161 |
-
return c
|
| 162 |
-
s = str(c).strip().lower()
|
| 163 |
-
aliases = {
|
| 164 |
-
'grey': 'gray',
|
| 165 |
-
'navy blue': 'navy',
|
| 166 |
-
'light blue': 'blue',
|
| 167 |
-
'dark blue': 'blue',
|
| 168 |
-
'light grey': 'gray',
|
| 169 |
-
'dark grey': 'gray',
|
| 170 |
-
'light gray': 'gray',
|
| 171 |
-
'dark gray': 'gray',
|
| 172 |
-
}
|
| 173 |
-
return aliases.get(s, s)
|
| 174 |
-
|
| 175 |
-
if config.color_column in df.columns:
|
| 176 |
-
df[config.color_column] = df[config.color_column].apply(normalize_color)
|
| 177 |
-
|
| 178 |
-
color_model = load_color_model(color_checkpoint, color_emb_dim, device)
|
| 179 |
-
main_model, processor = load_main_model(main_checkpoint, device)
|
| 180 |
-
hierarchy_model = load_hierarchy_model(hierarchy_model_path, device)
|
| 181 |
-
|
| 182 |
-
# Results container
|
| 183 |
-
results = []
|
| 184 |
-
|
| 185 |
-
# Accumulators for projection (A: main part, B: small model)
|
| 186 |
-
color_txt_As, color_txt_Bs = [], []
|
| 187 |
-
color_img_As, color_img_Bs = [], []
|
| 188 |
-
hier_txt_As, hier_txt_Bs = [], []
|
| 189 |
-
hier_img_As, hier_img_Bs = [], []
|
| 190 |
-
|
| 191 |
-
# Ensure determinism for sampling
|
| 192 |
-
pd.options.mode.copy_on_write = True
|
| 193 |
-
rng = pd.Series(range(len(df)), dtype=int)
|
| 194 |
-
_ = rng # silence lint
|
| 195 |
-
torch.manual_seed(seed)
|
| 196 |
-
|
| 197 |
-
unique_hiers = sorted(df[config.hierarchy_column].dropna().unique())
|
| 198 |
-
unique_colors = sorted(df[config.color_column].dropna().unique())
|
| 199 |
-
|
| 200 |
-
# Progress bar across all (hierarchy, color) pairs
|
| 201 |
-
total_pairs = len(unique_hiers) * len(unique_colors)
|
| 202 |
-
pair_pbar = tqdm(total=total_pairs, desc="Evaluating pairs", leave=False)
|
| 203 |
-
for hierarchy in unique_hiers:
|
| 204 |
-
for color in unique_colors:
|
| 205 |
-
group = df[(df[config.hierarchy_column] == hierarchy) & (df[config.color_column] == color)]
|
| 206 |
-
|
| 207 |
-
# Sample up to num_samples per (hierarchy, color)
|
| 208 |
-
k = min(num_samples, len(group))
|
| 209 |
-
group_iter = group.sample(n=k, random_state=seed) if len(group) > k else group.iloc[:k]
|
| 210 |
-
|
| 211 |
-
# Progress bar for samples within the pair
|
| 212 |
-
inner_pbar = tqdm(total=len(group_iter), desc=f"{hierarchy}/{color}", leave=False)
|
| 213 |
-
for row_idx, (_, example) in enumerate(group_iter.iterrows()):
|
| 214 |
-
try:
|
| 215 |
-
image_emb, txt_emb = get_emb_color_model(color_model, example['local_image_path'], example['text'])
|
| 216 |
-
image_emb_hier, txt_emb_hier = get_emb_hierarchy_model(hierarchy_model, example['local_image_path'], example['text'])
|
| 217 |
-
text_emb_main_model, image_emb_main_model = get_emb_main_model(
|
| 218 |
-
main_model, processor, example['local_image_path'], example['text']
|
| 219 |
-
)
|
| 220 |
-
|
| 221 |
-
color_part_txt = text_emb_main_model[:, :color_emb_dim]
|
| 222 |
-
color_part_img = image_emb_main_model[:, :color_emb_dim]
|
| 223 |
-
hier_part_txt = text_emb_main_model[:, color_emb_dim:color_emb_dim + hierarchy_emb_dim]
|
| 224 |
-
hier_part_img = image_emb_main_model[:, color_emb_dim:color_emb_dim + hierarchy_emb_dim]
|
| 225 |
-
|
| 226 |
-
# L2-normalize parts and small-model embeddings for stable cosine
|
| 227 |
-
color_part_txt = F.normalize(color_part_txt, dim=1)
|
| 228 |
-
color_part_img = F.normalize(color_part_img, dim=1)
|
| 229 |
-
hier_part_txt = F.normalize(hier_part_txt, dim=1)
|
| 230 |
-
hier_part_img = F.normalize(hier_part_img, dim=1)
|
| 231 |
-
txt_emb = F.normalize(txt_emb, dim=1)
|
| 232 |
-
image_emb = F.normalize(image_emb, dim=1)
|
| 233 |
-
txt_emb_hier = F.normalize(txt_emb_hier, dim=1)
|
| 234 |
-
image_emb_hier = F.normalize(image_emb_hier, dim=1)
|
| 235 |
-
|
| 236 |
-
sim_txt_color_part = F.cosine_similarity(txt_emb, color_part_txt).item()
|
| 237 |
-
sim_img_color_part = F.cosine_similarity(image_emb, color_part_img).item()
|
| 238 |
-
sim_color_txt_img = F.cosine_similarity(color_part_txt, color_part_img).item()
|
| 239 |
-
sim_small_txt_img = F.cosine_similarity(txt_emb, image_emb).item()
|
| 240 |
-
|
| 241 |
-
sim_txt_hierarchy_part = F.cosine_similarity(txt_emb_hier, hier_part_txt).item()
|
| 242 |
-
sim_img_hierarchy_part = F.cosine_similarity(image_emb_hier, hier_part_img).item()
|
| 243 |
-
|
| 244 |
-
# Accumulate for projection fitting later
|
| 245 |
-
color_txt_As.append(color_part_txt.squeeze(0).detach().cpu())
|
| 246 |
-
color_txt_Bs.append(txt_emb.squeeze(0).detach().cpu())
|
| 247 |
-
color_img_As.append(color_part_img.squeeze(0).detach().cpu())
|
| 248 |
-
color_img_Bs.append(image_emb.squeeze(0).detach().cpu())
|
| 249 |
-
|
| 250 |
-
hier_txt_As.append(hier_part_txt.squeeze(0).detach().cpu())
|
| 251 |
-
hier_txt_Bs.append(txt_emb_hier.squeeze(0).detach().cpu())
|
| 252 |
-
hier_img_As.append(hier_part_img.squeeze(0).detach().cpu())
|
| 253 |
-
hier_img_Bs.append(image_emb_hier.squeeze(0).detach().cpu())
|
| 254 |
-
|
| 255 |
-
results.append({
|
| 256 |
-
'hierarchy' "hierarchy",
|
| 257 |
-
'color': color,
|
| 258 |
-
'row_index': int(row_idx),
|
| 259 |
-
'sim_txt_color_part': float(sim_txt_color_part),
|
| 260 |
-
'sim_img_color_part': float(sim_img_color_part),
|
| 261 |
-
'sim_color_txt_img': float(sim_color_txt_img),
|
| 262 |
-
'sim_small_txt_img': float(sim_small_txt_img),
|
| 263 |
-
'sim_txt_hierarchy_part': float(sim_txt_hierarchy_part),
|
| 264 |
-
'sim_img_hierarchy_part': float(sim_img_hierarchy_part),
|
| 265 |
-
})
|
| 266 |
-
except Exception as e:
|
| 267 |
-
print(f"Skipping example due to error: {e}")
|
| 268 |
-
finally:
|
| 269 |
-
inner_pbar.update(1)
|
| 270 |
-
inner_pbar.close()
|
| 271 |
-
pair_pbar.update(1)
|
| 272 |
-
pair_pbar.close()
|
| 273 |
-
|
| 274 |
-
results_df = pd.DataFrame(results)
|
| 275 |
-
|
| 276 |
-
# Save raw results
|
| 277 |
-
os.makedirs('evaluation_outputs', exist_ok=True)
|
| 278 |
-
raw_path = os.path.join('evaluation_outputs', 'similarities_raw.csv')
|
| 279 |
-
results_df.to_csv(raw_path, index=False)
|
| 280 |
-
print(f"Saved raw similarities to {raw_path}")
|
| 281 |
-
|
| 282 |
-
# Intelligent averages
|
| 283 |
-
metrics = ['sim_txt_color_part', 'sim_img_color_part', 'sim_color_txt_img', 'sim_small_txt_img',
|
| 284 |
-
'sim_txt_hierarchy_part', 'sim_img_hierarchy_part']
|
| 285 |
-
|
| 286 |
-
# Overall means
|
| 287 |
-
overall_means = results_df[metrics].mean().to_frame(name='mean').T
|
| 288 |
-
overall_means.insert(0, 'level', 'overall')
|
| 289 |
-
|
| 290 |
-
# By hierarchy
|
| 291 |
-
by_hierarchy = results_df.groupby(config.hierarchy_column)[metrics].mean().reset_index()
|
| 292 |
-
by_hierarchy.insert(0, 'level', config.hierarchy_column)
|
| 293 |
-
|
| 294 |
-
# By color
|
| 295 |
-
by_color = results_df.groupby(config.color_column)[metrics].mean().reset_index()
|
| 296 |
-
by_color.insert(0, 'level', config.color_column)
|
| 297 |
-
|
| 298 |
-
# By hierarchy+color
|
| 299 |
-
by_pair = results_df.groupby([config.hierarchy_column, config.color_column])[metrics].mean().reset_index()
|
| 300 |
-
by_pair.insert(0, 'level', 'hierarchy_color')
|
| 301 |
-
|
| 302 |
-
summary_df = pd.concat([overall_means, by_hierarchy, by_color, by_pair], ignore_index=True)
|
| 303 |
-
summary_path = os.path.join('evaluation_outputs', 'similarities_summary.csv')
|
| 304 |
-
summary_df.to_csv(summary_path, index=False)
|
| 305 |
-
print(f"Saved summary statistics to {summary_path}")
|
| 306 |
-
|
| 307 |
-
# =====================
|
| 308 |
-
# Similarity matrices for best hierarchy-color combinations
|
| 309 |
-
# =====================
|
| 310 |
-
try:
|
| 311 |
-
by_pair_core = results_df.groupby([config.hierarchy_column, config.color_column])[metrics].mean().reset_index()
|
| 312 |
-
top_pairs = by_pair_core.nlargest(top_k, primary_metric)
|
| 313 |
-
matrix = top_pairs.pivot(index=config.hierarchy_column, columns=config.color_column, values=primary_metric)
|
| 314 |
-
os.makedirs('evaluation_outputs', exist_ok=True)
|
| 315 |
-
matrix_csv_path = os.path.join('evaluation_outputs', f'similarity_matrix_{primary_metric}_top{top_k}.csv')
|
| 316 |
-
matrix.to_csv(matrix_csv_path)
|
| 317 |
-
print(f"Saved similarity matrix to {matrix_csv_path}")
|
| 318 |
-
|
| 319 |
-
if args.heatmap:
|
| 320 |
-
try:
|
| 321 |
-
import seaborn as sns
|
| 322 |
-
import matplotlib.pyplot as plt
|
| 323 |
-
plt.figure(figsize=(max(6, 0.5 * len(matrix.columns)), max(4, 0.5 * len(matrix.index))))
|
| 324 |
-
sns.heatmap(matrix, annot=False, cmap='viridis')
|
| 325 |
-
plt.title(f'Similarity matrix (top {top_k}) - {primary_metric}')
|
| 326 |
-
heatmap_path = os.path.join('evaluation_outputs', f'similarity_matrix_{primary_metric}_top{top_k}.png')
|
| 327 |
-
plt.tight_layout()
|
| 328 |
-
plt.savefig(heatmap_path, dpi=200)
|
| 329 |
-
plt.close()
|
| 330 |
-
print(f"Saved similarity heatmap to {heatmap_path}")
|
| 331 |
-
except Exception as e:
|
| 332 |
-
print(f"Skipping heatmap generation: {e}")
|
| 333 |
-
except Exception as e:
|
| 334 |
-
print(f"Skipping matrix generation: {e}")
|
| 335 |
-
|
| 336 |
-
# =====================
|
| 337 |
-
# Learn projections A->B and report projected cosine means
|
| 338 |
-
# =====================
|
| 339 |
-
def fit_ridge_projection(A, B, l2_reg=1e-3):
|
| 340 |
-
# A: [N, D_in], B: [N, D_out]
|
| 341 |
-
A = torch.stack(A) # [N, D_in]
|
| 342 |
-
B = torch.stack(B) # [N, D_out]
|
| 343 |
-
# Closed-form ridge: W = (A^T A + λI)^-1 A^T B
|
| 344 |
-
AtA = A.T @ A
|
| 345 |
-
D_in = AtA.shape[0]
|
| 346 |
-
AtA_reg = AtA + l2_reg * torch.eye(D_in)
|
| 347 |
-
W = torch.linalg.solve(AtA_reg, A.T @ B)
|
| 348 |
-
return W # [D_in, D_out]
|
| 349 |
-
|
| 350 |
-
def fit_ridge_with_cv(A, B, l2_values):
|
| 351 |
-
# Simple holdout CV: 80/20 split
|
| 352 |
-
if len(A) < 10:
|
| 353 |
-
# Not enough data for split; fallback to middle lambda
|
| 354 |
-
best_l2 = l2_values[min(len(l2_values) // 2, len(l2_values)-1)]
|
| 355 |
-
W = fit_ridge_projection(A, B, best_l2)
|
| 356 |
-
return W, best_l2, None
|
| 357 |
-
|
| 358 |
-
N = len(A)
|
| 359 |
-
idx = torch.randperm(N)
|
| 360 |
-
split = int(0.8 * N)
|
| 361 |
-
train_idx = idx[:split]
|
| 362 |
-
val_idx = idx[split:]
|
| 363 |
-
|
| 364 |
-
A_tensor = torch.stack(A)
|
| 365 |
-
B_tensor = torch.stack(B)
|
| 366 |
-
|
| 367 |
-
A_train, B_train = A_tensor[train_idx], B_tensor[train_idx]
|
| 368 |
-
A_val, B_val = A_tensor[val_idx], B_tensor[val_idx]
|
| 369 |
-
|
| 370 |
-
def to_list(t):
|
| 371 |
-
return [row for row in t]
|
| 372 |
-
|
| 373 |
-
best_l2 = None
|
| 374 |
-
best_score = -1.0
|
| 375 |
-
for l2 in l2_values:
|
| 376 |
-
W = fit_ridge_projection(to_list(A_train), to_list(B_train), l2)
|
| 377 |
-
score = mean_projected_cosine(to_list(A_val), to_list(B_val), W)
|
| 378 |
-
if score > best_score:
|
| 379 |
-
best_score = score
|
| 380 |
-
best_l2 = l2
|
| 381 |
-
|
| 382 |
-
# Refit on all with best_l2
|
| 383 |
-
W_best = fit_ridge_projection(A, B, best_l2)
|
| 384 |
-
return W_best, best_l2, best_score
|
| 385 |
-
|
| 386 |
-
def mean_projected_cosine(A, B, W):
|
| 387 |
-
A = torch.stack(A)
|
| 388 |
-
B = torch.stack(B)
|
| 389 |
-
A_proj = A @ W
|
| 390 |
-
A_proj = F.normalize(A_proj, dim=1)
|
| 391 |
-
B = F.normalize(B, dim=1)
|
| 392 |
-
return torch.mean(torch.sum(A_proj * B, dim=1)).item()
|
| 393 |
-
|
| 394 |
-
projection_report = {}
|
| 395 |
-
|
| 396 |
-
if len(color_txt_As) >= 8:
|
| 397 |
-
W_ct, best_l2_ct, cv_ct = fit_ridge_with_cv(color_txt_As, color_txt_Bs, l2_grid)
|
| 398 |
-
projection_report['proj_sim_txt_color_part_mean'] = mean_projected_cosine(color_txt_As, color_txt_Bs, W_ct)
|
| 399 |
-
projection_report['proj_txt_color_part_best_l2'] = best_l2_ct
|
| 400 |
-
if cv_ct is not None:
|
| 401 |
-
projection_report['proj_txt_color_part_cv_val'] = cv_ct
|
| 402 |
-
if len(color_img_As) >= 8:
|
| 403 |
-
W_ci, best_l2_ci, cv_ci = fit_ridge_with_cv(color_img_As, color_img_Bs, l2_grid)
|
| 404 |
-
projection_report['proj_sim_img_color_part_mean'] = mean_projected_cosine(color_img_As, color_img_Bs, W_ci)
|
| 405 |
-
projection_report['proj_img_color_part_best_l2'] = best_l2_ci
|
| 406 |
-
if cv_ci is not None:
|
| 407 |
-
projection_report['proj_img_color_part_cv_val'] = cv_ci
|
| 408 |
-
if len(hier_txt_As) >= 8:
|
| 409 |
-
W_ht, best_l2_ht, cv_ht = fit_ridge_with_cv(hier_txt_As, hier_txt_Bs, l2_grid)
|
| 410 |
-
projection_report['proj_sim_txt_hierarchy_part_mean'] = mean_projected_cosine(hier_txt_As, hier_txt_Bs, W_ht)
|
| 411 |
-
projection_report['proj_txt_hierarchy_part_best_l2'] = best_l2_ht
|
| 412 |
-
if cv_ht is not None:
|
| 413 |
-
projection_report['proj_txt_hierarchy_part_cv_val'] = cv_ht
|
| 414 |
-
if len(hier_img_As) >= 8:
|
| 415 |
-
W_hi, best_l2_hi, cv_hi = fit_ridge_with_cv(hier_img_As, hier_img_Bs, l2_grid)
|
| 416 |
-
projection_report['proj_sim_img_hierarchy_part_mean'] = mean_projected_cosine(hier_img_As, hier_img_Bs, W_hi)
|
| 417 |
-
projection_report['proj_img_hierarchy_part_best_l2'] = best_l2_hi
|
| 418 |
-
if cv_hi is not None:
|
| 419 |
-
projection_report['proj_img_hierarchy_part_cv_val'] = cv_hi
|
| 420 |
-
|
| 421 |
-
proj_summary_path = os.path.join('evaluation_outputs', 'projection_summary.json')
|
| 422 |
-
with open(proj_summary_path, 'w') as f:
|
| 423 |
-
json.dump(projection_report, f, indent=2)
|
| 424 |
-
print(f"Saved projection summary to {proj_summary_path}")
|
| 425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Evaluation/evaluate_color_embeddings.py
DELETED
|
@@ -1,1124 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Comprehensive evaluation of color embeddings with Fashion-CLIP comparison.
|
| 3 |
-
This file evaluates the quality of color embeddings generated by the ColorCLIP model
|
| 4 |
-
by calculating intra-class and inter-class similarity metrics, classification accuracies,
|
| 5 |
-
and generating confusion matrices. It also compares results with Fashion-CLIP as a baseline
|
| 6 |
-
to measure relative performance.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
import pandas as pd
|
| 12 |
-
import numpy as np
|
| 13 |
-
import matplotlib.pyplot as plt
|
| 14 |
-
import seaborn as sns
|
| 15 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 16 |
-
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
|
| 17 |
-
from collections import defaultdict
|
| 18 |
-
import os
|
| 19 |
-
import json
|
| 20 |
-
from tqdm import tqdm
|
| 21 |
-
from torch.utils.data import Dataset, DataLoader
|
| 22 |
-
from torchvision import transforms
|
| 23 |
-
import requests
|
| 24 |
-
from io import BytesIO
|
| 25 |
-
from PIL import Image
|
| 26 |
-
import warnings
|
| 27 |
-
warnings.filterwarnings('ignore')
|
| 28 |
-
from color_model import ColorCLIP, Tokenizer, ImageEncoder, TextEncoder, collate_batch
|
| 29 |
-
from torch.utils.data import DataLoader
|
| 30 |
-
from transformers import CLIPProcessor, CLIPModel as TransformersCLIPModel
|
| 31 |
-
import config
|
| 32 |
-
|
| 33 |
-
class ColorDataset(Dataset):
|
| 34 |
-
"""
|
| 35 |
-
Dataset class for color embedding evaluation.
|
| 36 |
-
|
| 37 |
-
Handles loading images from various sources (local paths, URLs, bytes) and
|
| 38 |
-
applying appropriate transformations for evaluation.
|
| 39 |
-
"""
|
| 40 |
-
def __init__(self, dataframe):
|
| 41 |
-
"""
|
| 42 |
-
Initialize the color dataset.
|
| 43 |
-
|
| 44 |
-
Args:
|
| 45 |
-
dataframe: DataFrame containing image paths/URLs, text, and color labels
|
| 46 |
-
"""
|
| 47 |
-
self.dataframe = dataframe
|
| 48 |
-
self.transform = transforms.Compose([
|
| 49 |
-
transforms.Resize((224, 224)),
|
| 50 |
-
transforms.ToTensor(),
|
| 51 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 52 |
-
])
|
| 53 |
-
|
| 54 |
-
def __len__(self):
|
| 55 |
-
return len(self.dataframe)
|
| 56 |
-
|
| 57 |
-
def __getitem__(self, idx):
|
| 58 |
-
row = self.dataframe.iloc[idx]
|
| 59 |
-
|
| 60 |
-
# Handle image - it should be in row[config.column_url_image] and contain the image data
|
| 61 |
-
image_data = row[config.column_url_image]
|
| 62 |
-
|
| 63 |
-
try:
|
| 64 |
-
# Check if image_data has 'bytes' key or is already PIL Image
|
| 65 |
-
if isinstance(image_data, dict) and 'bytes' in image_data:
|
| 66 |
-
image = Image.open(BytesIO(image_data['bytes'])).convert("RGB")
|
| 67 |
-
elif hasattr(image_data, 'convert'): # Already a PIL Image
|
| 68 |
-
image = image_data.convert("RGB")
|
| 69 |
-
elif isinstance(image_data, str):
|
| 70 |
-
# It's a file path (local or URL)
|
| 71 |
-
if image_data.startswith('http'):
|
| 72 |
-
# It's a URL - download the image
|
| 73 |
-
response = requests.get(image_data, timeout=10)
|
| 74 |
-
response.raise_for_status()
|
| 75 |
-
image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 76 |
-
else:
|
| 77 |
-
# It's a local file path
|
| 78 |
-
image = Image.open(image_data).convert("RGB")
|
| 79 |
-
else:
|
| 80 |
-
# Assume it's bytes data
|
| 81 |
-
image = Image.open(BytesIO(image_data)).convert("RGB")
|
| 82 |
-
|
| 83 |
-
# Apply transform
|
| 84 |
-
image = self.transform(image)
|
| 85 |
-
|
| 86 |
-
except Exception as e:
|
| 87 |
-
print(f"⚠️ Failed to load image {idx}: {e}")
|
| 88 |
-
# Return a placeholder image
|
| 89 |
-
image = torch.zeros(3, 224, 224)
|
| 90 |
-
|
| 91 |
-
# Get text and color
|
| 92 |
-
description = row[config.text_column]
|
| 93 |
-
color = row[config.color_column]
|
| 94 |
-
|
| 95 |
-
return image, description, color
|
| 96 |
-
|
| 97 |
-
class EmbeddingEvaluator:
|
| 98 |
-
"""
|
| 99 |
-
Evaluator for color embeddings generated by the ColorCLIP model.
|
| 100 |
-
|
| 101 |
-
This class provides methods to evaluate the quality of color embeddings by computing
|
| 102 |
-
similarity metrics, classification accuracies, and generating visualizations.
|
| 103 |
-
"""
|
| 104 |
-
|
| 105 |
-
def __init__(self, model_path, embed_dim):
|
| 106 |
-
"""
|
| 107 |
-
Initialize the embedding evaluator.
|
| 108 |
-
|
| 109 |
-
Args:
|
| 110 |
-
model_path: Path to the trained ColorCLIP model checkpoint
|
| 111 |
-
embed_dim: Embedding dimension for the model
|
| 112 |
-
"""
|
| 113 |
-
self.device = config.device
|
| 114 |
-
|
| 115 |
-
# Initialize tokenizer first to get vocab size
|
| 116 |
-
self.tokenizer = Tokenizer()
|
| 117 |
-
vocab_size = None
|
| 118 |
-
|
| 119 |
-
# Load vocabulary if available to determine vocab_size
|
| 120 |
-
if os.path.exists(config.tokeniser_path):
|
| 121 |
-
with open(config.tokeniser_path, 'r') as f:
|
| 122 |
-
vocab_dict = json.load(f)
|
| 123 |
-
# Manually load vocabulary
|
| 124 |
-
self.tokenizer.word2idx = defaultdict(lambda: 0, {k: int(v) for k, v in vocab_dict.items()})
|
| 125 |
-
self.tokenizer.idx2word = {int(v): k for k, v in vocab_dict.items() if int(v) > 0}
|
| 126 |
-
self.tokenizer.counter = max(self.tokenizer.word2idx.values(), default=0) + 1
|
| 127 |
-
vocab_size = self.tokenizer.counter
|
| 128 |
-
print(f"Tokenizer vocabulary loaded from {config.tokeniser_path}")
|
| 129 |
-
else:
|
| 130 |
-
print(f"Warning: {config.tokeniser_path} not found. Using default tokenizer.")
|
| 131 |
-
|
| 132 |
-
# Load checkpoint to get vocab_size and state_dict
|
| 133 |
-
checkpoint = None
|
| 134 |
-
if os.path.exists(model_path):
|
| 135 |
-
checkpoint = torch.load(model_path, map_location=self.device)
|
| 136 |
-
|
| 137 |
-
# Try to get vocab_size from model checkpoint if not already determined
|
| 138 |
-
if vocab_size is None:
|
| 139 |
-
# Try to get vocab_size from metadata
|
| 140 |
-
if isinstance(checkpoint, dict) and 'vocab_size' in checkpoint:
|
| 141 |
-
vocab_size = checkpoint['vocab_size']
|
| 142 |
-
# Otherwise, try to infer from model state dict
|
| 143 |
-
elif isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 144 |
-
state_dict = checkpoint['model_state_dict']
|
| 145 |
-
if 'text_encoder.embedding.weight' in state_dict:
|
| 146 |
-
vocab_size = state_dict['text_encoder.embedding.weight'].shape[0]
|
| 147 |
-
elif isinstance(checkpoint, dict) and 'text_encoder.embedding.weight' in checkpoint:
|
| 148 |
-
vocab_size = checkpoint['text_encoder.embedding.weight'].shape[0]
|
| 149 |
-
|
| 150 |
-
# Fallback to default if still not determined
|
| 151 |
-
if vocab_size is None:
|
| 152 |
-
vocab_size = 39 # Default fallback
|
| 153 |
-
print(f"Warning: Could not determine vocab_size, using default: {vocab_size}")
|
| 154 |
-
|
| 155 |
-
# Initialize model with determined vocab_size
|
| 156 |
-
self.model = ColorCLIP(vocab_size=vocab_size, embedding_dim=embed_dim).to(self.device)
|
| 157 |
-
|
| 158 |
-
# Load trained model state dict
|
| 159 |
-
if checkpoint is not None:
|
| 160 |
-
state_dict = checkpoint.get('model_state_dict', checkpoint)
|
| 161 |
-
self.model.load_state_dict(state_dict)
|
| 162 |
-
print(f"Model loaded from {model_path}")
|
| 163 |
-
else:
|
| 164 |
-
print(f"Warning: Model file {model_path} not found. Using untrained model.")
|
| 165 |
-
|
| 166 |
-
self.model.eval()
|
| 167 |
-
|
| 168 |
-
def extract_embeddings(self, dataloader, embedding_type='text'):
|
| 169 |
-
"""
|
| 170 |
-
Extract embeddings from the model for a given dataloader.
|
| 171 |
-
|
| 172 |
-
Args:
|
| 173 |
-
dataloader: DataLoader containing images, texts, and colors
|
| 174 |
-
embedding_type: Type of embeddings to extract ('text', 'image', or 'color')
|
| 175 |
-
|
| 176 |
-
Returns:
|
| 177 |
-
Tuple of (embeddings array, labels list, texts list)
|
| 178 |
-
"""
|
| 179 |
-
all_embeddings = []
|
| 180 |
-
all_labels = []
|
| 181 |
-
all_texts = []
|
| 182 |
-
|
| 183 |
-
with torch.no_grad():
|
| 184 |
-
for images, texts, colors in tqdm(dataloader, desc=f"Extracting {embedding_type} embeddings"):
|
| 185 |
-
if embedding_type == 'text':
|
| 186 |
-
# Tokenize texts using the tokenizer
|
| 187 |
-
tokenized_texts = [self.tokenizer(text) for text in texts]
|
| 188 |
-
# Convert to tensors and pad sequences
|
| 189 |
-
text_tensors = [torch.tensor(t, dtype=torch.long) for t in tokenized_texts]
|
| 190 |
-
text_tokens = nn.utils.rnn.pad_sequence(text_tensors, batch_first=True, padding_value=0).to(self.device)
|
| 191 |
-
lengths = torch.tensor([len(t) for t in tokenized_texts], dtype=torch.long).to(self.device)
|
| 192 |
-
embeddings = self.model.text_encoder(text_tokens, lengths)
|
| 193 |
-
labels = colors
|
| 194 |
-
elif embedding_type == 'image':
|
| 195 |
-
images = images.to(self.device)
|
| 196 |
-
embeddings = self.model.image_encoder(images)
|
| 197 |
-
labels = colors
|
| 198 |
-
elif embedding_type == 'color':
|
| 199 |
-
# Tokenize color names using the tokenizer
|
| 200 |
-
tokenized_colors = [self.tokenizer(color) for color in colors]
|
| 201 |
-
# Convert to tensors and pad sequences
|
| 202 |
-
color_tensors = [torch.tensor(t, dtype=torch.long) for t in tokenized_colors]
|
| 203 |
-
color_tokens = nn.utils.rnn.pad_sequence(color_tensors, batch_first=True, padding_value=0).to(self.device)
|
| 204 |
-
lengths = torch.tensor([len(t) for t in tokenized_colors], dtype=torch.long).to(self.device)
|
| 205 |
-
embeddings = self.model.text_encoder(color_tokens, lengths)
|
| 206 |
-
labels = colors
|
| 207 |
-
|
| 208 |
-
all_embeddings.append(embeddings.cpu().numpy())
|
| 209 |
-
all_labels.extend(labels)
|
| 210 |
-
all_texts.extend(texts)
|
| 211 |
-
|
| 212 |
-
return np.vstack(all_embeddings), all_labels, all_texts
|
| 213 |
-
|
| 214 |
-
def compute_similarity_metrics(self, embeddings, labels):
|
| 215 |
-
"""Compute intra-class and inter-class similarities"""
|
| 216 |
-
similarities = cosine_similarity(embeddings)
|
| 217 |
-
|
| 218 |
-
# Group embeddings by color
|
| 219 |
-
color_groups = defaultdict(list)
|
| 220 |
-
for i, color in enumerate(labels):
|
| 221 |
-
color_groups[color].append(i)
|
| 222 |
-
|
| 223 |
-
# Calculate intra-class similarities (same color)
|
| 224 |
-
intra_class_similarities = []
|
| 225 |
-
for color, indices in color_groups.items():
|
| 226 |
-
if len(indices) > 1:
|
| 227 |
-
for i in range(len(indices)):
|
| 228 |
-
for j in range(i+1, len(indices)):
|
| 229 |
-
sim = similarities[indices[i], indices[j]]
|
| 230 |
-
intra_class_similarities.append(sim)
|
| 231 |
-
|
| 232 |
-
# Calculate inter-class similarities (different colors)
|
| 233 |
-
inter_class_similarities = []
|
| 234 |
-
colors = list(color_groups.keys())
|
| 235 |
-
for i in range(len(colors)):
|
| 236 |
-
for j in range(i+1, len(colors)):
|
| 237 |
-
color1_indices = color_groups[colors[i]]
|
| 238 |
-
color2_indices = color_groups[colors[j]]
|
| 239 |
-
|
| 240 |
-
for idx1 in color1_indices:
|
| 241 |
-
for idx2 in color2_indices:
|
| 242 |
-
sim = similarities[idx1, idx2]
|
| 243 |
-
inter_class_similarities.append(sim)
|
| 244 |
-
|
| 245 |
-
# Calculate classification accuracy using nearest neighbor in embedding space
|
| 246 |
-
nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
|
| 247 |
-
|
| 248 |
-
# Calculate classification accuracy using centroids
|
| 249 |
-
centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
|
| 250 |
-
|
| 251 |
-
return {
|
| 252 |
-
'intra_class_similarities': intra_class_similarities,
|
| 253 |
-
'inter_class_similarities': inter_class_similarities,
|
| 254 |
-
'intra_class_mean': np.mean(intra_class_similarities) if intra_class_similarities else 0,
|
| 255 |
-
'inter_class_mean': np.mean(inter_class_similarities) if inter_class_similarities else 0,
|
| 256 |
-
'separation_score': np.mean(intra_class_similarities) - np.mean(inter_class_similarities) if intra_class_similarities and inter_class_similarities else 0,
|
| 257 |
-
'accuracy': nn_accuracy,
|
| 258 |
-
'centroid_accuracy': centroid_accuracy
|
| 259 |
-
}
|
| 260 |
-
|
| 261 |
-
def compute_embedding_accuracy(self, embeddings, labels, similarities):
|
| 262 |
-
"""Compute classification accuracy using nearest neighbor in embedding space"""
|
| 263 |
-
correct_predictions = 0
|
| 264 |
-
total_predictions = len(labels)
|
| 265 |
-
|
| 266 |
-
for i in range(len(embeddings)):
|
| 267 |
-
true_label = labels[i]
|
| 268 |
-
|
| 269 |
-
# Find the most similar embedding (excluding itself)
|
| 270 |
-
similarities_row = similarities[i].copy()
|
| 271 |
-
similarities_row[i] = -1 # Exclude self-similarity
|
| 272 |
-
nearest_neighbor_idx = np.argmax(similarities_row)
|
| 273 |
-
predicted_label = labels[nearest_neighbor_idx]
|
| 274 |
-
|
| 275 |
-
if predicted_label == true_label:
|
| 276 |
-
correct_predictions += 1
|
| 277 |
-
|
| 278 |
-
return correct_predictions / total_predictions if total_predictions > 0 else 0
|
| 279 |
-
|
| 280 |
-
def compute_centroid_accuracy(self, embeddings, labels):
|
| 281 |
-
"""Compute classification accuracy using color centroids"""
|
| 282 |
-
# Create centroids for each color
|
| 283 |
-
unique_colors = list(set(labels))
|
| 284 |
-
centroids = {}
|
| 285 |
-
|
| 286 |
-
for color in unique_colors:
|
| 287 |
-
color_indices = [i for i, label in enumerate(labels) if label == color]
|
| 288 |
-
color_embeddings = embeddings[color_indices]
|
| 289 |
-
centroids[color] = np.mean(color_embeddings, axis=0)
|
| 290 |
-
|
| 291 |
-
# Classify each embedding to nearest centroid
|
| 292 |
-
correct_predictions = 0
|
| 293 |
-
total_predictions = len(labels)
|
| 294 |
-
|
| 295 |
-
for i, embedding in enumerate(embeddings):
|
| 296 |
-
true_label = labels[i]
|
| 297 |
-
|
| 298 |
-
# Find closest centroid
|
| 299 |
-
best_similarity = -1
|
| 300 |
-
predicted_label = None
|
| 301 |
-
|
| 302 |
-
for color, centroid in centroids.items():
|
| 303 |
-
similarity = cosine_similarity([embedding], [centroid])[0][0]
|
| 304 |
-
if similarity > best_similarity:
|
| 305 |
-
best_similarity = similarity
|
| 306 |
-
predicted_label = color
|
| 307 |
-
|
| 308 |
-
if predicted_label == true_label:
|
| 309 |
-
correct_predictions += 1
|
| 310 |
-
|
| 311 |
-
return correct_predictions / total_predictions if total_predictions > 0 else 0
|
| 312 |
-
|
| 313 |
-
def predict_colors_from_embeddings(self, embeddings, labels):
|
| 314 |
-
"""Predict colors from embeddings using centroid-based classification"""
|
| 315 |
-
# Create color centroids from training data
|
| 316 |
-
unique_colors = list(set(labels))
|
| 317 |
-
centroids = {}
|
| 318 |
-
|
| 319 |
-
for color in unique_colors:
|
| 320 |
-
color_indices = [i for i, label in enumerate(labels) if label == color]
|
| 321 |
-
color_embeddings = embeddings[color_indices]
|
| 322 |
-
centroids[color] = np.mean(color_embeddings, axis=0)
|
| 323 |
-
|
| 324 |
-
# Predict colors for all embeddings
|
| 325 |
-
predictions = []
|
| 326 |
-
|
| 327 |
-
for i, embedding in enumerate(embeddings):
|
| 328 |
-
# Find closest centroid
|
| 329 |
-
best_similarity = -1
|
| 330 |
-
predicted_color = None
|
| 331 |
-
|
| 332 |
-
for color, centroid in centroids.items():
|
| 333 |
-
similarity = cosine_similarity([embedding], [centroid])[0][0]
|
| 334 |
-
if similarity > best_similarity:
|
| 335 |
-
best_similarity = similarity
|
| 336 |
-
predicted_color = color
|
| 337 |
-
|
| 338 |
-
predictions.append(predicted_color)
|
| 339 |
-
|
| 340 |
-
return predictions
|
| 341 |
-
|
| 342 |
-
def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix"):
|
| 343 |
-
"""Create and plot confusion matrix"""
|
| 344 |
-
# Get unique labels
|
| 345 |
-
unique_labels = sorted(list(set(true_labels + predicted_labels)))
|
| 346 |
-
|
| 347 |
-
# Create confusion matrix
|
| 348 |
-
cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
|
| 349 |
-
|
| 350 |
-
# Calculate accuracy
|
| 351 |
-
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 352 |
-
|
| 353 |
-
# Plot confusion matrix
|
| 354 |
-
plt.figure(figsize=(12, 10))
|
| 355 |
-
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
| 356 |
-
xticklabels=unique_labels, yticklabels=unique_labels)
|
| 357 |
-
plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
|
| 358 |
-
plt.ylabel('True Color')
|
| 359 |
-
plt.xlabel('Predicted Color')
|
| 360 |
-
plt.xticks(rotation=45)
|
| 361 |
-
plt.yticks(rotation=0)
|
| 362 |
-
plt.tight_layout()
|
| 363 |
-
|
| 364 |
-
return plt.gcf(), accuracy, cm
|
| 365 |
-
|
| 366 |
-
def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings"):
|
| 367 |
-
"""Evaluate classification performance and create confusion matrix"""
|
| 368 |
-
# Predict colors
|
| 369 |
-
predictions = self.predict_colors_from_embeddings(embeddings, labels)
|
| 370 |
-
|
| 371 |
-
# Calculate accuracy
|
| 372 |
-
accuracy = accuracy_score(labels, predictions)
|
| 373 |
-
|
| 374 |
-
# Create confusion matrix
|
| 375 |
-
fig, acc, cm = self.create_confusion_matrix(labels, predictions,
|
| 376 |
-
f"{embedding_type} - Color Classification")
|
| 377 |
-
|
| 378 |
-
# Generate classification report
|
| 379 |
-
unique_labels = sorted(list(set(labels)))
|
| 380 |
-
report = classification_report(labels, predictions, labels=unique_labels,
|
| 381 |
-
target_names=unique_labels, output_dict=True)
|
| 382 |
-
|
| 383 |
-
return {
|
| 384 |
-
'accuracy': accuracy,
|
| 385 |
-
'predictions': predictions,
|
| 386 |
-
'confusion_matrix': cm,
|
| 387 |
-
'classification_report': report,
|
| 388 |
-
'figure': fig
|
| 389 |
-
}
|
| 390 |
-
|
| 391 |
-
def evaluate_dataset(self, dataframe, dataset_name="Dataset"):
|
| 392 |
-
"""
|
| 393 |
-
Evaluate embeddings on a given dataset.
|
| 394 |
-
|
| 395 |
-
This method extracts embeddings for text, image, and color, computes similarity metrics,
|
| 396 |
-
evaluates classification performance, and saves confusion matrices.
|
| 397 |
-
|
| 398 |
-
Args:
|
| 399 |
-
dataframe: DataFrame containing the dataset
|
| 400 |
-
dataset_name: Name of the dataset for display purposes
|
| 401 |
-
|
| 402 |
-
Returns:
|
| 403 |
-
Dictionary containing evaluation results for text, image, and color embeddings
|
| 404 |
-
"""
|
| 405 |
-
print(f"\n{'='*60}")
|
| 406 |
-
print(f"Evaluating {dataset_name}")
|
| 407 |
-
print(f"{'='*60}")
|
| 408 |
-
|
| 409 |
-
# Create dataset and dataloader - use KaglDataset for kagl data
|
| 410 |
-
if "kagl" in dataset_name.lower():
|
| 411 |
-
dataset = KaglDataset(dataframe)
|
| 412 |
-
else:
|
| 413 |
-
dataset = ColorDataset(dataframe)
|
| 414 |
-
# Optimize batch size and workers for faster processing
|
| 415 |
-
dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)
|
| 416 |
-
|
| 417 |
-
results = {}
|
| 418 |
-
|
| 419 |
-
# Evaluate text embeddings
|
| 420 |
-
text_embeddings, text_labels, texts = self.extract_embeddings(dataloader, 'text')
|
| 421 |
-
text_metrics = self.compute_similarity_metrics(text_embeddings, text_labels)
|
| 422 |
-
text_classification = self.evaluate_classification_performance(text_embeddings, text_labels, "Text Embeddings")
|
| 423 |
-
text_metrics.update(text_classification)
|
| 424 |
-
results['text'] = text_metrics
|
| 425 |
-
|
| 426 |
-
# Evaluate image embeddings
|
| 427 |
-
image_embeddings, image_labels, _ = self.extract_embeddings(dataloader, 'image')
|
| 428 |
-
image_metrics = self.compute_similarity_metrics(image_embeddings, image_labels)
|
| 429 |
-
image_classification = self.evaluate_classification_performance(image_embeddings, image_labels, "Image Embeddings")
|
| 430 |
-
image_metrics.update(image_classification)
|
| 431 |
-
results['image'] = image_metrics
|
| 432 |
-
|
| 433 |
-
# Evaluate color embeddings
|
| 434 |
-
color_embeddings, color_labels, _ = self.extract_embeddings(dataloader, 'color')
|
| 435 |
-
color_metrics = self.compute_similarity_metrics(color_embeddings, color_labels)
|
| 436 |
-
color_classification = self.evaluate_classification_performance(color_embeddings, color_labels, "Color Embeddings")
|
| 437 |
-
color_metrics.update(color_classification)
|
| 438 |
-
results['color'] = color_metrics
|
| 439 |
-
|
| 440 |
-
# Print results
|
| 441 |
-
print(f"\n{dataset_name} Results:")
|
| 442 |
-
print("-" * 40)
|
| 443 |
-
for emb_type, metrics in results.items():
|
| 444 |
-
print(f"{emb_type.capitalize()} Embeddings:")
|
| 445 |
-
print(f" Intra-class similarity (same color): {metrics['intra_class_mean']:.4f}")
|
| 446 |
-
print(f" Inter-class similarity (diff colors): {metrics['inter_class_mean']:.4f}")
|
| 447 |
-
print(f" Separation score: {metrics['separation_score']:.4f}")
|
| 448 |
-
print(f" Nearest Neighbor Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
|
| 449 |
-
print(f" Centroid Accuracy: {metrics['centroid_accuracy']:.4f} ({metrics['centroid_accuracy']*100:.1f}%)")
|
| 450 |
-
|
| 451 |
-
# Classification report summary
|
| 452 |
-
report = metrics['classification_report']
|
| 453 |
-
print(f" 📊 Classification Performance:")
|
| 454 |
-
print(f" • Macro Avg F1-Score: {report['macro avg']['f1-score']:.4f}")
|
| 455 |
-
print(f" • Weighted Avg F1-Score: {report['weighted avg']['f1-score']:.4f}")
|
| 456 |
-
print(f" • Support: {report['macro avg']['support']:.0f} samples")
|
| 457 |
-
print()
|
| 458 |
-
|
| 459 |
-
# Create visualizations
|
| 460 |
-
os.makedirs('embedding_evaluation', exist_ok=True)
|
| 461 |
-
|
| 462 |
-
# Confusion matrices
|
| 463 |
-
results['text']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_text_confusion_matrix.png', dpi=300, bbox_inches='tight')
|
| 464 |
-
plt.close(results['text']['figure'])
|
| 465 |
-
|
| 466 |
-
results['image']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_image_confusion_matrix.png', dpi=300, bbox_inches='tight')
|
| 467 |
-
plt.close(results['image']['figure'])
|
| 468 |
-
|
| 469 |
-
results['color']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_color_confusion_matrix.png', dpi=300, bbox_inches='tight')
|
| 470 |
-
plt.close(results['color']['figure'])
|
| 471 |
-
|
| 472 |
-
return results
|
| 473 |
-
|
| 474 |
-
class FashionCLIPDataset(Dataset):
|
| 475 |
-
"""
|
| 476 |
-
Special dataset for Fashion-CLIP that doesn't normalize images.
|
| 477 |
-
|
| 478 |
-
This dataset is used when evaluating with Fashion-CLIP baseline model,
|
| 479 |
-
which requires different image preprocessing (no normalization).
|
| 480 |
-
"""
|
| 481 |
-
def __init__(self, dataframe):
|
| 482 |
-
"""
|
| 483 |
-
Initialize the Fashion-CLIP dataset.
|
| 484 |
-
|
| 485 |
-
Args:
|
| 486 |
-
dataframe: DataFrame containing image paths/URLs, text, and color labels
|
| 487 |
-
"""
|
| 488 |
-
self.dataframe = dataframe
|
| 489 |
-
# Only resize and convert to tensor, no normalization
|
| 490 |
-
self.transform = transforms.Compose([
|
| 491 |
-
transforms.Resize((224, 224)),
|
| 492 |
-
transforms.ToTensor()
|
| 493 |
-
])
|
| 494 |
-
|
| 495 |
-
def __len__(self):
|
| 496 |
-
return len(self.dataframe)
|
| 497 |
-
|
| 498 |
-
def __getitem__(self, idx):
|
| 499 |
-
row = self.dataframe.iloc[idx]
|
| 500 |
-
|
| 501 |
-
# Handle image - it should be in row[config.column_url_image] and contain the image data
|
| 502 |
-
image_data = row[config.column_url_image]
|
| 503 |
-
|
| 504 |
-
try:
|
| 505 |
-
# Check if image_data has 'bytes' key or is already PIL Image
|
| 506 |
-
if isinstance(image_data, dict) and 'bytes' in image_data:
|
| 507 |
-
image = Image.open(BytesIO(image_data['bytes'])).convert("RGB")
|
| 508 |
-
elif hasattr(image_data, 'convert'): # Already a PIL Image
|
| 509 |
-
image = image_data.convert("RGB")
|
| 510 |
-
elif isinstance(image_data, str):
|
| 511 |
-
# It's a file path (local or URL)
|
| 512 |
-
if image_data.startswith('http'):
|
| 513 |
-
# It's a URL - download the image
|
| 514 |
-
import requests
|
| 515 |
-
response = requests.get(image_data, timeout=10)
|
| 516 |
-
response.raise_for_status()
|
| 517 |
-
image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 518 |
-
else:
|
| 519 |
-
# It's a local file path
|
| 520 |
-
image = Image.open(image_data).convert("RGB")
|
| 521 |
-
else:
|
| 522 |
-
# Assume it's bytes data
|
| 523 |
-
image = Image.open(BytesIO(image_data)).convert("RGB")
|
| 524 |
-
|
| 525 |
-
# Apply minimal transform (no normalization)
|
| 526 |
-
image = self.transform(image)
|
| 527 |
-
|
| 528 |
-
except Exception as e:
|
| 529 |
-
print(f"⚠️ Failed to load image {idx}: {e}")
|
| 530 |
-
# Return a placeholder image instead of undefined variable
|
| 531 |
-
image = torch.zeros(3, 224, 224)
|
| 532 |
-
|
| 533 |
-
# Get text and color
|
| 534 |
-
description = row[config.text_column]
|
| 535 |
-
color = row[config.color_column]
|
| 536 |
-
|
| 537 |
-
return image, description, color
|
| 538 |
-
|
| 539 |
-
class FashionCLIPEvaluator:
|
| 540 |
-
"""
|
| 541 |
-
Evaluator for Fashion-CLIP baseline model.
|
| 542 |
-
|
| 543 |
-
This class provides methods to evaluate embeddings from the Fashion-CLIP model
|
| 544 |
-
and compare them with the custom ColorCLIP model.
|
| 545 |
-
"""
|
| 546 |
-
|
| 547 |
-
def __init__(self):
|
| 548 |
-
"""
|
| 549 |
-
Initialize the Fashion-CLIP evaluator.
|
| 550 |
-
|
| 551 |
-
Loads the Fashion-CLIP model from Hugging Face and prepares it for evaluation.
|
| 552 |
-
"""
|
| 553 |
-
# Load Fashion-CLIP model
|
| 554 |
-
patrick_model_name = "patrickjohncyh/fashion-clip"
|
| 555 |
-
print(f"🔄 Loading Fashion-CLIP model: {patrick_model_name}")
|
| 556 |
-
self.processor = CLIPProcessor.from_pretrained(patrick_model_name)
|
| 557 |
-
self.device = config.device
|
| 558 |
-
self.model = TransformersCLIPModel.from_pretrained(patrick_model_name).to(self.device)
|
| 559 |
-
self.model.eval()
|
| 560 |
-
print(f"✅ Fashion-CLIP model loaded successfully")
|
| 561 |
-
|
| 562 |
-
def extract_embeddings(self, dataloader, embedding_type='text'):
|
| 563 |
-
"""
|
| 564 |
-
Extract embeddings from the Fashion-CLIP model.
|
| 565 |
-
|
| 566 |
-
Args:
|
| 567 |
-
dataloader: DataLoader containing images, texts, and colors
|
| 568 |
-
embedding_type: Type of embeddings to extract ('text', 'image', or 'color')
|
| 569 |
-
|
| 570 |
-
Returns:
|
| 571 |
-
Tuple of (embeddings array, labels list, texts list)
|
| 572 |
-
"""
|
| 573 |
-
all_embeddings = []
|
| 574 |
-
all_labels = []
|
| 575 |
-
all_texts = []
|
| 576 |
-
|
| 577 |
-
with torch.no_grad():
|
| 578 |
-
for images, texts, colors in tqdm(dataloader, desc=f"Extracting {embedding_type} embeddings (Fashion-CLIP)"):
|
| 579 |
-
if embedding_type == 'text':
|
| 580 |
-
# Process text through Fashion-CLIP
|
| 581 |
-
inputs = self.processor(text=texts, return_tensors="pt", padding=True, truncation=True, max_length=77)
|
| 582 |
-
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 583 |
-
text_features = self.model.get_text_features(**inputs)
|
| 584 |
-
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 585 |
-
embeddings = text_features.cpu().numpy()
|
| 586 |
-
labels = colors
|
| 587 |
-
elif embedding_type == 'image':
|
| 588 |
-
# Convert tensors back to PIL images for CLIP processor
|
| 589 |
-
pil_images = []
|
| 590 |
-
for i in range(images.shape[0]):
|
| 591 |
-
# Convert tensor back to PIL Image
|
| 592 |
-
img_tensor = images[i]
|
| 593 |
-
# Denormalize if needed (images should be in [0,1] range)
|
| 594 |
-
if img_tensor.min() < 0 or img_tensor.max() > 1:
|
| 595 |
-
# If normalized, denormalize
|
| 596 |
-
img_tensor = (img_tensor + 1) / 2 # Assuming [-1,1] to [0,1]
|
| 597 |
-
img_tensor = torch.clamp(img_tensor, 0, 1)
|
| 598 |
-
img_pil = transforms.ToPILImage()(img_tensor)
|
| 599 |
-
pil_images.append(img_pil)
|
| 600 |
-
|
| 601 |
-
# Process images through Fashion-CLIP
|
| 602 |
-
inputs = self.processor(images=pil_images, return_tensors="pt", padding=True)
|
| 603 |
-
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 604 |
-
image_features = self.model.get_image_features(**inputs)
|
| 605 |
-
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 606 |
-
embeddings = image_features.cpu().numpy()
|
| 607 |
-
labels = colors
|
| 608 |
-
elif embedding_type == 'color':
|
| 609 |
-
# Process color names as text through Fashion-CLIP
|
| 610 |
-
inputs = self.processor(text=colors, return_tensors="pt", padding=True, truncation=True, max_length=77)
|
| 611 |
-
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 612 |
-
text_features = self.model.get_text_features(**inputs)
|
| 613 |
-
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 614 |
-
embeddings = text_features.cpu().numpy()
|
| 615 |
-
labels = colors
|
| 616 |
-
|
| 617 |
-
all_embeddings.append(embeddings)
|
| 618 |
-
all_labels.extend(labels)
|
| 619 |
-
all_texts.extend(texts)
|
| 620 |
-
|
| 621 |
-
return np.vstack(all_embeddings), all_labels, all_texts
|
| 622 |
-
|
| 623 |
-
def compute_similarity_metrics(self, embeddings, labels):
|
| 624 |
-
"""Compute intra-class and inter-class similarities"""
|
| 625 |
-
similarities = cosine_similarity(embeddings)
|
| 626 |
-
|
| 627 |
-
# Group embeddings by color
|
| 628 |
-
color_groups = defaultdict(list)
|
| 629 |
-
for i, color in enumerate(labels):
|
| 630 |
-
color_groups[color].append(i)
|
| 631 |
-
|
| 632 |
-
# Calculate intra-class similarities (same color)
|
| 633 |
-
intra_class_similarities = []
|
| 634 |
-
for color, indices in color_groups.items():
|
| 635 |
-
if len(indices) > 1:
|
| 636 |
-
for i in range(len(indices)):
|
| 637 |
-
for j in range(i+1, len(indices)):
|
| 638 |
-
sim = similarities[indices[i], indices[j]]
|
| 639 |
-
intra_class_similarities.append(sim)
|
| 640 |
-
|
| 641 |
-
# Calculate inter-class similarities (different colors)
|
| 642 |
-
inter_class_similarities = []
|
| 643 |
-
colors = list(color_groups.keys())
|
| 644 |
-
for i in range(len(colors)):
|
| 645 |
-
for j in range(i+1, len(colors)):
|
| 646 |
-
color1_indices = color_groups[colors[i]]
|
| 647 |
-
color2_indices = color_groups[colors[j]]
|
| 648 |
-
|
| 649 |
-
for idx1 in color1_indices:
|
| 650 |
-
for idx2 in color2_indices:
|
| 651 |
-
sim = similarities[idx1, idx2]
|
| 652 |
-
inter_class_similarities.append(sim)
|
| 653 |
-
|
| 654 |
-
# Calculate classification accuracy using nearest neighbor in embedding space
|
| 655 |
-
nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
|
| 656 |
-
|
| 657 |
-
# Calculate classification accuracy using centroids
|
| 658 |
-
centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
|
| 659 |
-
|
| 660 |
-
return {
|
| 661 |
-
'intra_class_similarities': intra_class_similarities,
|
| 662 |
-
'inter_class_similarities': inter_class_similarities,
|
| 663 |
-
'intra_class_mean': np.mean(intra_class_similarities) if intra_class_similarities else 0,
|
| 664 |
-
'inter_class_mean': np.mean(inter_class_similarities) if inter_class_similarities else 0,
|
| 665 |
-
'separation_score': np.mean(intra_class_similarities) - np.mean(inter_class_similarities) if intra_class_similarities and inter_class_similarities else 0,
|
| 666 |
-
'accuracy': nn_accuracy,
|
| 667 |
-
'centroid_accuracy': centroid_accuracy
|
| 668 |
-
}
|
| 669 |
-
|
| 670 |
-
def compute_embedding_accuracy(self, embeddings, labels, similarities):
|
| 671 |
-
"""Compute classification accuracy using nearest neighbor in embedding space"""
|
| 672 |
-
correct_predictions = 0
|
| 673 |
-
total_predictions = len(labels)
|
| 674 |
-
|
| 675 |
-
for i in range(len(embeddings)):
|
| 676 |
-
true_label = labels[i]
|
| 677 |
-
|
| 678 |
-
# Find the most similar embedding (excluding itself)
|
| 679 |
-
similarities_row = similarities[i].copy()
|
| 680 |
-
similarities_row[i] = -1 # Exclude self-similarity
|
| 681 |
-
nearest_neighbor_idx = np.argmax(similarities_row)
|
| 682 |
-
predicted_label = labels[nearest_neighbor_idx]
|
| 683 |
-
|
| 684 |
-
if predicted_label == true_label:
|
| 685 |
-
correct_predictions += 1
|
| 686 |
-
|
| 687 |
-
return correct_predictions / total_predictions if total_predictions > 0 else 0
|
| 688 |
-
|
| 689 |
-
def compute_centroid_accuracy(self, embeddings, labels):
|
| 690 |
-
"""Compute classification accuracy using color centroids"""
|
| 691 |
-
# Create centroids for each color
|
| 692 |
-
unique_colors = list(set(labels))
|
| 693 |
-
centroids = {}
|
| 694 |
-
|
| 695 |
-
for color in unique_colors:
|
| 696 |
-
color_indices = [i for i, label in enumerate(labels) if label == color]
|
| 697 |
-
color_embeddings = embeddings[color_indices]
|
| 698 |
-
centroids[color] = np.mean(color_embeddings, axis=0)
|
| 699 |
-
|
| 700 |
-
# Classify each embedding to nearest centroid
|
| 701 |
-
correct_predictions = 0
|
| 702 |
-
total_predictions = len(labels)
|
| 703 |
-
|
| 704 |
-
for i, embedding in enumerate(embeddings):
|
| 705 |
-
true_label = labels[i]
|
| 706 |
-
|
| 707 |
-
# Find closest centroid
|
| 708 |
-
best_similarity = -1
|
| 709 |
-
predicted_label = None
|
| 710 |
-
|
| 711 |
-
for color, centroid in centroids.items():
|
| 712 |
-
similarity = cosine_similarity([embedding], [centroid])[0][0]
|
| 713 |
-
if similarity > best_similarity:
|
| 714 |
-
best_similarity = similarity
|
| 715 |
-
predicted_label = color
|
| 716 |
-
|
| 717 |
-
if predicted_label == true_label:
|
| 718 |
-
correct_predictions += 1
|
| 719 |
-
|
| 720 |
-
return correct_predictions / total_predictions if total_predictions > 0 else 0
|
| 721 |
-
|
| 722 |
-
def predict_colors_from_embeddings(self, embeddings, labels):
|
| 723 |
-
"""Predict colors from embeddings using centroid-based classification"""
|
| 724 |
-
# Create color centroids from training data
|
| 725 |
-
unique_colors = list(set(labels))
|
| 726 |
-
centroids = {}
|
| 727 |
-
|
| 728 |
-
for color in unique_colors:
|
| 729 |
-
color_indices = [i for i, label in enumerate(labels) if label == color]
|
| 730 |
-
color_embeddings = embeddings[color_indices]
|
| 731 |
-
centroids[color] = np.mean(color_embeddings, axis=0)
|
| 732 |
-
|
| 733 |
-
# Predict colors for all embeddings
|
| 734 |
-
predictions = []
|
| 735 |
-
|
| 736 |
-
for i, embedding in enumerate(embeddings):
|
| 737 |
-
# Find closest centroid
|
| 738 |
-
best_similarity = -1
|
| 739 |
-
predicted_color = None
|
| 740 |
-
|
| 741 |
-
for color, centroid in centroids.items():
|
| 742 |
-
similarity = cosine_similarity([embedding], [centroid])[0][0]
|
| 743 |
-
if similarity > best_similarity:
|
| 744 |
-
best_similarity = similarity
|
| 745 |
-
predicted_color = color
|
| 746 |
-
|
| 747 |
-
predictions.append(predicted_color)
|
| 748 |
-
|
| 749 |
-
return predictions
|
| 750 |
-
|
| 751 |
-
def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix"):
|
| 752 |
-
"""Create and plot confusion matrix"""
|
| 753 |
-
# Get unique labels
|
| 754 |
-
unique_labels = sorted(list(set(true_labels + predicted_labels)))
|
| 755 |
-
|
| 756 |
-
# Create confusion matrix
|
| 757 |
-
cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
|
| 758 |
-
|
| 759 |
-
# Calculate accuracy
|
| 760 |
-
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 761 |
-
|
| 762 |
-
# Plot confusion matrix
|
| 763 |
-
plt.figure(figsize=(12, 10))
|
| 764 |
-
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
| 765 |
-
xticklabels=unique_labels, yticklabels=unique_labels)
|
| 766 |
-
plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
|
| 767 |
-
plt.ylabel('True Color')
|
| 768 |
-
plt.xlabel('Predicted Color')
|
| 769 |
-
plt.xticks(rotation=45)
|
| 770 |
-
plt.yticks(rotation=0)
|
| 771 |
-
plt.tight_layout()
|
| 772 |
-
|
| 773 |
-
return plt.gcf(), accuracy, cm
|
| 774 |
-
|
| 775 |
-
def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings"):
|
| 776 |
-
"""Evaluate classification performance and create confusion matrix"""
|
| 777 |
-
# Predict colors
|
| 778 |
-
predictions = self.predict_colors_from_embeddings(embeddings, labels)
|
| 779 |
-
|
| 780 |
-
# Calculate accuracy
|
| 781 |
-
accuracy = accuracy_score(labels, predictions)
|
| 782 |
-
|
| 783 |
-
# Create confusion matrix
|
| 784 |
-
fig, acc, cm = self.create_confusion_matrix(labels, predictions,
|
| 785 |
-
f"{embedding_type} - Color Classification (Fashion-CLIP)")
|
| 786 |
-
|
| 787 |
-
# Generate classification report
|
| 788 |
-
unique_labels = sorted(list(set(labels)))
|
| 789 |
-
report = classification_report(labels, predictions, labels=unique_labels,
|
| 790 |
-
target_names=unique_labels, output_dict=True)
|
| 791 |
-
|
| 792 |
-
return {
|
| 793 |
-
'accuracy': accuracy,
|
| 794 |
-
'predictions': predictions,
|
| 795 |
-
'confusion_matrix': cm,
|
| 796 |
-
'classification_report': report,
|
| 797 |
-
'figure': fig
|
| 798 |
-
}
|
| 799 |
-
|
| 800 |
-
def evaluate_dataset(self, dataframe, dataset_name="Dataset"):
|
| 801 |
-
"""
|
| 802 |
-
Evaluate Fashion-CLIP embeddings on a given dataset.
|
| 803 |
-
|
| 804 |
-
This method extracts embeddings for text, image, and color, computes similarity metrics,
|
| 805 |
-
evaluates classification performance, and saves confusion matrices.
|
| 806 |
-
|
| 807 |
-
Args:
|
| 808 |
-
dataframe: DataFrame containing the dataset
|
| 809 |
-
dataset_name: Name of the dataset for display purposes
|
| 810 |
-
|
| 811 |
-
Returns:
|
| 812 |
-
Dictionary containing evaluation results for text, image, and color embeddings
|
| 813 |
-
"""
|
| 814 |
-
print(f"\n{'='*60}")
|
| 815 |
-
print(f"Evaluating {dataset_name} with Fashion-CLIP")
|
| 816 |
-
print(f"{'='*60}")
|
| 817 |
-
|
| 818 |
-
# Create dataset and dataloader - use FashionCLIPDataset for Fashion-CLIP
|
| 819 |
-
if "kagl" in dataset_name.lower():
|
| 820 |
-
dataset = KaglDataset(dataframe)
|
| 821 |
-
else:
|
| 822 |
-
dataset = FashionCLIPDataset(dataframe) # Use special dataset for Fashion-CLIP
|
| 823 |
-
# Optimize batch size for Fashion-CLIP
|
| 824 |
-
dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)
|
| 825 |
-
|
| 826 |
-
results = {}
|
| 827 |
-
|
| 828 |
-
# Evaluate text embeddings
|
| 829 |
-
text_embeddings, text_labels, texts = self.extract_embeddings(dataloader, 'text')
|
| 830 |
-
text_metrics = self.compute_similarity_metrics(text_embeddings, text_labels)
|
| 831 |
-
text_classification = self.evaluate_classification_performance(text_embeddings, text_labels, "Text Embeddings")
|
| 832 |
-
text_metrics.update(text_classification)
|
| 833 |
-
results['text'] = text_metrics
|
| 834 |
-
|
| 835 |
-
# Evaluate image embeddings
|
| 836 |
-
image_embeddings, image_labels, _ = self.extract_embeddings(dataloader, 'image')
|
| 837 |
-
image_metrics = self.compute_similarity_metrics(image_embeddings, image_labels)
|
| 838 |
-
image_classification = self.evaluate_classification_performance(image_embeddings, image_labels, "Image Embeddings")
|
| 839 |
-
image_metrics.update(image_classification)
|
| 840 |
-
results['image'] = image_metrics
|
| 841 |
-
|
| 842 |
-
# Evaluate color embeddings
|
| 843 |
-
color_embeddings, color_labels, _ = self.extract_embeddings(dataloader, 'color')
|
| 844 |
-
color_metrics = self.compute_similarity_metrics(color_embeddings, color_labels)
|
| 845 |
-
color_classification = self.evaluate_classification_performance(color_embeddings, color_labels, "Color Embeddings")
|
| 846 |
-
color_metrics.update(color_classification)
|
| 847 |
-
results['color'] = color_metrics
|
| 848 |
-
|
| 849 |
-
# Print results
|
| 850 |
-
print(f"\n{dataset_name} Results (Fashion-CLIP):")
|
| 851 |
-
print("-" * 40)
|
| 852 |
-
for emb_type, metrics in results.items():
|
| 853 |
-
print(f"{emb_type.capitalize()} Embeddings:")
|
| 854 |
-
print(f" Intra-class similarity (same color): {metrics['intra_class_mean']:.4f}")
|
| 855 |
-
print(f" Inter-class similarity (diff colors): {metrics['inter_class_mean']:.4f}")
|
| 856 |
-
print(f" Separation score: {metrics['separation_score']:.4f}")
|
| 857 |
-
print(f" Nearest Neighbor Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
|
| 858 |
-
print(f" Centroid Accuracy: {metrics['centroid_accuracy']:.4f} ({metrics['centroid_accuracy']*100:.1f}%)")
|
| 859 |
-
|
| 860 |
-
# Classification report summary
|
| 861 |
-
report = metrics['classification_report']
|
| 862 |
-
print(f" 📊 Classification Performance:")
|
| 863 |
-
print(f" • Macro Avg F1-Score: {report['macro avg']['f1-score']:.4f}")
|
| 864 |
-
print(f" • Weighted Avg F1-Score: {report['weighted avg']['f1-score']:.4f}")
|
| 865 |
-
print(f" • Support: {report['macro avg']['support']:.0f} samples")
|
| 866 |
-
print()
|
| 867 |
-
|
| 868 |
-
# Create visualizations
|
| 869 |
-
os.makedirs('embedding_evaluation', exist_ok=True)
|
| 870 |
-
|
| 871 |
-
# Confusion matrices
|
| 872 |
-
results['text']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_text_confusion_matrix_fashion_clip.png', dpi=300, bbox_inches='tight')
|
| 873 |
-
plt.close(results['text']['figure'])
|
| 874 |
-
|
| 875 |
-
results['image']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_image_confusion_matrix_fashion_clip.png', dpi=300, bbox_inches='tight')
|
| 876 |
-
plt.close(results['image']['figure'])
|
| 877 |
-
|
| 878 |
-
results['color']['figure'].savefig(f'embedding_evaluation/{dataset_name.lower()}_color_confusion_matrix_fashion_clip.png', dpi=300, bbox_inches='tight')
|
| 879 |
-
plt.close(results['color']['figure'])
|
| 880 |
-
|
| 881 |
-
return results
|
| 882 |
-
|
| 883 |
-
class KaglDataset(Dataset):
|
| 884 |
-
"""
|
| 885 |
-
Dataset class for KAGL Marqo dataset evaluation.
|
| 886 |
-
|
| 887 |
-
Handles loading images from the KAGL dataset format (with 'bytes' in image_url).
|
| 888 |
-
"""
|
| 889 |
-
def __init__(self, dataframe):
|
| 890 |
-
"""
|
| 891 |
-
Initialize the KAGL dataset.
|
| 892 |
-
|
| 893 |
-
Args:
|
| 894 |
-
dataframe: DataFrame containing image_url (with bytes), text, and color labels
|
| 895 |
-
"""
|
| 896 |
-
self.dataframe = dataframe
|
| 897 |
-
self.transform = transforms.Compose([
|
| 898 |
-
transforms.Resize((224, 224)),
|
| 899 |
-
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
| 900 |
-
transforms.ToTensor(),
|
| 901 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 902 |
-
])
|
| 903 |
-
|
| 904 |
-
def __len__(self):
|
| 905 |
-
return len(self.dataframe)
|
| 906 |
-
|
| 907 |
-
def __getitem__(self, idx):
|
| 908 |
-
row = self.dataframe.iloc[idx]
|
| 909 |
-
|
| 910 |
-
# Handle image - it should be in row['image_url'] and contain the image data
|
| 911 |
-
image_data = row["image_url"]
|
| 912 |
-
|
| 913 |
-
# Check if image_data has 'bytes' key or is already PIL Image
|
| 914 |
-
if isinstance(image_data, dict) and 'bytes' in image_data:
|
| 915 |
-
image = Image.open(BytesIO(image_data['bytes'])).convert("RGB")
|
| 916 |
-
elif hasattr(image_data, 'convert'): # Already a PIL Image
|
| 917 |
-
image = image_data.convert("RGB")
|
| 918 |
-
else:
|
| 919 |
-
image = Image.open(BytesIO(image_data)).convert("RGB")
|
| 920 |
-
|
| 921 |
-
image = self.transform(image)
|
| 922 |
-
|
| 923 |
-
# Get text and color from kagl
|
| 924 |
-
description = row['text']
|
| 925 |
-
color = row['color']
|
| 926 |
-
|
| 927 |
-
return image, description, color
|
| 928 |
-
|
| 929 |
-
def load_kagl_marqo_dataset():
|
| 930 |
-
"""
|
| 931 |
-
Load and prepare KAGL Marqo dataset from Hugging Face.
|
| 932 |
-
|
| 933 |
-
This function loads the Marqo/KAGL dataset, filters for valid colors,
|
| 934 |
-
and formats it for evaluation.
|
| 935 |
-
|
| 936 |
-
Returns:
|
| 937 |
-
DataFrame with columns: image_url, text, color
|
| 938 |
-
"""
|
| 939 |
-
from datasets import load_dataset
|
| 940 |
-
print("Loading kagl KAGL dataset...")
|
| 941 |
-
|
| 942 |
-
# Load the dataset
|
| 943 |
-
dataset = load_dataset("Marqo/KAGL")
|
| 944 |
-
df = dataset["data"].to_pandas()
|
| 945 |
-
print(f"✅ Dataset kagl loaded")
|
| 946 |
-
|
| 947 |
-
# Prepare data - Replace baseColour
|
| 948 |
-
df['baseColour'] = df['baseColour'].str.lower().str.replace("grey", "gray")
|
| 949 |
-
df_test = df[df['baseColour'].notna()].copy()
|
| 950 |
-
|
| 951 |
-
print(f"📊 Before filtering: {len(df_test)} samples")
|
| 952 |
-
|
| 953 |
-
# Filter for common colors
|
| 954 |
-
valid_colors = ['red', 'blue', 'green', 'yellow', 'purple', 'pink', 'orange',
|
| 955 |
-
'brown', 'black', 'white', 'gray', 'navy', 'maroon', 'beige']
|
| 956 |
-
df_test = df_test[df_test['baseColour'].isin(valid_colors)]
|
| 957 |
-
|
| 958 |
-
print(f"📊 After filtering invalid colors: {len(df_test)} samples")
|
| 959 |
-
print(f"🎨 Valid colors found: {sorted(df_test['baseColour'].unique())}")
|
| 960 |
-
|
| 961 |
-
if len(df_test) == 0:
|
| 962 |
-
print("❌ No samples left after color filtering. Using mock dataset.")
|
| 963 |
-
|
| 964 |
-
# Map to our expected column names
|
| 965 |
-
kagl_formatted = pd.DataFrame({
|
| 966 |
-
'image_url': df_test['image_url'],
|
| 967 |
-
'text': df_test['text'],
|
| 968 |
-
'color': df_test['baseColour'].str.lower().str.replace("grey", "gray")
|
| 969 |
-
})
|
| 970 |
-
|
| 971 |
-
# Additional validation - remove rows with missing data
|
| 972 |
-
print(f"📊 Before final validation: {len(kagl_formatted)} samples")
|
| 973 |
-
kagl_formatted = kagl_formatted.dropna(subset=[config.column_url_image, config.text_column, config.color_column])
|
| 974 |
-
print(f"📊 After removing missing data: {len(kagl_formatted)} samples")
|
| 975 |
-
|
| 976 |
-
# Check for empty strings
|
| 977 |
-
kagl_formatted = kagl_formatted[
|
| 978 |
-
(kagl_formatted['text'].str.strip() != '') &
|
| 979 |
-
(kagl_formatted['color'].str.strip() != '')
|
| 980 |
-
]
|
| 981 |
-
print(f"📊 After removing empty strings: {len(kagl_formatted)} samples")
|
| 982 |
-
|
| 983 |
-
print(f"📊 Final dataset size: {len(kagl_formatted)} samples")
|
| 984 |
-
|
| 985 |
-
return kagl_formatted
|
| 986 |
-
|
| 987 |
-
def create_comparison_table(val_results, kagl_results, val_results_fashion_clip, kagl_results_fashion_clip):
|
| 988 |
-
"""
|
| 989 |
-
Create a structured comparison table between custom model and Fashion-CLIP baseline.
|
| 990 |
-
|
| 991 |
-
Args:
|
| 992 |
-
val_results: Evaluation results for custom model on validation dataset
|
| 993 |
-
kagl_results: Evaluation results for custom model on KAGL dataset
|
| 994 |
-
val_results_fashion_clip: Evaluation results for Fashion-CLIP on validation dataset
|
| 995 |
-
kagl_results_fashion_clip: Evaluation results for Fashion-CLIP on KAGL dataset
|
| 996 |
-
|
| 997 |
-
Returns:
|
| 998 |
-
DataFrame containing the comparison table
|
| 999 |
-
"""
|
| 1000 |
-
|
| 1001 |
-
# Create DataFrame for comparison
|
| 1002 |
-
data = []
|
| 1003 |
-
|
| 1004 |
-
# Define embedding types and their display names
|
| 1005 |
-
embedding_types = [
|
| 1006 |
-
('text', 'Text Embeddings'),
|
| 1007 |
-
('image', 'Image Embeddings'),
|
| 1008 |
-
('color', 'Color Embeddings')
|
| 1009 |
-
]
|
| 1010 |
-
|
| 1011 |
-
# Define datasets
|
| 1012 |
-
datasets = [
|
| 1013 |
-
('Validation Dataset', val_results, val_results_fashion_clip),
|
| 1014 |
-
('kagl Marqo Dataset', kagl_results, kagl_results_fashion_clip)
|
| 1015 |
-
]
|
| 1016 |
-
|
| 1017 |
-
for dataset_name, custom_results, baseline_results in datasets:
|
| 1018 |
-
for emb_type, emb_display in embedding_types:
|
| 1019 |
-
# Your custom model results
|
| 1020 |
-
custom_metrics = custom_results[emb_type]
|
| 1021 |
-
# Baseline model results
|
| 1022 |
-
baseline_metrics = baseline_results[emb_type]
|
| 1023 |
-
|
| 1024 |
-
data.append({
|
| 1025 |
-
'Dataset': dataset_name,
|
| 1026 |
-
'Embedding Type': emb_display,
|
| 1027 |
-
'Model': 'Your Model',
|
| 1028 |
-
'Separation Score': f"{custom_metrics['separation_score']:.4f}",
|
| 1029 |
-
'NN Accuracy (%)': f"{custom_metrics['accuracy']*100:.1f}%",
|
| 1030 |
-
'Centroid Accuracy (%)': f"{custom_metrics['centroid_accuracy']*100:.1f}%",
|
| 1031 |
-
'Intra-class Similarity': f"{custom_metrics['intra_class_mean']:.4f}",
|
| 1032 |
-
'Inter-class Similarity': f"{custom_metrics['inter_class_mean']:.4f}",
|
| 1033 |
-
'Macro F1-Score': f"{custom_metrics['classification_report']['macro avg']['f1-score']:.4f}",
|
| 1034 |
-
'Weighted F1-Score': f"{custom_metrics['classification_report']['weighted avg']['f1-score']:.4f}"
|
| 1035 |
-
})
|
| 1036 |
-
|
| 1037 |
-
data.append({
|
| 1038 |
-
'Dataset': dataset_name,
|
| 1039 |
-
'Embedding Type': emb_display,
|
| 1040 |
-
'Model': 'Fashion-CLIP (Baseline)',
|
| 1041 |
-
'Separation Score': f"{baseline_metrics['separation_score']:.4f}",
|
| 1042 |
-
'NN Accuracy (%)': f"{baseline_metrics['accuracy']*100:.1f}%",
|
| 1043 |
-
'Centroid Accuracy (%)': f"{baseline_metrics['centroid_accuracy']*100:.1f}%",
|
| 1044 |
-
'Intra-class Similarity': f"{baseline_metrics['intra_class_mean']:.4f}",
|
| 1045 |
-
'Inter-class Similarity': f"{baseline_metrics['inter_class_mean']:.4f}",
|
| 1046 |
-
'Macro F1-Score': f"{baseline_metrics['classification_report']['macro avg']['f1-score']:.4f}",
|
| 1047 |
-
'Weighted F1-Score': f"{baseline_metrics['classification_report']['weighted avg']['f1-score']:.4f}"
|
| 1048 |
-
})
|
| 1049 |
-
|
| 1050 |
-
# Create DataFrame
|
| 1051 |
-
df_comparison = pd.DataFrame(data)
|
| 1052 |
-
|
| 1053 |
-
# Save to CSV
|
| 1054 |
-
df_comparison.to_csv('embedding_evaluation/model_comparison_table.csv', index=False)
|
| 1055 |
-
|
| 1056 |
-
# Print formatted table
|
| 1057 |
-
print(f"\n{'='*120}")
|
| 1058 |
-
print("📊 COMPREHENSIVE MODEL COMPARISON TABLE")
|
| 1059 |
-
print(f"{'='*120}")
|
| 1060 |
-
|
| 1061 |
-
# Print table by dataset
|
| 1062 |
-
for dataset_name in df_comparison['Dataset'].unique():
|
| 1063 |
-
print(f"\n🔍 {dataset_name.upper()}")
|
| 1064 |
-
print("-" * 120)
|
| 1065 |
-
|
| 1066 |
-
dataset_df = df_comparison[df_comparison['Dataset'] == dataset_name]
|
| 1067 |
-
|
| 1068 |
-
for emb_type in dataset_df['Embedding Type'].unique():
|
| 1069 |
-
print(f"\n📈 {emb_type}:")
|
| 1070 |
-
emb_df = dataset_df[dataset_df['Embedding Type'] == emb_type]
|
| 1071 |
-
|
| 1072 |
-
# Print header
|
| 1073 |
-
print(f"{'Model':<20} {'Separation':<12} {'NN Acc':<10} {'Centroid Acc':<13} {'Intra-class':<12} {'Inter-class':<12} {'Macro F1':<10} {'Weighted F1':<12}")
|
| 1074 |
-
print("-" * 120)
|
| 1075 |
-
|
| 1076 |
-
# Print data
|
| 1077 |
-
for _, row in emb_df.iterrows():
|
| 1078 |
-
print(f"{row['Model']:<20} {row['Separation Score']:<12} {row['NN Accuracy (%)']:<10} {row['Centroid Accuracy (%)']:<13} {row['Intra-class Similarity']:<12} {row['Inter-class Similarity']:<12} {row['Macro F1-Score']:<10} {row['Weighted F1-Score']:<12}")
|
| 1079 |
-
|
| 1080 |
-
return df_comparison
|
| 1081 |
-
|
| 1082 |
-
if __name__ == "__main__":
|
| 1083 |
-
|
| 1084 |
-
# Initialize evaluator for your custom model
|
| 1085 |
-
evaluator = EmbeddingEvaluator(model_path=config.color_model_path, embed_dim=config.color_emb_dim)
|
| 1086 |
-
|
| 1087 |
-
# Initialize Fashion-CLIP evaluator
|
| 1088 |
-
fashion_clip_evaluator = FashionCLIPEvaluator()
|
| 1089 |
-
|
| 1090 |
-
# Load datasets
|
| 1091 |
-
print("Loading datasets...")
|
| 1092 |
-
|
| 1093 |
-
# Load validation dataset
|
| 1094 |
-
df_val = pd.read_csv(config.local_dataset_path)
|
| 1095 |
-
|
| 1096 |
-
# Filter for better quality data
|
| 1097 |
-
print(f"📊 Original dataset size: {len(df_val)}")
|
| 1098 |
-
samples_to_evaluate = 10000
|
| 1099 |
-
|
| 1100 |
-
|
| 1101 |
-
# Load kagl Marqo dataset
|
| 1102 |
-
kagl_df = load_kagl_marqo_dataset()
|
| 1103 |
-
|
| 1104 |
-
# Evaluate your custom model on validation dataset
|
| 1105 |
-
val_results = evaluator.evaluate_dataset(df_val, "Validation Dataset")
|
| 1106 |
-
|
| 1107 |
-
# Evaluate your custom model on kagl Marqo dataset (reduced sample for speed)
|
| 1108 |
-
kagl_results = evaluator.evaluate_dataset(kagl_df.sample(min(samples_to_evaluate, len(kagl_df)), random_state=42), "kagl Marqo Dataset")
|
| 1109 |
-
|
| 1110 |
-
# Evaluate Fashion-CLIP on validation dataset
|
| 1111 |
-
val_results_fashion_clip = fashion_clip_evaluator.evaluate_dataset(df_val, "Validation Dataset")
|
| 1112 |
-
|
| 1113 |
-
# Create comprehensive comparison table
|
| 1114 |
-
comparison_df = create_comparison_table(
|
| 1115 |
-
val_results, kagl_results,
|
| 1116 |
-
val_results_fashion_clip
|
| 1117 |
-
)
|
| 1118 |
-
|
| 1119 |
-
print(f"\n{'='*120}")
|
| 1120 |
-
print("✅ Evaluation complete!")
|
| 1121 |
-
print("📁 Confusion matrices saved in 'embedding_evaluation/' folder")
|
| 1122 |
-
print("📁 Comparison table saved as 'model_comparison_table.csv'")
|
| 1123 |
-
print("📁 Fashion-CLIP results are saved with '_fashion_clip' suffix.")
|
| 1124 |
-
print(f"{'='*120}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Evaluation/fashion_search.py
DELETED
|
@@ -1,365 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Fashion search system using multi-modal embeddings.
|
| 4 |
-
This file implements a fashion search engine that allows searching for clothing items
|
| 5 |
-
using text queries. It uses embeddings from the main model to calculate cosine similarities
|
| 6 |
-
and return the most relevant items. The system pre-computes embeddings for all items
|
| 7 |
-
in the dataset for fast search.
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
import torch
|
| 11 |
-
import numpy as np
|
| 12 |
-
import pandas as pd
|
| 13 |
-
from PIL import Image
|
| 14 |
-
import matplotlib.pyplot as plt
|
| 15 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 16 |
-
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
|
| 17 |
-
import warnings
|
| 18 |
-
import os
|
| 19 |
-
from typing import List, Tuple, Union, Optional
|
| 20 |
-
import argparse
|
| 21 |
-
|
| 22 |
-
# Import custom models
|
| 23 |
-
from color_model import CLIPModel as ColorModel
|
| 24 |
-
from hierarchy_model import Model as HierarchyModel, HierarchyExtractor
|
| 25 |
-
from main_model import CustomDataset
|
| 26 |
-
import config
|
| 27 |
-
|
| 28 |
-
warnings.filterwarnings("ignore")
|
| 29 |
-
|
| 30 |
-
class FashionSearchEngine:
|
| 31 |
-
"""
|
| 32 |
-
Fashion search engine using multi-modal embeddings with category emphasis
|
| 33 |
-
"""
|
| 34 |
-
|
| 35 |
-
def __init__(self, top_k: int = 10, max_items: int = 10000):
|
| 36 |
-
"""
|
| 37 |
-
Initialize the fashion search engine
|
| 38 |
-
Args:
|
| 39 |
-
top_k: Number of top results to return
|
| 40 |
-
max_items: Maximum number of items to process (for faster initialization)
|
| 41 |
-
hierarchy_weight: Weight for hierarchy/category dimensions (default: 2.0)
|
| 42 |
-
color_weight: Weight for color dimensions (default: 1.0)
|
| 43 |
-
"""
|
| 44 |
-
self.device = config.device
|
| 45 |
-
self.top_k = top_k
|
| 46 |
-
self.max_items = max_items
|
| 47 |
-
self.color_dim = config.color_emb_dim
|
| 48 |
-
self.hierarchy_dim = config.hierarchy_emb_dim
|
| 49 |
-
|
| 50 |
-
# Load models
|
| 51 |
-
self._load_models()
|
| 52 |
-
|
| 53 |
-
# Load dataset
|
| 54 |
-
self._load_dataset()
|
| 55 |
-
|
| 56 |
-
# Pre-compute embeddings for all items
|
| 57 |
-
self._precompute_embeddings()
|
| 58 |
-
|
| 59 |
-
print("✅ Fashion Search Engine ready!")
|
| 60 |
-
|
| 61 |
-
def _load_models(self):
|
| 62 |
-
"""Load all required models"""
|
| 63 |
-
print("📦 Loading models...")
|
| 64 |
-
|
| 65 |
-
# Load color model
|
| 66 |
-
color_checkpoint = torch.load(config.color_model_path, map_location=self.device, weights_only=True)
|
| 67 |
-
self.color_model = ColorModel(embed_dim=self.color_dim).to(self.device)
|
| 68 |
-
self.color_model.load_state_dict(color_checkpoint)
|
| 69 |
-
self.color_model.eval()
|
| 70 |
-
|
| 71 |
-
# Load hierarchy model
|
| 72 |
-
hierarchy_checkpoint = torch.load(config.hierarchy_model_path, map_location=self.device)
|
| 73 |
-
self.hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
|
| 74 |
-
self.hierarchy_model = HierarchyModel(
|
| 75 |
-
num_hierarchy_classes=len(self.hierarchy_classes),
|
| 76 |
-
embed_dim=self.hierarchy_dim
|
| 77 |
-
).to(self.device)
|
| 78 |
-
self.hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
|
| 79 |
-
|
| 80 |
-
# Set hierarchy extractor
|
| 81 |
-
hierarchy_extractor = HierarchyExtractor(self.hierarchy_classes, verbose=False)
|
| 82 |
-
self.hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
|
| 83 |
-
self.hierarchy_model.eval()
|
| 84 |
-
|
| 85 |
-
# Load main CLIP model - Use the trained model directly
|
| 86 |
-
self.main_model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 87 |
-
|
| 88 |
-
# Load the trained weights
|
| 89 |
-
checkpoint = torch.load(config.main_model_path, map_location=self.device)
|
| 90 |
-
if 'model_state_dict' in checkpoint:
|
| 91 |
-
self.main_model.load_state_dict(checkpoint['model_state_dict'])
|
| 92 |
-
else:
|
| 93 |
-
# Fallback: try to load as state dict directly
|
| 94 |
-
self.main_model.load_state_dict(checkpoint)
|
| 95 |
-
print("✅ Loaded model weights directly")
|
| 96 |
-
|
| 97 |
-
self.main_model.to(self.device)
|
| 98 |
-
self.main_model.eval()
|
| 99 |
-
|
| 100 |
-
# Load CLIP processor
|
| 101 |
-
self.clip_processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
|
| 102 |
-
|
| 103 |
-
print(f"✅ Models loaded - Colors: {self.color_dim}D, Hierarchy: {self.hierarchy_dim}D")
|
| 104 |
-
|
| 105 |
-
def _load_dataset(self):
|
| 106 |
-
"""Load the fashion dataset"""
|
| 107 |
-
print("📊 Loading dataset...")
|
| 108 |
-
|
| 109 |
-
# Load dataset
|
| 110 |
-
self.df = pd.read_csv(config.local_dataset_path)
|
| 111 |
-
self.df_clean = self.df.dropna(subset=[config.column_local_image_path])
|
| 112 |
-
|
| 113 |
-
# Create dataset object
|
| 114 |
-
self.dataset = CustomDataset(self.df_clean)
|
| 115 |
-
self.dataset.set_training_mode(False) # No augmentation for search
|
| 116 |
-
|
| 117 |
-
print(f"✅ {len(self.df_clean)} items loaded for search")
|
| 118 |
-
|
| 119 |
-
def _precompute_embeddings(self):
|
| 120 |
-
"""Pre-compute embeddings for all items in the dataset"""
|
| 121 |
-
print("🔄 Pre-computing embeddings...")
|
| 122 |
-
|
| 123 |
-
# OPTIMIZATION: Sample a subset for faster initialization
|
| 124 |
-
print(f"⚠️ Dataset too large ({len(self.dataset)} items). Using stratified sampling of 10 items per color-category combination.")
|
| 125 |
-
|
| 126 |
-
# Stratified sampling by color-category combinations
|
| 127 |
-
sampled_df = self.df_clean.groupby([config.color_column, config.hierarchy_column]).sample(n=20, replace=False)
|
| 128 |
-
|
| 129 |
-
# Get the original indices of sampled items
|
| 130 |
-
sampled_indices = sampled_df.index.tolist()
|
| 131 |
-
|
| 132 |
-
all_embeddings = []
|
| 133 |
-
all_texts = []
|
| 134 |
-
all_colors = []
|
| 135 |
-
all_hierarchies = []
|
| 136 |
-
all_images = []
|
| 137 |
-
all_urls = []
|
| 138 |
-
|
| 139 |
-
# Process in batches for efficiency
|
| 140 |
-
batch_size = 32
|
| 141 |
-
|
| 142 |
-
# Add progress bar
|
| 143 |
-
from tqdm import tqdm
|
| 144 |
-
total_batches = (len(sampled_indices) + batch_size - 1) // batch_size
|
| 145 |
-
|
| 146 |
-
for i in tqdm(range(0, len(sampled_indices), batch_size),
|
| 147 |
-
desc="Computing embeddings",
|
| 148 |
-
total=total_batches):
|
| 149 |
-
batch_end = min(i + batch_size, len(sampled_indices))
|
| 150 |
-
batch_items = []
|
| 151 |
-
|
| 152 |
-
for j in range(i, batch_end):
|
| 153 |
-
try:
|
| 154 |
-
# Use the original dataset with the sampled index
|
| 155 |
-
original_idx = sampled_indices[j]
|
| 156 |
-
image, text, color, hierarchy = self.dataset[original_idx]
|
| 157 |
-
batch_items.append((image, text, color, hierarchy))
|
| 158 |
-
all_texts.append(text)
|
| 159 |
-
all_colors.append(color)
|
| 160 |
-
all_hierarchies.append(hierarchy)
|
| 161 |
-
all_images.append(self.df_clean.iloc[original_idx][config.column_local_image_path])
|
| 162 |
-
all_urls.append(self.df_clean.iloc[original_idx][config.column_url_image])
|
| 163 |
-
except Exception as e:
|
| 164 |
-
print(f"⚠️ Skipping item {j}: {e}")
|
| 165 |
-
continue
|
| 166 |
-
|
| 167 |
-
if not batch_items:
|
| 168 |
-
continue
|
| 169 |
-
|
| 170 |
-
# Process batch
|
| 171 |
-
images = torch.stack([item[0] for item in batch_items]).to(self.device)
|
| 172 |
-
texts = [item[1] for item in batch_items]
|
| 173 |
-
|
| 174 |
-
with torch.no_grad():
|
| 175 |
-
# Get embeddings from main model (text embeddings only)
|
| 176 |
-
text_inputs = self.clip_processor(text=texts, padding=True, return_tensors="pt")
|
| 177 |
-
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
| 178 |
-
|
| 179 |
-
# Create dummy images for the model
|
| 180 |
-
dummy_images = torch.zeros(len(texts), 3, 224, 224).to(self.device)
|
| 181 |
-
|
| 182 |
-
outputs = self.main_model(**text_inputs, pixel_values=dummy_images)
|
| 183 |
-
embeddings = outputs.text_embeds.cpu().numpy()
|
| 184 |
-
|
| 185 |
-
all_embeddings.extend(embeddings)
|
| 186 |
-
|
| 187 |
-
self.all_embeddings = np.array(all_embeddings)
|
| 188 |
-
self.all_texts = all_texts
|
| 189 |
-
self.all_colors = all_colors
|
| 190 |
-
self.all_hierarchies = all_hierarchies
|
| 191 |
-
self.all_images = all_images
|
| 192 |
-
self.all_urls = all_urls
|
| 193 |
-
|
| 194 |
-
print(f"✅ Pre-computed embeddings for {len(self.all_embeddings)} items")
|
| 195 |
-
|
| 196 |
-
def search_by_text(self, query_text: str, filter_category: str = None) -> List[dict]:
|
| 197 |
-
"""
|
| 198 |
-
Search for clothing items using text query
|
| 199 |
-
|
| 200 |
-
Args:
|
| 201 |
-
query_text: Text description to search for
|
| 202 |
-
|
| 203 |
-
Returns:
|
| 204 |
-
List of dictionaries containing search results
|
| 205 |
-
"""
|
| 206 |
-
print(f"🔍 Searching for: '{query_text}'")
|
| 207 |
-
|
| 208 |
-
# Get query embedding
|
| 209 |
-
with torch.no_grad():
|
| 210 |
-
text_inputs = self.clip_processor(text=[query_text], padding=True, return_tensors="pt")
|
| 211 |
-
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
|
| 212 |
-
|
| 213 |
-
# Create a dummy image tensor to satisfy the model's requirements
|
| 214 |
-
dummy_image = torch.zeros(1, 3, 224, 224).to(self.device)
|
| 215 |
-
|
| 216 |
-
outputs = self.main_model(**text_inputs, pixel_values=dummy_image)
|
| 217 |
-
query_embedding = outputs.text_embeds.cpu().numpy()
|
| 218 |
-
|
| 219 |
-
# Calculate similarities
|
| 220 |
-
similarities = cosine_similarity(query_embedding, self.all_embeddings)[0]
|
| 221 |
-
|
| 222 |
-
# Get top-k results
|
| 223 |
-
top_indices = np.argsort(similarities)[::-1][:self.top_k * 2] # Prendre plus de résultats
|
| 224 |
-
|
| 225 |
-
results = []
|
| 226 |
-
for idx in top_indices:
|
| 227 |
-
if similarities[idx] > -0.5:
|
| 228 |
-
# Filter by category if specified
|
| 229 |
-
if filter_category and filter_category.lower() not in self.all_hierarchies[idx].lower():
|
| 230 |
-
continue
|
| 231 |
-
|
| 232 |
-
results.append({
|
| 233 |
-
'rank': len(results) + 1,
|
| 234 |
-
'image_path': self.all_images[idx],
|
| 235 |
-
'text': self.all_texts[idx],
|
| 236 |
-
'color': self.all_colors[idx],
|
| 237 |
-
'hierarchy': self.all_hierarchies[idx],
|
| 238 |
-
'similarity': float(similarities[idx]),
|
| 239 |
-
'index': int(idx),
|
| 240 |
-
'url': self.all_urls[idx]
|
| 241 |
-
})
|
| 242 |
-
|
| 243 |
-
if len(results) >= self.top_k:
|
| 244 |
-
break
|
| 245 |
-
|
| 246 |
-
print(f"✅ Found {len(results)} results")
|
| 247 |
-
return results
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
def display_results(self, results: List[dict], query_info: str = ""):
|
| 251 |
-
"""
|
| 252 |
-
Display search results with images and information
|
| 253 |
-
|
| 254 |
-
Args:
|
| 255 |
-
results: List of search result dictionaries
|
| 256 |
-
query_info: Information about the query
|
| 257 |
-
"""
|
| 258 |
-
if not results:
|
| 259 |
-
print("❌ No results found")
|
| 260 |
-
return
|
| 261 |
-
|
| 262 |
-
print(f"\n🎯 Search Results for: {query_info}")
|
| 263 |
-
print("=" * 80)
|
| 264 |
-
|
| 265 |
-
# Calculate grid layout
|
| 266 |
-
n_results = len(results)
|
| 267 |
-
cols = min(5, n_results)
|
| 268 |
-
rows = (n_results + cols - 1) // cols
|
| 269 |
-
|
| 270 |
-
fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
|
| 271 |
-
if rows == 1:
|
| 272 |
-
axes = axes.reshape(1, -1)
|
| 273 |
-
elif cols == 1:
|
| 274 |
-
axes = axes.reshape(-1, 1)
|
| 275 |
-
|
| 276 |
-
for i, result in enumerate(results):
|
| 277 |
-
row = i // cols
|
| 278 |
-
col = i % cols
|
| 279 |
-
ax = axes[row, col]
|
| 280 |
-
|
| 281 |
-
try:
|
| 282 |
-
# Load and display image
|
| 283 |
-
image = Image.open(result['image_path'])
|
| 284 |
-
ax.imshow(image)
|
| 285 |
-
ax.axis('off')
|
| 286 |
-
|
| 287 |
-
# Add title with similarity score
|
| 288 |
-
title = f"#{result['rank']} (Similarity: {result['similarity']:.3f})\n{result['color']} {result['hierarchy']}"
|
| 289 |
-
ax.set_title(title, fontsize=10, wrap=True)
|
| 290 |
-
|
| 291 |
-
except Exception as e:
|
| 292 |
-
ax.text(0.5, 0.5, f"Error loading image\n{result['image_path']}",
|
| 293 |
-
ha='center', va='center', transform=ax.transAxes)
|
| 294 |
-
ax.axis('off')
|
| 295 |
-
|
| 296 |
-
# Hide empty subplots
|
| 297 |
-
for i in range(n_results, rows * cols):
|
| 298 |
-
row = i // cols
|
| 299 |
-
col = i % cols
|
| 300 |
-
axes[row, col].axis('off')
|
| 301 |
-
|
| 302 |
-
plt.tight_layout()
|
| 303 |
-
plt.show()
|
| 304 |
-
|
| 305 |
-
# Print detailed results
|
| 306 |
-
print("\n📋 Detailed Results:")
|
| 307 |
-
for result in results:
|
| 308 |
-
print(f"#{result['rank']:2d} | Similarity: {result['similarity']:.3f} | "
|
| 309 |
-
f"Color: {result['color']:12s} | Category: {result['hierarchy']:15s} | "
|
| 310 |
-
f"Text: {result['text'][:50]}...")
|
| 311 |
-
print(f" 🔗 URL: {result['url']}")
|
| 312 |
-
print()
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
def main():
|
| 316 |
-
"""Main function for command-line usage"""
|
| 317 |
-
parser = argparse.ArgumentParser(description="Fashion Search Engine with Category Emphasis")
|
| 318 |
-
parser.add_argument("--query", "-q", type=str, help="Search query")
|
| 319 |
-
parser.add_argument("--top-k", "-k", type=int, default=10, help="Number of results (default: 10)")
|
| 320 |
-
parser.add_argument("--fast", "-f", action="store_true", help="Fast mode (less items)")
|
| 321 |
-
parser.add_argument("--interactive", "-i", action="store_true", help="Interactive mode")
|
| 322 |
-
|
| 323 |
-
args = parser.parse_args()
|
| 324 |
-
|
| 325 |
-
print("🎯 Fashion Search Engine with Category Emphasis")
|
| 326 |
-
|
| 327 |
-
search_engine = FashionSearchEngine(
|
| 328 |
-
top_k=args.top_k,
|
| 329 |
-
)
|
| 330 |
-
print("✅ Ready!")
|
| 331 |
-
|
| 332 |
-
# Single query mode
|
| 333 |
-
if args.query:
|
| 334 |
-
print(f"🔍 Search: '{args.query}'...")
|
| 335 |
-
results = search_engine.search_by_text(args.query)
|
| 336 |
-
search_engine.display_results(results, args.query)
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
# Interactive mode
|
| 340 |
-
print("Enter your query (e.g. 'red dress') or 'quit' to exit")
|
| 341 |
-
|
| 342 |
-
while True:
|
| 343 |
-
try:
|
| 344 |
-
user_input = input("\n🔍 Query: ").strip()
|
| 345 |
-
if not user_input or user_input.lower() in ['quit', 'exit', 'q']:
|
| 346 |
-
print("👋 Goodbye!")
|
| 347 |
-
break
|
| 348 |
-
|
| 349 |
-
if user_input.startswith('verify '):
|
| 350 |
-
if 'yellow accessories' in user_input:
|
| 351 |
-
search_engine.display_yellow_accessories()
|
| 352 |
-
continue
|
| 353 |
-
|
| 354 |
-
print(f"🔍 Search: '{user_input}'...")
|
| 355 |
-
results = search_engine.search_by_text(user_input)
|
| 356 |
-
search_engine.display_results(results, user_input)
|
| 357 |
-
|
| 358 |
-
except KeyboardInterrupt:
|
| 359 |
-
print("\n👋 Goodbye!")
|
| 360 |
-
break
|
| 361 |
-
except Exception as e:
|
| 362 |
-
print(f"❌ Error: {e}")
|
| 363 |
-
|
| 364 |
-
if __name__ == "__main__":
|
| 365 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Evaluation/hierarchy_evaluation.py
DELETED
|
@@ -1,589 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Hierarchy embedding evaluation for clothing category classification.
|
| 3 |
-
This file evaluates the quality of hierarchy embeddings generated by the hierarchy model
|
| 4 |
-
by calculating intra-class and inter-class similarity metrics, nearest neighbor and centroid-based
|
| 5 |
-
classification accuracies, and generating confusion matrices. It can be used on different datasets
|
| 6 |
-
(local validation, Kagl Marqo) to measure model generalization.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
import pandas as pd
|
| 11 |
-
import numpy as np
|
| 12 |
-
import matplotlib.pyplot as plt
|
| 13 |
-
import seaborn as sns
|
| 14 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 15 |
-
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
|
| 16 |
-
from collections import defaultdict
|
| 17 |
-
import os
|
| 18 |
-
from tqdm import tqdm
|
| 19 |
-
from torch.utils.data import Dataset, DataLoader
|
| 20 |
-
from torchvision import transforms
|
| 21 |
-
from sklearn.model_selection import train_test_split
|
| 22 |
-
from io import BytesIO
|
| 23 |
-
from PIL import Image
|
| 24 |
-
import config
|
| 25 |
-
import warnings
|
| 26 |
-
warnings.filterwarnings('ignore')
|
| 27 |
-
from hierarchy_model import Model, HierarchyExtractor, HierarchyDataset, collate_fn
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
class EmbeddingEvaluator:
|
| 31 |
-
"""
|
| 32 |
-
Evaluator for hierarchy embeddings generated by the hierarchy model.
|
| 33 |
-
|
| 34 |
-
This class provides methods to evaluate the quality of hierarchy embeddings by computing
|
| 35 |
-
similarity metrics, classification accuracies, and generating visualizations.
|
| 36 |
-
"""
|
| 37 |
-
|
| 38 |
-
def __init__(self, model_path, directory):
|
| 39 |
-
"""
|
| 40 |
-
Initialize the embedding evaluator.
|
| 41 |
-
|
| 42 |
-
Args:
|
| 43 |
-
model_path: Path to the trained hierarchy model checkpoint
|
| 44 |
-
directory: Directory to save evaluation results and visualizations
|
| 45 |
-
"""
|
| 46 |
-
self.device = config.device
|
| 47 |
-
self.directory = directory
|
| 48 |
-
|
| 49 |
-
# 1. Load the dataset
|
| 50 |
-
CSV = config.local_dataset_path
|
| 51 |
-
print(f"📁 Using dataset with local images: {CSV}")
|
| 52 |
-
df = pd.read_csv(CSV)
|
| 53 |
-
|
| 54 |
-
print(f"📁 Loaded {len(df)} samples")
|
| 55 |
-
|
| 56 |
-
# 2. Get unique hierarchy classes from the dataset
|
| 57 |
-
hierarchy_classes = sorted(df[config.hierarchy_column].unique().tolist())
|
| 58 |
-
print(f"📋 Found {len(hierarchy_classes)} hierarchy classes")
|
| 59 |
-
|
| 60 |
-
_, self.val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df[config.hierarchy_column])
|
| 61 |
-
|
| 62 |
-
# 3. Load the model
|
| 63 |
-
if os.path.exists(model_path):
|
| 64 |
-
checkpoint = torch.load(model_path, map_location=self.device)
|
| 65 |
-
|
| 66 |
-
# Use model_config to avoid shadowing the imported config module
|
| 67 |
-
model_config = checkpoint.get('config', {})
|
| 68 |
-
saved_hierarchy_classes = checkpoint['hierarchy_classes']
|
| 69 |
-
|
| 70 |
-
# Use the saved hierarchy classes
|
| 71 |
-
self.hierarchy_classes = saved_hierarchy_classes
|
| 72 |
-
|
| 73 |
-
# Create the hierarchy extractor
|
| 74 |
-
self.vocab = HierarchyExtractor(saved_hierarchy_classes)
|
| 75 |
-
|
| 76 |
-
# Create the model with the saved configuration
|
| 77 |
-
self.model = Model(
|
| 78 |
-
num_hierarchy_classes=len(saved_hierarchy_classes),
|
| 79 |
-
embed_dim=model_config['embed_dim'],
|
| 80 |
-
dropout=model_config['dropout']
|
| 81 |
-
).to(self.device)
|
| 82 |
-
|
| 83 |
-
self.model.load_state_dict(checkpoint['model_state'])
|
| 84 |
-
|
| 85 |
-
print(f"✅ Model loaded with:")
|
| 86 |
-
print(f"📋 Hierarchy classes: {len(saved_hierarchy_classes)}")
|
| 87 |
-
print(f"🎯 Embed dim: {model_config['embed_dim']}")
|
| 88 |
-
print(f"💧 Dropout: {model_config['dropout']}")
|
| 89 |
-
print(f"📅 Epoch: {checkpoint.get('epoch', 'unknown')}")
|
| 90 |
-
|
| 91 |
-
else:
|
| 92 |
-
raise FileNotFoundError(f"Model file {model_path} not found")
|
| 93 |
-
|
| 94 |
-
self.model.eval()
|
| 95 |
-
|
| 96 |
-
def create_dataloader(self, dataframe, batch_size=16):
|
| 97 |
-
"""
|
| 98 |
-
Create a DataLoader for the hierarchy dataset.
|
| 99 |
-
|
| 100 |
-
Args:
|
| 101 |
-
dataframe: DataFrame containing the dataset
|
| 102 |
-
batch_size: Batch size for the DataLoader
|
| 103 |
-
|
| 104 |
-
Returns:
|
| 105 |
-
DataLoader instance
|
| 106 |
-
"""
|
| 107 |
-
dataset = HierarchyDataset(dataframe, image_size=224)
|
| 108 |
-
|
| 109 |
-
dataloader = DataLoader(
|
| 110 |
-
dataset,
|
| 111 |
-
batch_size=batch_size,
|
| 112 |
-
shuffle=False,
|
| 113 |
-
collate_fn=lambda batch: collate_fn(batch, self.vocab),
|
| 114 |
-
num_workers=0
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
return dataloader
|
| 118 |
-
|
| 119 |
-
def extract_embeddings(self, dataloader, embedding_type='text'):
|
| 120 |
-
"""
|
| 121 |
-
Extract embeddings from the model for a given dataloader.
|
| 122 |
-
|
| 123 |
-
Args:
|
| 124 |
-
dataloader: DataLoader containing images, texts, and hierarchy labels
|
| 125 |
-
embedding_type: Type of embeddings to extract ('text' or 'image')
|
| 126 |
-
|
| 127 |
-
Returns:
|
| 128 |
-
Tuple of (embeddings array, labels list, texts list)
|
| 129 |
-
"""
|
| 130 |
-
all_embeddings = []
|
| 131 |
-
all_labels = []
|
| 132 |
-
all_texts = []
|
| 133 |
-
|
| 134 |
-
with torch.no_grad():
|
| 135 |
-
for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} embeddings"):
|
| 136 |
-
images = batch['image'].to(self.device)
|
| 137 |
-
hierarchy_indices = batch['hierarchy_indices'].to(self.device)
|
| 138 |
-
hierarchy_labels = batch['hierarchy']
|
| 139 |
-
|
| 140 |
-
# Forward pass
|
| 141 |
-
out = self.model(image=images, hierarchy_indices=hierarchy_indices)
|
| 142 |
-
embeddings = out['z_txt'] if embedding_type == 'text' else out['z_img'] if embedding_type == 'image' else out['z_txt']
|
| 143 |
-
|
| 144 |
-
all_embeddings.append(embeddings.cpu().numpy())
|
| 145 |
-
all_labels.extend(hierarchy_labels)
|
| 146 |
-
all_texts.extend(hierarchy_labels)
|
| 147 |
-
|
| 148 |
-
return np.vstack(all_embeddings), all_labels, all_texts
|
| 149 |
-
|
| 150 |
-
def compute_similarity_metrics(self, embeddings, labels):
|
| 151 |
-
"""
|
| 152 |
-
Compute intra-class and inter-class similarity metrics.
|
| 153 |
-
|
| 154 |
-
Args:
|
| 155 |
-
embeddings: Array of embeddings [N, embed_dim]
|
| 156 |
-
labels: List of labels for each embedding
|
| 157 |
-
|
| 158 |
-
Returns:
|
| 159 |
-
Dictionary containing similarity metrics, accuracies, and separation scores
|
| 160 |
-
"""
|
| 161 |
-
similarities = cosine_similarity(embeddings)
|
| 162 |
-
|
| 163 |
-
# Group embeddings by hierarchy
|
| 164 |
-
hierarchy_groups = defaultdict(list)
|
| 165 |
-
for i, hierarchy in enumerate(labels):
|
| 166 |
-
hierarchy_groups[hierarchy].append(i)
|
| 167 |
-
|
| 168 |
-
# Calculate intra-class similarities (same hierarchy)
|
| 169 |
-
intra_class_similarities = []
|
| 170 |
-
for hierarchy, indices in hierarchy_groups.items():
|
| 171 |
-
if len(indices) > 1:
|
| 172 |
-
for i in range(len(indices)):
|
| 173 |
-
for j in range(i+1, len(indices)):
|
| 174 |
-
sim = similarities[indices[i], indices[j]]
|
| 175 |
-
intra_class_similarities.append(sim)
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
# Calculate inter-class similarities (different hierarchies)
|
| 179 |
-
inter_class_similarities = []
|
| 180 |
-
hierarchies = list(hierarchy_groups.keys())
|
| 181 |
-
for i in range(len(hierarchies)):
|
| 182 |
-
for j in range(i+1, len(hierarchies)):
|
| 183 |
-
hierarchy1_indices = hierarchy_groups[hierarchies[i]]
|
| 184 |
-
hierarchy2_indices = hierarchy_groups[hierarchies[j]]
|
| 185 |
-
|
| 186 |
-
for idx1 in hierarchy1_indices:
|
| 187 |
-
for idx2 in hierarchy2_indices:
|
| 188 |
-
sim = similarities[idx1, idx2]
|
| 189 |
-
inter_class_similarities.append(sim)
|
| 190 |
-
|
| 191 |
-
# Calculate classification accuracy using nearest neighbor in embedding space
|
| 192 |
-
nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
|
| 193 |
-
|
| 194 |
-
# Calculate classification accuracy using centroids
|
| 195 |
-
centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
|
| 196 |
-
|
| 197 |
-
return {
|
| 198 |
-
'intra_class_similarities': intra_class_similarities,
|
| 199 |
-
'inter_class_similarities': inter_class_similarities,
|
| 200 |
-
'intra_class_mean': np.mean(intra_class_similarities) if intra_class_similarities else 0,
|
| 201 |
-
'inter_class_mean': np.mean(inter_class_similarities) if inter_class_similarities else 0,
|
| 202 |
-
'separation_score': np.mean(intra_class_similarities) - np.mean(inter_class_similarities) if intra_class_similarities and inter_class_similarities else 0,
|
| 203 |
-
'accuracy': nn_accuracy,
|
| 204 |
-
'centroid_accuracy': centroid_accuracy
|
| 205 |
-
}
|
| 206 |
-
|
| 207 |
-
def compute_embedding_accuracy(self, embeddings, labels, similarities):
|
| 208 |
-
"""
|
| 209 |
-
Compute classification accuracy using nearest neighbor in embedding space.
|
| 210 |
-
|
| 211 |
-
Args:
|
| 212 |
-
embeddings: Array of embeddings [N, embed_dim]
|
| 213 |
-
labels: List of true labels
|
| 214 |
-
similarities: Pre-computed similarity matrix [N, N]
|
| 215 |
-
|
| 216 |
-
Returns:
|
| 217 |
-
Accuracy score (float between 0 and 1)
|
| 218 |
-
"""
|
| 219 |
-
correct_predictions = 0
|
| 220 |
-
total_predictions = len(labels)
|
| 221 |
-
|
| 222 |
-
for i in range(len(embeddings)):
|
| 223 |
-
true_label = labels[i]
|
| 224 |
-
|
| 225 |
-
# Find the most similar embedding (excluding itself)
|
| 226 |
-
similarities_row = similarities[i].copy()
|
| 227 |
-
similarities_row[i] = -1 # Exclude self-similarity
|
| 228 |
-
nearest_neighbor_idx = np.argmax(similarities_row)
|
| 229 |
-
predicted_label = labels[nearest_neighbor_idx]
|
| 230 |
-
|
| 231 |
-
if predicted_label == true_label:
|
| 232 |
-
correct_predictions += 1
|
| 233 |
-
|
| 234 |
-
return correct_predictions / total_predictions if total_predictions > 0 else 0
|
| 235 |
-
|
| 236 |
-
def compute_centroid_accuracy(self, embeddings, labels):
|
| 237 |
-
"""
|
| 238 |
-
Compute classification accuracy using hierarchy centroids.
|
| 239 |
-
|
| 240 |
-
Each hierarchy class is represented by its centroid (mean embedding), and each
|
| 241 |
-
embedding is classified to the nearest centroid.
|
| 242 |
-
|
| 243 |
-
Args:
|
| 244 |
-
embeddings: Array of embeddings [N, embed_dim]
|
| 245 |
-
labels: List of true labels
|
| 246 |
-
|
| 247 |
-
Returns:
|
| 248 |
-
Accuracy score (float between 0 and 1)
|
| 249 |
-
"""
|
| 250 |
-
# Create centroids for each hierarchy
|
| 251 |
-
unique_hierarchies = list(set(labels))
|
| 252 |
-
centroids = {}
|
| 253 |
-
|
| 254 |
-
for hierarchy in unique_hierarchies:
|
| 255 |
-
hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy]
|
| 256 |
-
hierarchy_embeddings = embeddings[hierarchy_indices]
|
| 257 |
-
centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0)
|
| 258 |
-
|
| 259 |
-
# Classify each embedding to nearest centroid
|
| 260 |
-
correct_predictions = 0
|
| 261 |
-
total_predictions = len(labels)
|
| 262 |
-
|
| 263 |
-
for i, embedding in enumerate(embeddings):
|
| 264 |
-
true_label = labels[i]
|
| 265 |
-
|
| 266 |
-
# Find closest centroid
|
| 267 |
-
best_similarity = -1
|
| 268 |
-
predicted_label = None
|
| 269 |
-
|
| 270 |
-
for hierarchy, centroid in centroids.items():
|
| 271 |
-
similarity = cosine_similarity([embedding], [centroid])[0][0]
|
| 272 |
-
if similarity > best_similarity:
|
| 273 |
-
best_similarity = similarity
|
| 274 |
-
predicted_label = hierarchy
|
| 275 |
-
|
| 276 |
-
if predicted_label == true_label:
|
| 277 |
-
correct_predictions += 1
|
| 278 |
-
|
| 279 |
-
return correct_predictions / total_predictions if total_predictions > 0 else 0
|
| 280 |
-
|
| 281 |
-
def predict_hierarchy_from_embeddings(self, embeddings, labels):
|
| 282 |
-
"""
|
| 283 |
-
Predict hierarchy from embeddings using centroid-based classification.
|
| 284 |
-
|
| 285 |
-
Args:
|
| 286 |
-
embeddings: Array of embeddings [N, embed_dim]
|
| 287 |
-
labels: List of labels used to compute centroids
|
| 288 |
-
|
| 289 |
-
Returns:
|
| 290 |
-
List of predicted hierarchy labels
|
| 291 |
-
"""
|
| 292 |
-
# Create hierarchy centroids from training data
|
| 293 |
-
unique_hierarchies = list(set(labels))
|
| 294 |
-
centroids = {}
|
| 295 |
-
|
| 296 |
-
for hierarchy in unique_hierarchies:
|
| 297 |
-
hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy]
|
| 298 |
-
hierarchy_embeddings = embeddings[hierarchy_indices]
|
| 299 |
-
centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0)
|
| 300 |
-
|
| 301 |
-
# Predict hierarchy for all embeddings
|
| 302 |
-
predictions = []
|
| 303 |
-
|
| 304 |
-
for i, embedding in enumerate(embeddings):
|
| 305 |
-
# Find closest centroid
|
| 306 |
-
best_similarity = -1
|
| 307 |
-
predicted_hierarchy = None
|
| 308 |
-
|
| 309 |
-
for hierarchy, centroid in centroids.items():
|
| 310 |
-
similarity = cosine_similarity([embedding], [centroid])[0][0]
|
| 311 |
-
if similarity > best_similarity:
|
| 312 |
-
best_similarity = similarity
|
| 313 |
-
predicted_hierarchy = hierarchy
|
| 314 |
-
|
| 315 |
-
predictions.append(predicted_hierarchy)
|
| 316 |
-
|
| 317 |
-
return predictions
|
| 318 |
-
|
| 319 |
-
def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix"):
|
| 320 |
-
"""
|
| 321 |
-
Create and plot a confusion matrix.
|
| 322 |
-
|
| 323 |
-
Args:
|
| 324 |
-
true_labels: List of true labels
|
| 325 |
-
predicted_labels: List of predicted labels
|
| 326 |
-
title: Title for the confusion matrix plot
|
| 327 |
-
|
| 328 |
-
Returns:
|
| 329 |
-
Tuple of (figure, accuracy, confusion_matrix)
|
| 330 |
-
"""
|
| 331 |
-
# Get unique labels
|
| 332 |
-
unique_labels = sorted(list(set(true_labels + predicted_labels)))
|
| 333 |
-
|
| 334 |
-
# Create confusion matrix
|
| 335 |
-
cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
|
| 336 |
-
|
| 337 |
-
# Calculate accuracy
|
| 338 |
-
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 339 |
-
|
| 340 |
-
# Plot confusion matrix
|
| 341 |
-
plt.figure(figsize=(12, 10))
|
| 342 |
-
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
| 343 |
-
xticklabels=unique_labels, yticklabels=unique_labels)
|
| 344 |
-
plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
|
| 345 |
-
plt.ylabel('True Hierarchy')
|
| 346 |
-
plt.xlabel('Predicted Hierarchy')
|
| 347 |
-
plt.xticks(rotation=45)
|
| 348 |
-
plt.yticks(rotation=0)
|
| 349 |
-
plt.tight_layout()
|
| 350 |
-
|
| 351 |
-
return plt.gcf(), accuracy, cm
|
| 352 |
-
|
| 353 |
-
def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings"):
|
| 354 |
-
"""
|
| 355 |
-
Evaluate classification performance and create confusion matrix.
|
| 356 |
-
|
| 357 |
-
Args:
|
| 358 |
-
embeddings: Array of embeddings [N, embed_dim]
|
| 359 |
-
labels: List of true labels
|
| 360 |
-
embedding_type: Type of embeddings for display purposes
|
| 361 |
-
|
| 362 |
-
Returns:
|
| 363 |
-
Dictionary containing accuracy, predictions, confusion matrix, and classification report
|
| 364 |
-
"""
|
| 365 |
-
# Predict hierarchy
|
| 366 |
-
predictions = self.predict_hierarchy_from_embeddings(embeddings, labels)
|
| 367 |
-
|
| 368 |
-
# Calculate accuracy
|
| 369 |
-
accuracy = accuracy_score(labels, predictions)
|
| 370 |
-
|
| 371 |
-
# Create confusion matrix
|
| 372 |
-
fig, acc, cm = self.create_confusion_matrix(labels, predictions,
|
| 373 |
-
f"{embedding_type} - Hierarchy Classification")
|
| 374 |
-
|
| 375 |
-
# Generate classification report
|
| 376 |
-
unique_labels = sorted(list(set(labels)))
|
| 377 |
-
report = classification_report(labels, predictions, labels=unique_labels,
|
| 378 |
-
target_names=unique_labels, output_dict=True)
|
| 379 |
-
|
| 380 |
-
return {
|
| 381 |
-
'accuracy': accuracy,
|
| 382 |
-
'predictions': predictions,
|
| 383 |
-
'confusion_matrix': cm,
|
| 384 |
-
'classification_report': report,
|
| 385 |
-
'figure': fig
|
| 386 |
-
}
|
| 387 |
-
|
| 388 |
-
def evaluate_dataset(self, dataframe, dataset_name="Dataset"):
|
| 389 |
-
"""
|
| 390 |
-
Evaluate embeddings on a given dataset.
|
| 391 |
-
|
| 392 |
-
This method extracts embeddings for text and image, computes similarity metrics,
|
| 393 |
-
evaluates classification performance, and saves confusion matrices.
|
| 394 |
-
|
| 395 |
-
Args:
|
| 396 |
-
dataframe: DataFrame containing the dataset
|
| 397 |
-
dataset_name: Name of the dataset for display purposes
|
| 398 |
-
|
| 399 |
-
Returns:
|
| 400 |
-
Dictionary containing evaluation results for text and image embeddings
|
| 401 |
-
"""
|
| 402 |
-
print(f"\n{'='*60}")
|
| 403 |
-
print(f"Evaluating {dataset_name}")
|
| 404 |
-
print(f"{'='*60}")
|
| 405 |
-
|
| 406 |
-
# Create dataloader exactly as during training
|
| 407 |
-
dataloader = self.create_dataloader(dataframe, batch_size=16)
|
| 408 |
-
|
| 409 |
-
results = {}
|
| 410 |
-
|
| 411 |
-
# Evaluate text embeddings
|
| 412 |
-
text_embeddings, text_labels, texts = self.extract_embeddings(dataloader, 'text')
|
| 413 |
-
text_metrics = self.compute_similarity_metrics(text_embeddings, text_labels)
|
| 414 |
-
text_classification = self.evaluate_classification_performance(text_embeddings, text_labels, "Text Embeddings")
|
| 415 |
-
text_metrics.update(text_classification)
|
| 416 |
-
results['text'] = text_metrics
|
| 417 |
-
|
| 418 |
-
# Evaluate image embeddings
|
| 419 |
-
image_embeddings, image_labels, _ = self.extract_embeddings(dataloader, 'image')
|
| 420 |
-
image_metrics = self.compute_similarity_metrics(image_embeddings, image_labels)
|
| 421 |
-
image_classification = self.evaluate_classification_performance(image_embeddings, image_labels, "Image Embeddings")
|
| 422 |
-
image_metrics.update(image_classification)
|
| 423 |
-
results['image'] = image_metrics
|
| 424 |
-
|
| 425 |
-
# Evaluate hierarchy embeddings
|
| 426 |
-
hierarchy_embeddings, hierarchy_labels, _ = self.extract_embeddings(dataloader, 'category2')
|
| 427 |
-
hierarchy_metrics = self.compute_similarity_metrics(hierarchy_embeddings, hierarchy_labels)
|
| 428 |
-
hierarchy_classification = self.evaluate_classification_performance(hierarchy_embeddings, hierarchy_labels, "hierarchy Embeddings")
|
| 429 |
-
hierarchy_metrics.update(hierarchy_classification)
|
| 430 |
-
results['hierarchy'] = hierarchy_metrics
|
| 431 |
-
|
| 432 |
-
# Print results
|
| 433 |
-
print(f"\n{dataset_name} Results:")
|
| 434 |
-
print("-" * 40)
|
| 435 |
-
for emb_type, metrics in results.items():
|
| 436 |
-
print(f"{emb_type.capitalize()} Embeddings:")
|
| 437 |
-
print(f" Intra-class similarity (same hierarchy): {metrics['intra_class_mean']:.4f}")
|
| 438 |
-
print(f" Inter-class similarity (diff hierarchy): {metrics['inter_class_mean']:.4f}")
|
| 439 |
-
print(f" Separation score: {metrics['separation_score']:.4f}")
|
| 440 |
-
print(f" Nearest Neighbor Accuracy: {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
|
| 441 |
-
print(f" Centroid Accuracy: {metrics['centroid_accuracy']:.4f} ({metrics['centroid_accuracy']*100:.1f}%)")
|
| 442 |
-
|
| 443 |
-
# Classification report summary
|
| 444 |
-
report = metrics['classification_report']
|
| 445 |
-
print(f" 📊 Classification Performance:")
|
| 446 |
-
print(f" • Macro Avg F1-Score: {report['macro avg']['f1-score']:.4f}")
|
| 447 |
-
print(f" • Weighted Avg F1-Score: {report['weighted avg']['f1-score']:.4f}")
|
| 448 |
-
print(f" • Support: {report['macro avg']['support']:.0f} samples")
|
| 449 |
-
print()
|
| 450 |
-
|
| 451 |
-
# Create visualizations
|
| 452 |
-
os.makedirs(f'{self.directory}', exist_ok=True)
|
| 453 |
-
|
| 454 |
-
# Confusion matrices
|
| 455 |
-
results['text']['figure'].savefig(f'{self.directory}/{dataset_name.lower()}_text_confusion_matrix.png', dpi=300, bbox_inches='tight')
|
| 456 |
-
plt.close(results['text']['figure'])
|
| 457 |
-
|
| 458 |
-
results['image']['figure'].savefig(f'{self.directory}/{dataset_name.lower()}_image_confusion_matrix.png', dpi=300, bbox_inches='tight')
|
| 459 |
-
plt.close(results['image']['figure'])
|
| 460 |
-
|
| 461 |
-
results['hierarchy']['figure'].savefig(f'{self.directory}/{dataset_name.lower()}_hierarchy_confusion_matrix.png', dpi=300, bbox_inches='tight')
|
| 462 |
-
plt.close(results['hierarchy']['figure'])
|
| 463 |
-
|
| 464 |
-
return results
|
| 465 |
-
|
| 466 |
-
class KaglDataset(Dataset):
|
| 467 |
-
def __init__(self, dataframe):
|
| 468 |
-
self.dataframe = dataframe
|
| 469 |
-
# Use VALIDATION transforms (no augmentation)
|
| 470 |
-
self.transform = transforms.Compose([
|
| 471 |
-
transforms.Resize((224, 224)),
|
| 472 |
-
transforms.ToTensor(),
|
| 473 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 474 |
-
])
|
| 475 |
-
|
| 476 |
-
def __len__(self):
|
| 477 |
-
return len(self.dataframe)
|
| 478 |
-
|
| 479 |
-
def __getitem__(self, idx):
|
| 480 |
-
row = self.dataframe.iloc[idx]
|
| 481 |
-
|
| 482 |
-
# Handle image
|
| 483 |
-
image_data = row['image_url']
|
| 484 |
-
image = Image.open(BytesIO(image_data['bytes'])).convert("RGB")
|
| 485 |
-
image = self.transform(image)
|
| 486 |
-
|
| 487 |
-
# Get text and hierarchy
|
| 488 |
-
description = row['text']
|
| 489 |
-
hierarchy = row['hierarchy']
|
| 490 |
-
|
| 491 |
-
return image, description, hierarchy
|
| 492 |
-
|
| 493 |
-
def load_Kagl_marqo_dataset(evaluator):
|
| 494 |
-
"""Load and prepare Kagl KAGL dataset"""
|
| 495 |
-
from datasets import load_dataset
|
| 496 |
-
print("Loading Kagl KAGL dataset...")
|
| 497 |
-
|
| 498 |
-
# Load the dataset
|
| 499 |
-
dataset = load_dataset("Marqo/KAGL")
|
| 500 |
-
df = dataset["data"].to_pandas()
|
| 501 |
-
print(f"✅ Dataset Kagl loaded")
|
| 502 |
-
print(f"📊 Before filtering: {len(df)} samples")
|
| 503 |
-
print(f"📋 Available columns: {list(df.columns)}")
|
| 504 |
-
|
| 505 |
-
# Check available categories and map them to our hierarchy
|
| 506 |
-
print(f"🎨 Available categories: {sorted(df['category2'].unique())}")
|
| 507 |
-
# Apply mapping
|
| 508 |
-
df['hierarchy'] = df['category2'].str.lower()
|
| 509 |
-
df['hierarchy'] = df['hierarchy'].replace('bags', 'bag').replace('topwear', 'top').replace('flip flops', 'shoes').replace('sandal', 'shoes')
|
| 510 |
-
|
| 511 |
-
# Filter to only include valid hierarchies that exist in our model
|
| 512 |
-
valid_hierarchies = df['hierarchy'].dropna().unique()
|
| 513 |
-
print(f"🎯 Valid hierarchies found: {sorted(valid_hierarchies)}")
|
| 514 |
-
print(f"🎯 Model hierarchies: {sorted(evaluator.hierarchy_classes)}")
|
| 515 |
-
|
| 516 |
-
# Filter to only include hierarchies that exist in our model
|
| 517 |
-
df = df[df['hierarchy'].isin(evaluator.hierarchy_classes)]
|
| 518 |
-
print(f"📊 After filtering to model hierarchies: {len(df)} samples")
|
| 519 |
-
|
| 520 |
-
if len(df) == 0:
|
| 521 |
-
print("❌ No samples left after hierarchy filtering.")
|
| 522 |
-
return pd.DataFrame()
|
| 523 |
-
|
| 524 |
-
# Ensure we have text and image data
|
| 525 |
-
df = df.dropna(subset=['text', 'image'])
|
| 526 |
-
print(f"📊 After removing missing text/image: {len(df)} samples")
|
| 527 |
-
|
| 528 |
-
# Show sample of text data to verify quality
|
| 529 |
-
print(f"📝 Sample texts:")
|
| 530 |
-
for i, (text, hierarchy) in enumerate(zip(df['text'].head(3), df['hierarchy'].head(3))):
|
| 531 |
-
print(f" {i+1}. [{hierarchy}] {text[:100]}...")
|
| 532 |
-
|
| 533 |
-
print(f"📊 After sampling: {len(df)} samples")
|
| 534 |
-
print(f"📊 Samples per hierarchy:")
|
| 535 |
-
for hierarchy in sorted(df['hierarchy'].unique()):
|
| 536 |
-
count = len(df[df['hierarchy'] == hierarchy])
|
| 537 |
-
print(f" {hierarchy}: {count} samples")
|
| 538 |
-
|
| 539 |
-
# Create formatted dataset with proper column names
|
| 540 |
-
Kagl_formatted = pd.DataFrame({
|
| 541 |
-
'image_url': df['image'],
|
| 542 |
-
'text': df['text'],
|
| 543 |
-
'hierarchy': df['hierarchy']
|
| 544 |
-
})
|
| 545 |
-
|
| 546 |
-
print(f"📊 Final dataset size: {len(Kagl_formatted)} samples")
|
| 547 |
-
return Kagl_formatted
|
| 548 |
-
|
| 549 |
-
if __name__ == "__main__":
|
| 550 |
-
device = config.device
|
| 551 |
-
model_path = config.hierarchy_model_path
|
| 552 |
-
directory = config.evaluation_directory
|
| 553 |
-
|
| 554 |
-
print(f"🚀 Starting evaluation with {model_path}")
|
| 555 |
-
|
| 556 |
-
evaluator = EmbeddingEvaluator(model_path, directory)
|
| 557 |
-
|
| 558 |
-
print(f"📊 Final hierarchy classes after initialization: {len(evaluator.vocab.hierarchy_classes)} classes")
|
| 559 |
-
|
| 560 |
-
# Evaluate on validation dataset (same subset as during training)
|
| 561 |
-
print("\n" + "="*60)
|
| 562 |
-
print("EVALUATING VALIDATION DATASET")
|
| 563 |
-
print("="*60)
|
| 564 |
-
val_results = evaluator.evaluate_dataset(evaluator.val_df, "Validation Dataset")
|
| 565 |
-
|
| 566 |
-
print("\n" + "="*60)
|
| 567 |
-
print("EVALUATING Kagl MARQO DATASET")
|
| 568 |
-
print("="*60)
|
| 569 |
-
df_Kagl_marqo = load_Kagl_marqo_dataset(evaluator)
|
| 570 |
-
Kagl_results = evaluator.evaluate_dataset(df_Kagl_marqo, "Kagl Marqo Dataset")
|
| 571 |
-
|
| 572 |
-
# Compare results
|
| 573 |
-
print(f"\n{'='*60}")
|
| 574 |
-
print("FINAL EVALUATION SUMMARY")
|
| 575 |
-
print(f"{'='*60}")
|
| 576 |
-
|
| 577 |
-
print("\n🔍 VALIDATION DATASET RESULTS:")
|
| 578 |
-
print(f"Text - Separation: {val_results['text']['separation_score']:.4f} | NN Acc: {val_results['text']['accuracy']*100:.1f}% | Centroid Acc: {val_results['text']['centroid_accuracy']*100:.1f}%")
|
| 579 |
-
print(f"Image - Separation: {val_results['image']['separation_score']:.4f} | NN Acc: {val_results['image']['accuracy']*100:.1f}% | Centroid Acc: {val_results['image']['centroid_accuracy']*100:.1f}%")
|
| 580 |
-
print(f"hierarchy - Separation: {val_results['hierarchy']['separation_score']:.4f} | NN Acc: {val_results['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {val_results['hierarchy']['centroid_accuracy']*100:.1f}%")
|
| 581 |
-
|
| 582 |
-
print("\n🌐 Kagl MARQO DATASET RESULTS:")
|
| 583 |
-
print(f"Text - Separation: {Kagl_results['text']['separation_score']:.4f} | NN Acc: {Kagl_results['text']['accuracy']*100:.1f}% | Centroid Acc: {Kagl_results['text']['centroid_accuracy']*100:.1f}%")
|
| 584 |
-
print(f"Image - Separation: {Kagl_results['image']['separation_score']:.4f} | NN Acc: {Kagl_results['image']['accuracy']*100:.1f}% | Centroid Acc: {Kagl_results['image']['centroid_accuracy']*100:.1f}%")
|
| 585 |
-
print(f"Hierarchy - Separation: {Kagl_results['hierarchy']['separation_score']:.4f} | NN Acc: {Kagl_results['hierarchy']['accuracy']*100:.1f}% | Centroid Acc: {Kagl_results['hierarchy']['centroid_accuracy']*100:.1f}%")
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
print(f"\n✅ Evaluation completed! Check 'improved_model_evaluation/' for visualization files.")
|
| 589 |
-
print(f"📊 Final hierarchy classes used: {len(evaluator.vocab.hierarchy_classes)} classes")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Evaluation/hierarchy_evaluation_with_clip_baseline.py
DELETED
|
@@ -1,808 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Hierarchy embedding evaluation with CLIP baseline comparison.
|
| 3 |
-
This file evaluates the quality of hierarchy embeddings from the custom model and compares them
|
| 4 |
-
with a CLIP baseline model (OpenAI CLIP). It calculates similarity metrics, classification accuracies,
|
| 5 |
-
and generates confusion matrices for both models to measure relative performance. It also supports
|
| 6 |
-
evaluation on Fashion-MNIST and kagl Marqo datasets.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
import pandas as pd
|
| 11 |
-
import numpy as np
|
| 12 |
-
import matplotlib.pyplot as plt
|
| 13 |
-
import seaborn as sns
|
| 14 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 15 |
-
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, f1_score
|
| 16 |
-
from collections import defaultdict
|
| 17 |
-
import os
|
| 18 |
-
import requests
|
| 19 |
-
from tqdm import tqdm
|
| 20 |
-
from torch.utils.data import Dataset, DataLoader
|
| 21 |
-
from torchvision import transforms
|
| 22 |
-
from sklearn.model_selection import train_test_split
|
| 23 |
-
from io import BytesIO
|
| 24 |
-
from PIL import Image
|
| 25 |
-
import warnings
|
| 26 |
-
warnings.filterwarnings('ignore')
|
| 27 |
-
|
| 28 |
-
# Import transformers CLIP
|
| 29 |
-
from transformers import CLIPProcessor, CLIPModel as TransformersCLIPModel
|
| 30 |
-
|
| 31 |
-
# Import your custom model
|
| 32 |
-
from hierarchy_model import Model, HierarchyExtractor, HierarchyDataset, collate_fn
|
| 33 |
-
import config
|
| 34 |
-
|
| 35 |
-
def convert_fashion_mnist_to_image(pixel_values):
|
| 36 |
-
"""Convert Fashion-MNIST pixel values to PIL image"""
|
| 37 |
-
# Reshape to 28x28 and convert to PIL Image
|
| 38 |
-
image_array = np.array(pixel_values).reshape(28, 28).astype(np.uint8)
|
| 39 |
-
# Convert to RGB by duplicating the grayscale channel
|
| 40 |
-
image_array = np.stack([image_array] * 3, axis=-1)
|
| 41 |
-
image = Image.fromarray(image_array)
|
| 42 |
-
return image
|
| 43 |
-
|
| 44 |
-
def get_fashion_mnist_labels():
|
| 45 |
-
"""Get Fashion-MNIST class labels"""
|
| 46 |
-
return {
|
| 47 |
-
0: "T-shirt/top",
|
| 48 |
-
1: "Trouser",
|
| 49 |
-
2: "Pullover",
|
| 50 |
-
3: "Dress",
|
| 51 |
-
4: "Coat",
|
| 52 |
-
5: "Sandal",
|
| 53 |
-
6: "Shirt",
|
| 54 |
-
7: "Sneaker",
|
| 55 |
-
8: "Bag",
|
| 56 |
-
9: "Ankle boot"
|
| 57 |
-
}
|
| 58 |
-
|
| 59 |
-
class FashionMNISTDataset(Dataset):
|
| 60 |
-
def __init__(self, dataframe, image_size=224):
|
| 61 |
-
self.dataframe = dataframe
|
| 62 |
-
self.image_size = image_size
|
| 63 |
-
self.labels_map = get_fashion_mnist_labels()
|
| 64 |
-
|
| 65 |
-
# Simple transforms for validation/inference
|
| 66 |
-
self.transform = transforms.Compose([
|
| 67 |
-
transforms.Resize((image_size, image_size)),
|
| 68 |
-
transforms.ToTensor(),
|
| 69 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 70 |
-
])
|
| 71 |
-
|
| 72 |
-
def __len__(self):
|
| 73 |
-
return len(self.dataframe)
|
| 74 |
-
|
| 75 |
-
def __getitem__(self, idx):
|
| 76 |
-
row = self.dataframe.iloc[idx]
|
| 77 |
-
|
| 78 |
-
# Get pixel values (columns 1-784)
|
| 79 |
-
pixel_cols = [f'pixel{i}' for i in range(1, 785)]
|
| 80 |
-
pixel_values = row[pixel_cols].values
|
| 81 |
-
|
| 82 |
-
# Convert to image
|
| 83 |
-
image = convert_fashion_mnist_to_image(pixel_values)
|
| 84 |
-
image = self.transform(image)
|
| 85 |
-
|
| 86 |
-
# Get text description
|
| 87 |
-
text = row['text']
|
| 88 |
-
|
| 89 |
-
# Get hierarchy label
|
| 90 |
-
hierarchy = row['hierarchy']
|
| 91 |
-
|
| 92 |
-
return image, text, hierarchy
|
| 93 |
-
|
| 94 |
-
class CLIPBaselineEvaluator:
|
| 95 |
-
def __init__(self, device='mps'):
|
| 96 |
-
self.device = torch.device(device)
|
| 97 |
-
|
| 98 |
-
# Load CLIP model and processor
|
| 99 |
-
print("🤗 Loading CLIP baseline model from transformers...")
|
| 100 |
-
self.clip_model = TransformersCLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
|
| 101 |
-
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 102 |
-
self.clip_model.eval()
|
| 103 |
-
print("✅ CLIP model loaded successfully")
|
| 104 |
-
|
| 105 |
-
def extract_clip_embeddings(self, images, texts):
|
| 106 |
-
"""Extract CLIP embeddings for images and texts"""
|
| 107 |
-
all_image_embeddings = []
|
| 108 |
-
all_text_embeddings = []
|
| 109 |
-
|
| 110 |
-
with torch.no_grad():
|
| 111 |
-
for i in tqdm(range(len(images)), desc="Extracting CLIP embeddings"):
|
| 112 |
-
# Process image
|
| 113 |
-
if isinstance(images[i], torch.Tensor):
|
| 114 |
-
# Convert tensor back to PIL Image
|
| 115 |
-
image_tensor = images[i]
|
| 116 |
-
if image_tensor.dim() == 3:
|
| 117 |
-
# Denormalize
|
| 118 |
-
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
| 119 |
-
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| 120 |
-
image_tensor = image_tensor * std + mean
|
| 121 |
-
image_tensor = torch.clamp(image_tensor, 0, 1)
|
| 122 |
-
|
| 123 |
-
# Convert to PIL
|
| 124 |
-
image_pil = transforms.ToPILImage()(image_tensor)
|
| 125 |
-
elif isinstance(images[i], Image.Image):
|
| 126 |
-
image_pil = images[i]
|
| 127 |
-
else:
|
| 128 |
-
raise ValueError(f"Unsupported image type: {type(images[i])}")
|
| 129 |
-
|
| 130 |
-
# Process with CLIP
|
| 131 |
-
inputs = self.clip_processor(
|
| 132 |
-
text=texts[i],
|
| 133 |
-
images=image_pil,
|
| 134 |
-
return_tensors="pt",
|
| 135 |
-
padding=True
|
| 136 |
-
).to(self.device)
|
| 137 |
-
|
| 138 |
-
outputs = self.clip_model(**inputs)
|
| 139 |
-
|
| 140 |
-
# Get normalized embeddings
|
| 141 |
-
image_emb = outputs.image_embeds / outputs.image_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 142 |
-
text_emb = outputs.text_embeds / outputs.text_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 143 |
-
|
| 144 |
-
all_image_embeddings.append(image_emb.cpu().numpy())
|
| 145 |
-
all_text_embeddings.append(text_emb.cpu().numpy())
|
| 146 |
-
|
| 147 |
-
return np.vstack(all_image_embeddings), np.vstack(all_text_embeddings)
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
class EmbeddingEvaluator:
|
| 151 |
-
def __init__(self, model_path, directory):
|
| 152 |
-
self.device = config.device
|
| 153 |
-
self.directory = directory
|
| 154 |
-
|
| 155 |
-
# 1. Load the dataset
|
| 156 |
-
CSV = config.local_dataset_path
|
| 157 |
-
print(f"📁 Using dataset with local images: {CSV}")
|
| 158 |
-
df = pd.read_csv(CSV)
|
| 159 |
-
|
| 160 |
-
print(f"📁 Loaded {len(df)} samples")
|
| 161 |
-
|
| 162 |
-
# 2. Get unique hierarchy classes from the dataset
|
| 163 |
-
hierarchy_classes = sorted(df[config.hierarchy_column].unique().tolist())
|
| 164 |
-
print(f"📋 Found {len(hierarchy_classes)} hierarchy classes")
|
| 165 |
-
|
| 166 |
-
_, self.val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['hierarchy'])
|
| 167 |
-
|
| 168 |
-
# 3. Load the model
|
| 169 |
-
if os.path.exists(model_path):
|
| 170 |
-
checkpoint = torch.load(model_path, map_location=self.device)
|
| 171 |
-
|
| 172 |
-
config = checkpoint.get('config', {})
|
| 173 |
-
saved_hierarchy_classes = checkpoint['hierarchy_classes']
|
| 174 |
-
|
| 175 |
-
# Use the saved hierarchy classes
|
| 176 |
-
self.hierarchy_classes = saved_hierarchy_classes
|
| 177 |
-
|
| 178 |
-
# Create the hierarchy extractor
|
| 179 |
-
self.vocab = HierarchyExtractor(saved_hierarchy_classes)
|
| 180 |
-
|
| 181 |
-
# Create the model with the saved configuration
|
| 182 |
-
self.model = Model(
|
| 183 |
-
num_hierarchy_classes=len(saved_hierarchy_classes),
|
| 184 |
-
embed_dim=config['embed_dim'],
|
| 185 |
-
dropout=config['dropout']
|
| 186 |
-
).to(self.device)
|
| 187 |
-
|
| 188 |
-
self.model.load_state_dict(checkpoint['model_state'])
|
| 189 |
-
|
| 190 |
-
print(f"✅ Custom model loaded with:")
|
| 191 |
-
print(f"📋 Hierarchy classes: {len(saved_hierarchy_classes)}")
|
| 192 |
-
print(f"🎯 Embed dim: {config['embed_dim']}")
|
| 193 |
-
print(f"💧 Dropout: {config['dropout']}")
|
| 194 |
-
print(f"📅 Epoch: {checkpoint.get('epoch', 'unknown')}")
|
| 195 |
-
|
| 196 |
-
else:
|
| 197 |
-
raise FileNotFoundError(f"Model file {model_path} not found")
|
| 198 |
-
|
| 199 |
-
self.model.eval()
|
| 200 |
-
|
| 201 |
-
# Initialize CLIP baseline
|
| 202 |
-
self.clip_evaluator = CLIPBaselineEvaluator(device)
|
| 203 |
-
|
| 204 |
-
def create_dataloader(self, dataframe, batch_size=16):
|
| 205 |
-
"""Create a dataloader for custom model"""
|
| 206 |
-
# Check if this is Fashion-MNIST data (has pixel1 column)
|
| 207 |
-
if 'pixel1' in dataframe.columns:
|
| 208 |
-
print("🔍 Detected Fashion-MNIST data, using FashionMNISTDataset")
|
| 209 |
-
dataset = FashionMNISTDataset(dataframe, image_size=224)
|
| 210 |
-
else:
|
| 211 |
-
dataset = HierarchyDataset(dataframe, image_size=224)
|
| 212 |
-
|
| 213 |
-
dataloader = DataLoader(
|
| 214 |
-
dataset,
|
| 215 |
-
batch_size=batch_size,
|
| 216 |
-
shuffle=False,
|
| 217 |
-
collate_fn=lambda batch: collate_fn(batch, self.vocab),
|
| 218 |
-
num_workers=0
|
| 219 |
-
)
|
| 220 |
-
|
| 221 |
-
return dataloader
|
| 222 |
-
|
| 223 |
-
def create_clip_dataloader(self, dataframe, batch_size=16):
|
| 224 |
-
"""Create a dataloader for CLIP baseline"""
|
| 225 |
-
# Check if this is Fashion-MNIST data (has pixel1 column)
|
| 226 |
-
if 'pixel1' in dataframe.columns:
|
| 227 |
-
print("🔍 Detected Fashion-MNIST data for CLIP, using FashionMNISTDataset")
|
| 228 |
-
dataset = FashionMNISTDataset(dataframe, image_size=224)
|
| 229 |
-
else:
|
| 230 |
-
dataset = CLIPDataset(dataframe)
|
| 231 |
-
|
| 232 |
-
dataloader = DataLoader(
|
| 233 |
-
dataset,
|
| 234 |
-
batch_size=batch_size,
|
| 235 |
-
shuffle=False,
|
| 236 |
-
num_workers=0
|
| 237 |
-
)
|
| 238 |
-
|
| 239 |
-
return dataloader
|
| 240 |
-
|
| 241 |
-
def extract_custom_embeddings(self, dataloader, embedding_type='text'):
|
| 242 |
-
"""Extract embeddings from custom model"""
|
| 243 |
-
all_embeddings = []
|
| 244 |
-
all_labels = []
|
| 245 |
-
all_texts = []
|
| 246 |
-
|
| 247 |
-
with torch.no_grad():
|
| 248 |
-
for batch in tqdm(dataloader, desc=f"Extracting custom {embedding_type} embeddings"):
|
| 249 |
-
images = batch['image'].to(self.device)
|
| 250 |
-
hierarchy_indices = batch['hierarchy_indices'].to(self.device)
|
| 251 |
-
hierarchy_labels = batch['hierarchy']
|
| 252 |
-
|
| 253 |
-
# Forward pass
|
| 254 |
-
out = self.model(image=images, hierarchy_indices=hierarchy_indices)
|
| 255 |
-
embeddings = out['z_txt'] if embedding_type == 'text' else out['z_img'] if embedding_type == 'image' else out['z_txt']
|
| 256 |
-
|
| 257 |
-
all_embeddings.append(embeddings.cpu().numpy())
|
| 258 |
-
all_labels.extend(hierarchy_labels)
|
| 259 |
-
all_texts.extend(hierarchy_labels)
|
| 260 |
-
|
| 261 |
-
return np.vstack(all_embeddings), all_labels, all_texts
|
| 262 |
-
|
| 263 |
-
def compute_similarity_metrics(self, embeddings, labels):
|
| 264 |
-
"""Compute intra-class and inter-class similarities"""
|
| 265 |
-
similarities = cosine_similarity(embeddings)
|
| 266 |
-
|
| 267 |
-
# Group embeddings by hierarchy
|
| 268 |
-
hierarchy_groups = defaultdict(list)
|
| 269 |
-
for i, hierarchy in enumerate(labels):
|
| 270 |
-
hierarchy_groups[hierarchy].append(i)
|
| 271 |
-
|
| 272 |
-
# Calculate intra-class similarities (same hierarchy)
|
| 273 |
-
intra_class_similarities = []
|
| 274 |
-
for hierarchy, indices in hierarchy_groups.items():
|
| 275 |
-
if len(indices) > 1:
|
| 276 |
-
for i in range(len(indices)):
|
| 277 |
-
for j in range(i+1, len(indices)):
|
| 278 |
-
sim = similarities[indices[i], indices[j]]
|
| 279 |
-
intra_class_similarities.append(sim)
|
| 280 |
-
|
| 281 |
-
# Calculate inter-class similarities (different hierarchies)
|
| 282 |
-
inter_class_similarities = []
|
| 283 |
-
hierarchies = list(hierarchy_groups.keys())
|
| 284 |
-
for i in range(len(hierarchies)):
|
| 285 |
-
for j in range(i+1, len(hierarchies)):
|
| 286 |
-
hierarchy1_indices = hierarchy_groups[hierarchies[i]]
|
| 287 |
-
hierarchy2_indices = hierarchy_groups[hierarchies[j]]
|
| 288 |
-
|
| 289 |
-
for idx1 in hierarchy1_indices:
|
| 290 |
-
for idx2 in hierarchy2_indices:
|
| 291 |
-
sim = similarities[idx1, idx2]
|
| 292 |
-
inter_class_similarities.append(sim)
|
| 293 |
-
|
| 294 |
-
# Calculate classification accuracy using nearest neighbor in embedding space
|
| 295 |
-
nn_accuracy = self.compute_embedding_accuracy(embeddings, labels, similarities)
|
| 296 |
-
|
| 297 |
-
# Calculate classification accuracy using centroids
|
| 298 |
-
centroid_accuracy = self.compute_centroid_accuracy(embeddings, labels)
|
| 299 |
-
|
| 300 |
-
return {
|
| 301 |
-
'intra_class_similarities': intra_class_similarities,
|
| 302 |
-
'inter_class_similarities': inter_class_similarities,
|
| 303 |
-
'intra_class_mean': np.mean(intra_class_similarities) if intra_class_similarities else 0,
|
| 304 |
-
'inter_class_mean': np.mean(inter_class_similarities) if inter_class_similarities else 0,
|
| 305 |
-
'separation_score': np.mean(intra_class_similarities) - np.mean(inter_class_similarities) if intra_class_similarities and inter_class_similarities else 0,
|
| 306 |
-
'accuracy': nn_accuracy,
|
| 307 |
-
'centroid_accuracy': centroid_accuracy
|
| 308 |
-
}
|
| 309 |
-
|
| 310 |
-
def compute_embedding_accuracy(self, embeddings, labels, similarities):
|
| 311 |
-
"""Compute classification accuracy using nearest neighbor in embedding space"""
|
| 312 |
-
correct_predictions = 0
|
| 313 |
-
total_predictions = len(labels)
|
| 314 |
-
|
| 315 |
-
for i in range(len(embeddings)):
|
| 316 |
-
true_label = labels[i]
|
| 317 |
-
|
| 318 |
-
# Find the most similar embedding (excluding itself)
|
| 319 |
-
similarities_row = similarities[i].copy()
|
| 320 |
-
similarities_row[i] = -1 # Exclude self-similarity
|
| 321 |
-
nearest_neighbor_idx = np.argmax(similarities_row)
|
| 322 |
-
predicted_label = labels[nearest_neighbor_idx]
|
| 323 |
-
|
| 324 |
-
if predicted_label == true_label:
|
| 325 |
-
correct_predictions += 1
|
| 326 |
-
|
| 327 |
-
return correct_predictions / total_predictions if total_predictions > 0 else 0
|
| 328 |
-
|
| 329 |
-
def compute_centroid_accuracy(self, embeddings, labels):
|
| 330 |
-
"""Compute classification accuracy using hierarchy centroids"""
|
| 331 |
-
# Create centroids for each hierarchy
|
| 332 |
-
unique_hierarchies = list(set(labels))
|
| 333 |
-
centroids = {}
|
| 334 |
-
|
| 335 |
-
for hierarchy in unique_hierarchies:
|
| 336 |
-
hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy]
|
| 337 |
-
hierarchy_embeddings = embeddings[hierarchy_indices]
|
| 338 |
-
centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0)
|
| 339 |
-
|
| 340 |
-
# Classify each embedding to nearest centroid
|
| 341 |
-
correct_predictions = 0
|
| 342 |
-
total_predictions = len(labels)
|
| 343 |
-
|
| 344 |
-
for i, embedding in enumerate(embeddings):
|
| 345 |
-
true_label = labels[i]
|
| 346 |
-
|
| 347 |
-
# Find closest centroid
|
| 348 |
-
best_similarity = -1
|
| 349 |
-
predicted_label = None
|
| 350 |
-
|
| 351 |
-
for hierarchy, centroid in centroids.items():
|
| 352 |
-
similarity = cosine_similarity([embedding], [centroid])[0][0]
|
| 353 |
-
if similarity > best_similarity:
|
| 354 |
-
best_similarity = similarity
|
| 355 |
-
predicted_label = hierarchy
|
| 356 |
-
|
| 357 |
-
if predicted_label == true_label:
|
| 358 |
-
correct_predictions += 1
|
| 359 |
-
|
| 360 |
-
return correct_predictions / total_predictions if total_predictions > 0 else 0
|
| 361 |
-
|
| 362 |
-
def create_confusion_matrix(self, true_labels, predicted_labels, title="Confusion Matrix"):
|
| 363 |
-
"""Create and plot confusion matrix"""
|
| 364 |
-
# Get unique labels
|
| 365 |
-
unique_labels = sorted(list(set(true_labels + predicted_labels)))
|
| 366 |
-
|
| 367 |
-
# Create confusion matrix
|
| 368 |
-
cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
|
| 369 |
-
|
| 370 |
-
# Calculate accuracy
|
| 371 |
-
accuracy = accuracy_score(true_labels, predicted_labels)
|
| 372 |
-
|
| 373 |
-
# Plot confusion matrix
|
| 374 |
-
plt.figure(figsize=(12, 10))
|
| 375 |
-
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
| 376 |
-
xticklabels=unique_labels, yticklabels=unique_labels)
|
| 377 |
-
plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)')
|
| 378 |
-
plt.ylabel('True Hierarchy')
|
| 379 |
-
plt.xlabel('Predicted Hierarchy')
|
| 380 |
-
plt.xticks(rotation=45)
|
| 381 |
-
plt.yticks(rotation=0)
|
| 382 |
-
plt.tight_layout()
|
| 383 |
-
|
| 384 |
-
return plt.gcf(), accuracy, cm
|
| 385 |
-
|
| 386 |
-
def predict_hierarchy_from_embeddings(self, embeddings, labels):
|
| 387 |
-
"""Predict hierarchy from embeddings using centroid-based classification"""
|
| 388 |
-
# Create hierarchy centroids from training data
|
| 389 |
-
unique_hierarchies = list(set(labels))
|
| 390 |
-
centroids = {}
|
| 391 |
-
|
| 392 |
-
for hierarchy in unique_hierarchies:
|
| 393 |
-
hierarchy_indices = [i for i, label in enumerate(labels) if label == hierarchy]
|
| 394 |
-
hierarchy_embeddings = embeddings[hierarchy_indices]
|
| 395 |
-
centroids[hierarchy] = np.mean(hierarchy_embeddings, axis=0)
|
| 396 |
-
|
| 397 |
-
# Predict hierarchy for all embeddings
|
| 398 |
-
predictions = []
|
| 399 |
-
|
| 400 |
-
for i, embedding in enumerate(embeddings):
|
| 401 |
-
# Find closest centroid
|
| 402 |
-
best_similarity = -1
|
| 403 |
-
predicted_hierarchy = None
|
| 404 |
-
|
| 405 |
-
for hierarchy, centroid in centroids.items():
|
| 406 |
-
similarity = cosine_similarity([embedding], [centroid])[0][0]
|
| 407 |
-
if similarity > best_similarity:
|
| 408 |
-
best_similarity = similarity
|
| 409 |
-
predicted_hierarchy = hierarchy
|
| 410 |
-
|
| 411 |
-
predictions.append(predicted_hierarchy)
|
| 412 |
-
|
| 413 |
-
return predictions
|
| 414 |
-
|
| 415 |
-
def evaluate_classification_performance(self, embeddings, labels, embedding_type="Embeddings"):
|
| 416 |
-
"""Evaluate classification performance and create confusion matrix"""
|
| 417 |
-
# Predict hierarchy
|
| 418 |
-
predictions = self.predict_hierarchy_from_embeddings(embeddings, labels)
|
| 419 |
-
|
| 420 |
-
# Calculate accuracy
|
| 421 |
-
accuracy = accuracy_score(labels, predictions)
|
| 422 |
-
|
| 423 |
-
# Calculate F1 scores
|
| 424 |
-
unique_labels = sorted(list(set(labels)))
|
| 425 |
-
f1_macro = f1_score(labels, predictions, labels=unique_labels, average='macro', zero_division=0)
|
| 426 |
-
f1_weighted = f1_score(labels, predictions, labels=unique_labels, average='weighted', zero_division=0)
|
| 427 |
-
f1_per_class = f1_score(labels, predictions, labels=unique_labels, average=None, zero_division=0)
|
| 428 |
-
|
| 429 |
-
# Create confusion matrix
|
| 430 |
-
fig, acc, cm = self.create_confusion_matrix(labels, predictions,
|
| 431 |
-
f"{embedding_type} - Hierarchy Classification")
|
| 432 |
-
|
| 433 |
-
# Generate classification report
|
| 434 |
-
report = classification_report(labels, predictions, labels=unique_labels,
|
| 435 |
-
target_names=unique_labels, output_dict=True)
|
| 436 |
-
|
| 437 |
-
return {
|
| 438 |
-
'accuracy': accuracy,
|
| 439 |
-
'f1_macro': f1_macro,
|
| 440 |
-
'f1_weighted': f1_weighted,
|
| 441 |
-
'f1_per_class': f1_per_class,
|
| 442 |
-
'predictions': predictions,
|
| 443 |
-
'confusion_matrix': cm,
|
| 444 |
-
'classification_report': report,
|
| 445 |
-
'figure': fig
|
| 446 |
-
}
|
| 447 |
-
|
| 448 |
-
def evaluate_dataset_with_baselines(self, dataframe, dataset_name="Dataset"):
|
| 449 |
-
"""Evaluate embeddings on a given dataset with both custom model and CLIP baseline"""
|
| 450 |
-
print(f"\n{'='*60}")
|
| 451 |
-
print(f"Evaluating {dataset_name}")
|
| 452 |
-
print(f"{'='*60}")
|
| 453 |
-
|
| 454 |
-
results = {}
|
| 455 |
-
|
| 456 |
-
# ===== CUSTOM MODEL EVALUATION =====
|
| 457 |
-
print(f"\n🔧 Evaluating Custom Model on {dataset_name}")
|
| 458 |
-
print("-" * 40)
|
| 459 |
-
|
| 460 |
-
# Create dataloader for custom model
|
| 461 |
-
custom_dataloader = self.create_dataloader(dataframe, batch_size=16)
|
| 462 |
-
|
| 463 |
-
# Evaluate text embeddings
|
| 464 |
-
text_embeddings, text_labels, texts = self.extract_custom_embeddings(custom_dataloader, 'text')
|
| 465 |
-
text_metrics = self.compute_similarity_metrics(text_embeddings, text_labels)
|
| 466 |
-
text_classification = self.evaluate_classification_performance(text_embeddings, text_labels, "Custom Text Embeddings")
|
| 467 |
-
text_metrics.update(text_classification)
|
| 468 |
-
results['custom_text'] = text_metrics
|
| 469 |
-
|
| 470 |
-
# Evaluate image embeddings
|
| 471 |
-
image_embeddings, image_labels, _ = self.extract_custom_embeddings(custom_dataloader, 'image')
|
| 472 |
-
image_metrics = self.compute_similarity_metrics(image_embeddings, image_labels)
|
| 473 |
-
image_classification = self.evaluate_classification_performance(image_embeddings, image_labels, "Custom Image Embeddings")
|
| 474 |
-
image_metrics.update(image_classification)
|
| 475 |
-
results['custom_image'] = image_metrics
|
| 476 |
-
|
| 477 |
-
# ===== CLIP BASELINE EVALUATION =====
|
| 478 |
-
print(f"\n🤗 Evaluating CLIP Baseline on {dataset_name}")
|
| 479 |
-
print("-" * 40)
|
| 480 |
-
|
| 481 |
-
# Create dataloader for CLIP
|
| 482 |
-
clip_dataloader = self.create_clip_dataloader(dataframe, batch_size=8) # Smaller batch for CLIP
|
| 483 |
-
|
| 484 |
-
# Extract data for CLIP
|
| 485 |
-
all_images = []
|
| 486 |
-
all_texts = []
|
| 487 |
-
all_labels = []
|
| 488 |
-
|
| 489 |
-
for batch in tqdm(clip_dataloader, desc="Preparing data for CLIP"):
|
| 490 |
-
images, texts, labels = batch
|
| 491 |
-
all_images.extend(images)
|
| 492 |
-
all_texts.extend(texts)
|
| 493 |
-
all_labels.extend(labels)
|
| 494 |
-
|
| 495 |
-
# Get CLIP embeddings
|
| 496 |
-
clip_image_embeddings, clip_text_embeddings = self.clip_evaluator.extract_clip_embeddings(all_images, all_texts)
|
| 497 |
-
|
| 498 |
-
# Evaluate CLIP text embeddings
|
| 499 |
-
clip_text_metrics = self.compute_similarity_metrics(clip_text_embeddings, all_labels)
|
| 500 |
-
clip_text_classification = self.evaluate_classification_performance(clip_text_embeddings, all_labels, "CLIP Text Embeddings")
|
| 501 |
-
clip_text_metrics.update(clip_text_classification)
|
| 502 |
-
results['clip_text'] = clip_text_metrics
|
| 503 |
-
|
| 504 |
-
# Evaluate CLIP image embeddings
|
| 505 |
-
clip_image_metrics = self.compute_similarity_metrics(clip_image_embeddings, all_labels)
|
| 506 |
-
clip_image_classification = self.evaluate_classification_performance(clip_image_embeddings, all_labels, "CLIP Image Embeddings")
|
| 507 |
-
clip_image_metrics.update(clip_image_classification)
|
| 508 |
-
results['clip_image'] = clip_image_metrics
|
| 509 |
-
|
| 510 |
-
# ===== PRINT COMPARISON RESULTS =====
|
| 511 |
-
print(f"\n{dataset_name} Results Comparison:")
|
| 512 |
-
print(f"Dataset size: {len(dataframe)} samples")
|
| 513 |
-
print("=" * 80)
|
| 514 |
-
print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<10} {'NN Acc':<8} {'Centroid Acc':<12} {'F1 Macro':<10}")
|
| 515 |
-
print("-" * 80)
|
| 516 |
-
|
| 517 |
-
for model_type in ['custom', 'clip']:
|
| 518 |
-
for emb_type in ['text', 'image']:
|
| 519 |
-
key = f"{model_type}_{emb_type}"
|
| 520 |
-
if key in results:
|
| 521 |
-
metrics = results[key]
|
| 522 |
-
model_name = "Custom Model" if model_type == 'custom' else "CLIP Baseline"
|
| 523 |
-
print(f"{model_name:<20} {emb_type.capitalize():<10} {metrics['separation_score']:<10.4f} {metrics['accuracy']*100:<8.1f}% {metrics['centroid_accuracy']*100:<12.1f}% {metrics['f1_macro']*100:<10.1f}%")
|
| 524 |
-
|
| 525 |
-
# ===== SAVE VISUALIZATIONS =====
|
| 526 |
-
os.makedirs(f'{self.directory}', exist_ok=True)
|
| 527 |
-
|
| 528 |
-
# Save confusion matrices
|
| 529 |
-
for key, metrics in results.items():
|
| 530 |
-
if 'figure' in metrics:
|
| 531 |
-
metrics['figure'].savefig(f'{self.directory}/{dataset_name.lower()}_{key}_confusion_matrix.png', dpi=300, bbox_inches='tight')
|
| 532 |
-
plt.close(metrics['figure'])
|
| 533 |
-
|
| 534 |
-
return results
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
class CLIPDataset(Dataset):
|
| 538 |
-
def __init__(self, dataframe):
|
| 539 |
-
self.dataframe = dataframe
|
| 540 |
-
# Use VALIDATION transforms (no augmentation)
|
| 541 |
-
self.transform = transforms.Compose([
|
| 542 |
-
transforms.Resize((224, 224)),
|
| 543 |
-
transforms.ToTensor(),
|
| 544 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 545 |
-
])
|
| 546 |
-
|
| 547 |
-
def __len__(self):
|
| 548 |
-
return len(self.dataframe)
|
| 549 |
-
|
| 550 |
-
def __getitem__(self, idx):
|
| 551 |
-
row = self.dataframe.iloc[idx]
|
| 552 |
-
|
| 553 |
-
# Handle image loading (same as HierarchyDataset)
|
| 554 |
-
if config.column_local_image_path in row.index and pd.notna(row[config.column_local_image_path]):
|
| 555 |
-
local_path = row[config.column_local_image_path]
|
| 556 |
-
try:
|
| 557 |
-
if os.path.exists(local_path):
|
| 558 |
-
image = Image.open(local_path).convert("RGB")
|
| 559 |
-
else:
|
| 560 |
-
print(f"⚠️ Local image not found: {local_path}")
|
| 561 |
-
image = Image.new('RGB', (224, 224), color='gray')
|
| 562 |
-
except Exception as e:
|
| 563 |
-
print(f"⚠️ Failed to load local image {idx}: {e}")
|
| 564 |
-
image = Image.new('RGB', (224, 224), color='gray')
|
| 565 |
-
elif isinstance(row[config.column_url_image], dict):
|
| 566 |
-
image = Image.open(BytesIO(row[config.column_url_image]['bytes'])).convert('RGB')
|
| 567 |
-
elif isinstance(row['image_url'], (list, np.ndarray)):
|
| 568 |
-
pixels = np.array(row[config.column_url_image]).reshape(28, 28)
|
| 569 |
-
image = Image.fromarray(pixels.astype(np.uint8)).convert("RGB")
|
| 570 |
-
elif isinstance(row[config.column_url_image], Image.Image):
|
| 571 |
-
# Handle PIL Image objects directly (for Fashion-MNIST)
|
| 572 |
-
image = row[config.column_url_image].convert("RGB")
|
| 573 |
-
else:
|
| 574 |
-
try:
|
| 575 |
-
response = requests.get(row[config.column_url_image], timeout=10)
|
| 576 |
-
response.raise_for_status()
|
| 577 |
-
image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 578 |
-
except Exception as e:
|
| 579 |
-
print(f"⚠️ Failed to load image {idx}: {e}")
|
| 580 |
-
image = Image.new('RGB', (224, 224), color='gray')
|
| 581 |
-
|
| 582 |
-
# Apply transforms
|
| 583 |
-
image_tensor = self.transform(image)
|
| 584 |
-
|
| 585 |
-
description = row[config.text_column]
|
| 586 |
-
hierarchy = row[config.hierarchy_column]
|
| 587 |
-
|
| 588 |
-
return image_tensor, description, hierarchy
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
def load_fashion_mnist_dataset(evaluator):
|
| 592 |
-
"""Load and prepare Fashion-MNIST test dataset"""
|
| 593 |
-
print("Loading Fashion-MNIST test dataset...")
|
| 594 |
-
|
| 595 |
-
# Load the dataset
|
| 596 |
-
df = pd.read_csv(config.fashion_mnist_test_path)
|
| 597 |
-
print(f"✅ Fashion-MNIST dataset loaded")
|
| 598 |
-
print(f"📊 Total samples: {len(df)}")
|
| 599 |
-
|
| 600 |
-
# Fashion-MNIST class labels mapping
|
| 601 |
-
fashion_mnist_labels = get_fashion_mnist_labels()
|
| 602 |
-
|
| 603 |
-
# Map labels to hierarchy classes
|
| 604 |
-
hierarchy_mapping = {
|
| 605 |
-
'T-shirt/top': 'top',
|
| 606 |
-
'Trouser': 'bottom',
|
| 607 |
-
'Pullover': 'top',
|
| 608 |
-
'Dress': 'dress',
|
| 609 |
-
'Coat': 'top',
|
| 610 |
-
'Sandal': 'shoes',
|
| 611 |
-
'Shirt': 'top',
|
| 612 |
-
'Sneaker': 'shoes',
|
| 613 |
-
'Bag': 'bag',
|
| 614 |
-
'Ankle boot': 'shoes'
|
| 615 |
-
}
|
| 616 |
-
|
| 617 |
-
# Apply label mapping
|
| 618 |
-
df['hierarchy'] = df['label'].map(fashion_mnist_labels).map(hierarchy_mapping)
|
| 619 |
-
|
| 620 |
-
# Filter to only include hierarchies that exist in our model
|
| 621 |
-
valid_hierarchies = df['hierarchy'].dropna().unique()
|
| 622 |
-
print(f"🎯 Valid hierarchies found: {sorted(valid_hierarchies)}")
|
| 623 |
-
print(f"🎯 Model hierarchies: {sorted(evaluator.hierarchy_classes)}")
|
| 624 |
-
|
| 625 |
-
# Filter to only include hierarchies that exist in our model
|
| 626 |
-
df = df[df['hierarchy'].isin(evaluator.hierarchy_classes)]
|
| 627 |
-
print(f"📊 After filtering to model hierarchies: {len(df)} samples")
|
| 628 |
-
|
| 629 |
-
if len(df) == 0:
|
| 630 |
-
print("❌ No samples left after hierarchy filtering.")
|
| 631 |
-
return pd.DataFrame()
|
| 632 |
-
|
| 633 |
-
# Keep pixel columns as they are (FashionMNISTDataset will handle them)
|
| 634 |
-
|
| 635 |
-
# Create text descriptions based on hierarchy
|
| 636 |
-
text_descriptions = {
|
| 637 |
-
'top': 'A top clothing item',
|
| 638 |
-
'bottom': 'A bottom clothing item',
|
| 639 |
-
'dress': 'A dress',
|
| 640 |
-
'shoes': 'A pair of shoes',
|
| 641 |
-
'bag': 'A bag'
|
| 642 |
-
}
|
| 643 |
-
|
| 644 |
-
df['text'] = df['hierarchy'].map(text_descriptions)
|
| 645 |
-
|
| 646 |
-
# Show sample of data
|
| 647 |
-
print(f"📝 Sample data:")
|
| 648 |
-
for i, (hierarchy, text) in enumerate(zip(df['hierarchy'].head(3), df['text'].head(3))):
|
| 649 |
-
print(f" {i+1}. [{hierarchy}] {text}")
|
| 650 |
-
|
| 651 |
-
df_test = df.copy()
|
| 652 |
-
|
| 653 |
-
print(f"📊 After sampling: {len(df_test)} samples")
|
| 654 |
-
print(f"📊 Samples per hierarchy:")
|
| 655 |
-
for hierarchy in sorted(df_test['hierarchy'].unique()):
|
| 656 |
-
count = len(df_test[df_test['hierarchy'] == hierarchy])
|
| 657 |
-
print(f" {hierarchy}: {count} samples")
|
| 658 |
-
|
| 659 |
-
# Create formatted dataset with proper column names
|
| 660 |
-
# Keep all pixel columns for FashionMNISTDataset
|
| 661 |
-
pixel_cols = [f'pixel{i}' for i in range(1, 785)]
|
| 662 |
-
fashion_mnist_formatted = df_test[['label'] + pixel_cols + ['text', 'hierarchy']].copy()
|
| 663 |
-
|
| 664 |
-
print(f"📊 Final dataset size: {len(fashion_mnist_formatted)} samples")
|
| 665 |
-
return fashion_mnist_formatted
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
def load_kagl_marqo_dataset(evaluator):
|
| 669 |
-
"""Load and prepare kagl dataset"""
|
| 670 |
-
from datasets import load_dataset
|
| 671 |
-
print("Loading kagl dataset...")
|
| 672 |
-
|
| 673 |
-
# Load the dataset
|
| 674 |
-
dataset = load_dataset("Marqo/KAGL")
|
| 675 |
-
df = dataset["data"].to_pandas()
|
| 676 |
-
print(f"✅ Dataset kagl loaded")
|
| 677 |
-
print(f"📊 Before filtering: {len(df)} samples")
|
| 678 |
-
print(f"📋 Available columns: {list(df.columns)}")
|
| 679 |
-
|
| 680 |
-
# Check available categories and map them to our hierarchy
|
| 681 |
-
print(f"🎨 Available categories: {sorted(df['category2'].unique())}")
|
| 682 |
-
# Apply mapping
|
| 683 |
-
df['hierarchy'] = df['category2'].str.lower()
|
| 684 |
-
df['hierarchy'] = df['hierarchy'].replace('bags', 'bag').replace('topwear', 'top').replace('flip flops', 'shoes').replace('sandal', 'shoes')
|
| 685 |
-
|
| 686 |
-
# Filter to only include valid hierarchies that exist in our model
|
| 687 |
-
valid_hierarchies = df['hierarchy'].dropna().unique()
|
| 688 |
-
print(f"🎯 Valid hierarchies found: {sorted(valid_hierarchies)}")
|
| 689 |
-
print(f"🎯 Model hierarchies: {sorted(evaluator.hierarchy_classes)}")
|
| 690 |
-
|
| 691 |
-
# Filter to only include hierarchies that exist in our model
|
| 692 |
-
df = df[df['hierarchy'].isin(evaluator.hierarchy_classes)]
|
| 693 |
-
print(f"📊 After filtering to model hierarchies: {len(df)} samples")
|
| 694 |
-
|
| 695 |
-
if len(df) == 0:
|
| 696 |
-
print("❌ No samples left after hierarchy filtering.")
|
| 697 |
-
return pd.DataFrame()
|
| 698 |
-
|
| 699 |
-
# Ensure we have text and image data
|
| 700 |
-
df = df.dropna(subset=['text', 'image'])
|
| 701 |
-
print(f"📊 After removing missing text/image: {len(df)} samples")
|
| 702 |
-
|
| 703 |
-
# Show sample of text data to verify quality
|
| 704 |
-
print(f"📝 Sample texts:")
|
| 705 |
-
for i, (text, hierarchy) in enumerate(zip(df['text'].head(3), df['hierarchy'].head(3))):
|
| 706 |
-
print(f" {i+1}. [{hierarchy}] {text[:100]}...")
|
| 707 |
-
|
| 708 |
-
print(f"📊 After sampling: {len(df_test)} samples")
|
| 709 |
-
print(f"📊 Samples per hierarchy:")
|
| 710 |
-
for hierarchy in sorted(df_test['hierarchy'].unique()):
|
| 711 |
-
count = len(df_test[df_test['hierarchy'] == hierarchy])
|
| 712 |
-
print(f" {hierarchy}: {count} samples")
|
| 713 |
-
|
| 714 |
-
# Create formatted dataset with proper column names
|
| 715 |
-
kagl_formatted = pd.DataFrame({
|
| 716 |
-
'image_url': df_test['image'],
|
| 717 |
-
'text': df_test['text'],
|
| 718 |
-
'hierarchy': df_test['hierarchy']
|
| 719 |
-
})
|
| 720 |
-
|
| 721 |
-
print(f"📊 Final dataset size: {len(kagl_formatted)} samples")
|
| 722 |
-
return kagl_formatted
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
if __name__ == "__main__":
|
| 726 |
-
device = config.device
|
| 727 |
-
directory = config.evaluation_directory
|
| 728 |
-
|
| 729 |
-
print(f"🚀 Starting evaluation with custom model: {config.hierarchy_model_path}")
|
| 730 |
-
print(f"🤗 Including CLIP baseline comparison")
|
| 731 |
-
|
| 732 |
-
evaluator = EmbeddingEvaluator(config.hierarchy_model_path, directory, device=device)
|
| 733 |
-
|
| 734 |
-
print(f"📊 Final hierarchy classes after initialization: {len(evaluator.vocab.hierarchy_classes)} classes")
|
| 735 |
-
|
| 736 |
-
# Evaluate on validation dataset (same subset as during training)
|
| 737 |
-
print("\n" + "="*60)
|
| 738 |
-
print("EVALUATING VALIDATION DATASET - CUSTOM MODEL vs CLIP BASELINE")
|
| 739 |
-
print("="*60)
|
| 740 |
-
val_results = evaluator.evaluate_dataset_with_baselines(evaluator.val_df, "Validation Dataset")
|
| 741 |
-
|
| 742 |
-
print("\n" + "="*60)
|
| 743 |
-
print("EVALUATING FASHION-MNIST TEST DATASET - CUSTOM MODEL vs CLIP BASELINE")
|
| 744 |
-
print("="*60)
|
| 745 |
-
df_fashion_mnist = load_fashion_mnist_dataset(evaluator)
|
| 746 |
-
if len(df_fashion_mnist) > 0:
|
| 747 |
-
fashion_mnist_results = evaluator.evaluate_dataset_with_baselines(df_fashion_mnist, "Fashion-MNIST Test Dataset")
|
| 748 |
-
else:
|
| 749 |
-
fashion_mnist_results = {}
|
| 750 |
-
|
| 751 |
-
print("\n" + "="*60)
|
| 752 |
-
print("EVALUATING kagl MARQO DATASET - CUSTOM MODEL vs CLIP BASELINE")
|
| 753 |
-
print("="*60)
|
| 754 |
-
df_kagl_marqo = load_kagl_marqo_dataset(evaluator)
|
| 755 |
-
if len(df_kagl_marqo) > 0:
|
| 756 |
-
kagl_results = evaluator.evaluate_dataset_with_baselines(df_kagl_marqo, "kagl Marqo Dataset")
|
| 757 |
-
else:
|
| 758 |
-
kagl_results = {}
|
| 759 |
-
|
| 760 |
-
# Compare results
|
| 761 |
-
print(f"\n{'='*80}")
|
| 762 |
-
print("FINAL EVALUATION SUMMARY - CUSTOM MODEL vs CLIP BASELINE")
|
| 763 |
-
print(f"{'='*80}")
|
| 764 |
-
|
| 765 |
-
print("\n🔍 VALIDATION DATASET RESULTS:")
|
| 766 |
-
print(f"Dataset size: {len(evaluator.val_df)} samples")
|
| 767 |
-
print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<12} {'NN Acc':<10} {'Centroid Acc':<12} {'F1 Macro':<10}")
|
| 768 |
-
print("-" * 80)
|
| 769 |
-
|
| 770 |
-
for model_type in ['custom', 'clip']:
|
| 771 |
-
for emb_type in ['text', 'image']:
|
| 772 |
-
key = f"{model_type}_{emb_type}"
|
| 773 |
-
if key in val_results:
|
| 774 |
-
metrics = val_results[key]
|
| 775 |
-
model_name = "Custom Model" if model_type == 'custom' else "CLIP Baseline"
|
| 776 |
-
print(f"{model_name:<20} {emb_type.capitalize():<10} {metrics['separation_score']:<12.4f} {metrics['accuracy']*100:<10.1f}% {metrics['centroid_accuracy']*100:<12.1f}% {metrics['f1_macro']*100:<10.1f}%")
|
| 777 |
-
|
| 778 |
-
if fashion_mnist_results:
|
| 779 |
-
print("\n👗 FASHION-MNIST TEST DATASET RESULTS:")
|
| 780 |
-
print(f"Dataset size: {len(df_fashion_mnist)} samples")
|
| 781 |
-
print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<12} {'NN Acc':<10} {'Centroid Acc':<12} {'F1 Macro':<10}")
|
| 782 |
-
print("-" * 80)
|
| 783 |
-
|
| 784 |
-
for model_type in ['custom', 'clip']:
|
| 785 |
-
for emb_type in ['text', 'image']:
|
| 786 |
-
key = f"{model_type}_{emb_type}"
|
| 787 |
-
if key in fashion_mnist_results:
|
| 788 |
-
metrics = fashion_mnist_results[key]
|
| 789 |
-
model_name = "Custom Model" if model_type == 'custom' else "CLIP Baseline"
|
| 790 |
-
print(f"{model_name:<20} {emb_type.capitalize():<10} {metrics['separation_score']:<12.4f} {metrics['accuracy']*100:<10.1f}% {metrics['centroid_accuracy']*100:<12.1f}% {metrics['f1_macro']*100:<10.1f}%")
|
| 791 |
-
|
| 792 |
-
if kagl_results:
|
| 793 |
-
print("\n🌐 kagl MARQO DATASET RESULTS:")
|
| 794 |
-
print(f"Dataset size: {len(df_kagl_marqo)} samples")
|
| 795 |
-
print(f"{'Model':<20} {'Embedding':<10} {'Sep Score':<12} {'NN Acc':<10} {'Centroid Acc':<12} {'F1 Macro':<10}")
|
| 796 |
-
print("-" * 80)
|
| 797 |
-
|
| 798 |
-
for model_type in ['custom', 'clip']:
|
| 799 |
-
for emb_type in ['text', 'image']:
|
| 800 |
-
key = f"{model_type}_{emb_type}"
|
| 801 |
-
if key in kagl_results:
|
| 802 |
-
metrics = kagl_results[key]
|
| 803 |
-
model_name = "Custom Model" if model_type == 'custom' else "CLIP Baseline"
|
| 804 |
-
print(f"{model_name:<20} {emb_type.capitalize():<10} {metrics['separation_score']:<12.4f} {metrics['accuracy']*100:<10.1f}% {metrics['centroid_accuracy']*100:<12.1f}% {metrics['f1_macro']*100:<10.1f}%")
|
| 805 |
-
|
| 806 |
-
print(f"\n✅ Evaluation completed! Check '{directory}/' for visualization files.")
|
| 807 |
-
print(f"📊 Custom model hierarchy classes: {len(evaluator.vocab.hierarchy_classes)} classes")
|
| 808 |
-
print(f"🤗 CLIP baseline comparison included")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Evaluation/main_model_evaluation.py
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Evaluation/tsne_images.py
DELETED
|
@@ -1,569 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Outputs several t-SNE visualizations with color and hierarchy overlays to
|
| 4 |
-
verify that the main model separates colors well inside each hierarchy group.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
import math
|
| 8 |
-
|
| 9 |
-
import matplotlib.pyplot as plt
|
| 10 |
-
import numpy as np
|
| 11 |
-
import pandas as pd
|
| 12 |
-
import seaborn as sns
|
| 13 |
-
import torch
|
| 14 |
-
from matplotlib.patches import Polygon
|
| 15 |
-
from PIL import Image
|
| 16 |
-
from sklearn.manifold import TSNE
|
| 17 |
-
from sklearn.metrics import (
|
| 18 |
-
silhouette_score,
|
| 19 |
-
davies_bouldin_score,
|
| 20 |
-
calinski_harabasz_score,
|
| 21 |
-
)
|
| 22 |
-
from sklearn.preprocessing import normalize
|
| 23 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 24 |
-
from torch.utils.data import DataLoader, Dataset
|
| 25 |
-
from torchvision import transforms
|
| 26 |
-
from tqdm import tqdm
|
| 27 |
-
from transformers import CLIPModel as CLIPModel_transformers, CLIPProcessor
|
| 28 |
-
|
| 29 |
-
try:
|
| 30 |
-
from scipy.spatial import ConvexHull
|
| 31 |
-
except ImportError:
|
| 32 |
-
ConvexHull = None
|
| 33 |
-
|
| 34 |
-
from config import (
|
| 35 |
-
color_column,
|
| 36 |
-
color_emb_dim,
|
| 37 |
-
column_local_image_path,
|
| 38 |
-
device,
|
| 39 |
-
hierarchy_column,
|
| 40 |
-
hierarchy_emb_dim,
|
| 41 |
-
images_dir,
|
| 42 |
-
local_dataset_path,
|
| 43 |
-
main_model_path,
|
| 44 |
-
)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
class ImageDataset(Dataset):
|
| 48 |
-
"""Lightweight dataset to load local images along with colors and hierarchies."""
|
| 49 |
-
|
| 50 |
-
def __init__(self, dataframe: pd.DataFrame, root_dir: str):
|
| 51 |
-
self.df = dataframe.reset_index(drop=True)
|
| 52 |
-
self.root_dir = root_dir
|
| 53 |
-
self.transform = transforms.Compose(
|
| 54 |
-
[
|
| 55 |
-
transforms.Resize((224, 224)),
|
| 56 |
-
transforms.ToTensor(),
|
| 57 |
-
transforms.Normalize(
|
| 58 |
-
mean=[0.485, 0.456, 0.406],
|
| 59 |
-
std=[0.229, 0.224, 0.225],
|
| 60 |
-
),
|
| 61 |
-
]
|
| 62 |
-
)
|
| 63 |
-
|
| 64 |
-
def __len__(self):
|
| 65 |
-
return len(self.df)
|
| 66 |
-
|
| 67 |
-
def __getitem__(self, idx):
|
| 68 |
-
row = self.df.iloc[idx]
|
| 69 |
-
img_path = row[column_local_image_path]
|
| 70 |
-
image = Image.open(img_path).convert("RGB")
|
| 71 |
-
image = self.transform(image)
|
| 72 |
-
color = row[color_column]
|
| 73 |
-
hierarchy = row[hierarchy_column]
|
| 74 |
-
return image, color, hierarchy
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def load_main_model():
|
| 79 |
-
"""Load the main model with the trained weights."""
|
| 80 |
-
checkpoint = torch.load(main_model_path, map_location=device)
|
| 81 |
-
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
| 82 |
-
model = CLIPModel_transformers.from_pretrained(
|
| 83 |
-
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
| 84 |
-
)
|
| 85 |
-
model.load_state_dict(state_dict)
|
| 86 |
-
model.to(device)
|
| 87 |
-
model.eval()
|
| 88 |
-
# Load processor for text tokenization
|
| 89 |
-
processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
|
| 90 |
-
return model, processor
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
def load_clip_baseline():
|
| 94 |
-
"""Load the CLIP baseline model from transformers."""
|
| 95 |
-
print("🤗 Loading CLIP baseline model from transformers...")
|
| 96 |
-
clip_model = CLIPModel_transformers.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
| 97 |
-
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 98 |
-
clip_model.eval()
|
| 99 |
-
print("✅ CLIP baseline model loaded successfully")
|
| 100 |
-
return clip_model, clip_processor
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def enforce_min_hierarchy_samples(df, min_per_hierarchy):
|
| 104 |
-
"""Filter out hierarchy groups with fewer than min_per_hierarchy rows."""
|
| 105 |
-
if not min_per_hierarchy or min_per_hierarchy <= 0:
|
| 106 |
-
return df
|
| 107 |
-
counts = df[hierarchy_column].value_counts()
|
| 108 |
-
keep_values = counts[counts >= min_per_hierarchy].index
|
| 109 |
-
filtered = df[df[hierarchy_column].isin(keep_values)].reset_index(drop=True)
|
| 110 |
-
return filtered
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
def prepare_dataframe(df, sample_size, per_color_limit, min_per_hierarchy=None):
|
| 114 |
-
"""Subsample the dataframe to speed up the t-SNE."""
|
| 115 |
-
if per_color_limit and per_color_limit > 0:
|
| 116 |
-
df_limited = (
|
| 117 |
-
df.groupby(color_column)
|
| 118 |
-
.apply(lambda g: g.sample(min(len(g), per_color_limit), random_state=42))
|
| 119 |
-
.reset_index(drop=True)
|
| 120 |
-
)
|
| 121 |
-
else:
|
| 122 |
-
df_limited = df
|
| 123 |
-
|
| 124 |
-
if sample_size and 0 < sample_size < len(df_limited):
|
| 125 |
-
df_limited = df_limited.sample(sample_size, random_state=42).reset_index(
|
| 126 |
-
drop=True
|
| 127 |
-
)
|
| 128 |
-
df_limited = enforce_min_hierarchy_samples(df_limited, min_per_hierarchy)
|
| 129 |
-
return df_limited
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def compute_embeddings(model, dataloader):
|
| 133 |
-
"""Extract color, hierarchy, and combined embeddings."""
|
| 134 |
-
color_embeddings = []
|
| 135 |
-
hierarchy_embeddings = []
|
| 136 |
-
color_labels = []
|
| 137 |
-
hierarchy_labels = []
|
| 138 |
-
with torch.no_grad():
|
| 139 |
-
for images, colors, hierarchies in tqdm(
|
| 140 |
-
dataloader, desc="Extracting embeddings"
|
| 141 |
-
):
|
| 142 |
-
images = images.to(device)
|
| 143 |
-
if images.shape[1] == 1: # safety in case
|
| 144 |
-
images = images.expand(-1, 3, -1, -1)
|
| 145 |
-
image_embeds = model.get_image_features(pixel_values=images)
|
| 146 |
-
color_part = image_embeds[:, :color_emb_dim]
|
| 147 |
-
hierarchy_part = image_embeds[
|
| 148 |
-
:, color_emb_dim : color_emb_dim + hierarchy_emb_dim
|
| 149 |
-
]
|
| 150 |
-
color_embeddings.append(color_part.cpu().numpy())
|
| 151 |
-
hierarchy_embeddings.append(hierarchy_part.cpu().numpy())
|
| 152 |
-
color_labels.extend(colors)
|
| 153 |
-
hierarchy_labels.extend(hierarchies)
|
| 154 |
-
return (
|
| 155 |
-
np.concatenate(color_embeddings, axis=0),
|
| 156 |
-
np.concatenate(hierarchy_embeddings, axis=0),
|
| 157 |
-
color_labels,
|
| 158 |
-
hierarchy_labels,
|
| 159 |
-
)
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
def compute_clip_embeddings(clip_model, clip_processor, dataloader):
|
| 163 |
-
"""Extract CLIP baseline embeddings (full image embeddings, not separated)."""
|
| 164 |
-
all_embeddings = []
|
| 165 |
-
color_labels = []
|
| 166 |
-
hierarchy_labels = []
|
| 167 |
-
|
| 168 |
-
with torch.no_grad():
|
| 169 |
-
for images, colors, hierarchies in tqdm(
|
| 170 |
-
dataloader, desc="Extracting CLIP embeddings"
|
| 171 |
-
):
|
| 172 |
-
batch_embeddings = []
|
| 173 |
-
for i in range(images.shape[0]):
|
| 174 |
-
# Get single image from batch
|
| 175 |
-
image_tensor = images[i] # Shape: (3, 224, 224)
|
| 176 |
-
|
| 177 |
-
# Denormalize on CPU (safer for PIL conversion)
|
| 178 |
-
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
| 179 |
-
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| 180 |
-
image_tensor = image_tensor * std + mean
|
| 181 |
-
image_tensor = torch.clamp(image_tensor, 0, 1)
|
| 182 |
-
|
| 183 |
-
# Convert to PIL Image (must be on CPU)
|
| 184 |
-
image_pil = transforms.ToPILImage()(image_tensor.cpu())
|
| 185 |
-
|
| 186 |
-
# Process with CLIP (using empty text since we only need image embeddings)
|
| 187 |
-
inputs = clip_processor(
|
| 188 |
-
text="",
|
| 189 |
-
images=image_pil,
|
| 190 |
-
return_tensors="pt",
|
| 191 |
-
padding=True
|
| 192 |
-
).to(device)
|
| 193 |
-
|
| 194 |
-
outputs = clip_model(**inputs)
|
| 195 |
-
# Get normalized image embeddings
|
| 196 |
-
image_emb = outputs.image_embeds / outputs.image_embeds.norm(p=2, dim=-1, keepdim=True)
|
| 197 |
-
batch_embeddings.append(image_emb.cpu().numpy())
|
| 198 |
-
|
| 199 |
-
all_embeddings.append(np.vstack(batch_embeddings))
|
| 200 |
-
color_labels.extend(colors)
|
| 201 |
-
hierarchy_labels.extend(hierarchies)
|
| 202 |
-
|
| 203 |
-
# For CLIP, we use the full embeddings for all visualizations
|
| 204 |
-
# (no separation into color/hierarchy dimensions)
|
| 205 |
-
full_embeddings = np.concatenate(all_embeddings, axis=0)
|
| 206 |
-
return (
|
| 207 |
-
full_embeddings, # color_embeddings (using full CLIP embeddings)
|
| 208 |
-
full_embeddings, # hierarchy_embeddings (using full CLIP embeddings)
|
| 209 |
-
full_embeddings, # color_hier_embeddings (using full CLIP embeddings)
|
| 210 |
-
color_labels,
|
| 211 |
-
hierarchy_labels,
|
| 212 |
-
)
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
def compute_dunn_index(embeddings, labels):
|
| 216 |
-
"""
|
| 217 |
-
Compute the Dunn Index for clustering evaluation.
|
| 218 |
-
|
| 219 |
-
The Dunn Index is the ratio of the minimum inter-cluster distance
|
| 220 |
-
to the maximum intra-cluster distance. Higher values indicate better clustering.
|
| 221 |
-
|
| 222 |
-
Args:
|
| 223 |
-
embeddings: Array of embeddings [N, embed_dim]
|
| 224 |
-
labels: Array of cluster labels [N]
|
| 225 |
-
|
| 226 |
-
Returns:
|
| 227 |
-
Dunn Index value (float) or None if calculation fails
|
| 228 |
-
"""
|
| 229 |
-
try:
|
| 230 |
-
unique_labels = np.unique(labels)
|
| 231 |
-
if len(unique_labels) < 2:
|
| 232 |
-
return None
|
| 233 |
-
|
| 234 |
-
# Calculate intra-cluster distances (maximum within each cluster)
|
| 235 |
-
max_intra_cluster_dist = 0
|
| 236 |
-
for label in unique_labels:
|
| 237 |
-
cluster_points = embeddings[labels == label]
|
| 238 |
-
if len(cluster_points) > 1:
|
| 239 |
-
# Calculate pairwise distances within cluster
|
| 240 |
-
from scipy.spatial.distance import pdist
|
| 241 |
-
intra_dists = pdist(cluster_points, metric='euclidean')
|
| 242 |
-
if len(intra_dists) > 0:
|
| 243 |
-
max_intra = np.max(intra_dists)
|
| 244 |
-
max_intra_cluster_dist = max(max_intra_cluster_dist, max_intra)
|
| 245 |
-
|
| 246 |
-
if max_intra_cluster_dist == 0:
|
| 247 |
-
return None
|
| 248 |
-
|
| 249 |
-
# Calculate inter-cluster distances (minimum between clusters)
|
| 250 |
-
min_inter_cluster_dist = float('inf')
|
| 251 |
-
for i, label1 in enumerate(unique_labels):
|
| 252 |
-
for label2 in unique_labels[i+1:]:
|
| 253 |
-
cluster1_points = embeddings[labels == label1]
|
| 254 |
-
cluster2_points = embeddings[labels == label2]
|
| 255 |
-
|
| 256 |
-
# Calculate distances between clusters
|
| 257 |
-
from scipy.spatial.distance import cdist
|
| 258 |
-
inter_dists = cdist(cluster1_points, cluster2_points, metric='euclidean')
|
| 259 |
-
min_inter = np.min(inter_dists)
|
| 260 |
-
min_inter_cluster_dist = min(min_inter_cluster_dist, min_inter)
|
| 261 |
-
|
| 262 |
-
if min_inter_cluster_dist == float('inf'):
|
| 263 |
-
return None
|
| 264 |
-
|
| 265 |
-
# Dunn Index = minimum inter-cluster distance / maximum intra-cluster distance
|
| 266 |
-
dunn_index = min_inter_cluster_dist / max_intra_cluster_dist
|
| 267 |
-
return float(dunn_index)
|
| 268 |
-
except Exception as e:
|
| 269 |
-
print(f"⚠️ Error computing Dunn Index: {e}")
|
| 270 |
-
return None
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
def build_color_map(labels, prefer_true_colors=False):
|
| 274 |
-
"""Build a color mapping for labels."""
|
| 275 |
-
unique_labels = sorted(set(labels))
|
| 276 |
-
palette = sns.color_palette("husl", len(unique_labels))
|
| 277 |
-
return {label: palette[idx] for idx, label in enumerate(unique_labels)}
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
def compute_color_similarity_matrix(embeddings, colors, title="Color similarity (image embeddings)"):
|
| 281 |
-
"""Compute and visualize similarity matrix between color centroids."""
|
| 282 |
-
# Use only the colors from the reference heatmap
|
| 283 |
-
reference_colors = ['red', 'pink', 'blue', 'green', 'aqua', 'lime', 'yellow', 'orange',
|
| 284 |
-
'purple', 'brown', 'gray', 'black', 'white']
|
| 285 |
-
# Map 'yelloworange' to 'yellow' or 'orange' if needed
|
| 286 |
-
color_mapping = {
|
| 287 |
-
'yelloworange': 'yellow',
|
| 288 |
-
'grey': 'gray' # Handle grey/gray variation
|
| 289 |
-
}
|
| 290 |
-
|
| 291 |
-
# Filter to only include colors that are in the reference list
|
| 292 |
-
filtered_colors = []
|
| 293 |
-
filtered_embeddings = []
|
| 294 |
-
for i, color in enumerate(colors):
|
| 295 |
-
# Normalize color name
|
| 296 |
-
normalized_color = color_mapping.get(color.lower(), color.lower())
|
| 297 |
-
if normalized_color in reference_colors:
|
| 298 |
-
filtered_colors.append(normalized_color)
|
| 299 |
-
filtered_embeddings.append(embeddings[i])
|
| 300 |
-
|
| 301 |
-
if len(filtered_colors) == 0:
|
| 302 |
-
print("⚠️ No matching colors found in reference list")
|
| 303 |
-
return None
|
| 304 |
-
|
| 305 |
-
# Use only unique colors from reference that exist in data
|
| 306 |
-
unique_colors = sorted([c for c in reference_colors if c in filtered_colors])
|
| 307 |
-
|
| 308 |
-
# Convert to numpy arrays
|
| 309 |
-
filtered_embeddings = np.array(filtered_embeddings)
|
| 310 |
-
filtered_colors = np.array(filtered_colors)
|
| 311 |
-
|
| 312 |
-
# Compute centroids for each color
|
| 313 |
-
centroids = {}
|
| 314 |
-
for color in unique_colors:
|
| 315 |
-
color_mask = np.array([c == color for c in filtered_colors])
|
| 316 |
-
if color_mask.sum() > 0:
|
| 317 |
-
centroids[color] = np.mean(filtered_embeddings[color_mask], axis=0)
|
| 318 |
-
|
| 319 |
-
# Compute similarity matrix
|
| 320 |
-
similarity_matrix = np.zeros((len(unique_colors), len(unique_colors)))
|
| 321 |
-
for i, color1 in enumerate(unique_colors):
|
| 322 |
-
for j, color2 in enumerate(unique_colors):
|
| 323 |
-
if i == j:
|
| 324 |
-
similarity_matrix[i, j] = 1.0
|
| 325 |
-
else:
|
| 326 |
-
if color1 in centroids and color2 in centroids:
|
| 327 |
-
similarity = cosine_similarity(
|
| 328 |
-
[centroids[color1]],
|
| 329 |
-
[centroids[color2]]
|
| 330 |
-
)[0][0]
|
| 331 |
-
similarity_matrix[i, j] = similarity
|
| 332 |
-
|
| 333 |
-
# Create heatmap
|
| 334 |
-
plt.figure(figsize=(12, 10))
|
| 335 |
-
sns.heatmap(
|
| 336 |
-
similarity_matrix,
|
| 337 |
-
annot=True,
|
| 338 |
-
fmt='.2f',
|
| 339 |
-
cmap='RdYlBu_r',
|
| 340 |
-
xticklabels=unique_colors,
|
| 341 |
-
yticklabels=unique_colors,
|
| 342 |
-
square=True,
|
| 343 |
-
cbar_kws={'label': 'Cosine Similarity'},
|
| 344 |
-
linewidths=0.5,
|
| 345 |
-
vmin=-0.6,
|
| 346 |
-
vmax=1.0
|
| 347 |
-
)
|
| 348 |
-
|
| 349 |
-
plt.title(title, fontsize=16, fontweight='bold', pad=20)
|
| 350 |
-
plt.xlabel('Colors', fontsize=14, fontweight='bold')
|
| 351 |
-
plt.ylabel('Colors', fontsize=14, fontweight='bold')
|
| 352 |
-
plt.xticks(rotation=45, ha='right')
|
| 353 |
-
plt.yticks(rotation=0)
|
| 354 |
-
plt.tight_layout()
|
| 355 |
-
|
| 356 |
-
output_path = "color_similarity_image_embeddings.png"
|
| 357 |
-
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
| 358 |
-
plt.close()
|
| 359 |
-
print(f"✅ Color similarity heatmap saved: {output_path}")
|
| 360 |
-
|
| 361 |
-
return similarity_matrix
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
def run_tsne(embeddings,legend_labels,output_path,perplexity,title,scatter_color_labels=None,prefer_true_colors=False):
|
| 365 |
-
"""Calculate and plot a t-SNE projection."""
|
| 366 |
-
tsne = TSNE(
|
| 367 |
-
n_components=2,
|
| 368 |
-
perplexity=perplexity,
|
| 369 |
-
init="pca",
|
| 370 |
-
learning_rate="auto",
|
| 371 |
-
random_state=42,
|
| 372 |
-
)
|
| 373 |
-
reduced = tsne.fit_transform(embeddings)
|
| 374 |
-
|
| 375 |
-
label_array = np.array(legend_labels)
|
| 376 |
-
color_labels = (
|
| 377 |
-
np.array(scatter_color_labels) if scatter_color_labels is not None else label_array
|
| 378 |
-
)
|
| 379 |
-
|
| 380 |
-
# Calculate silhouette scores
|
| 381 |
-
unique_labels_list = sorted(set(label_array))
|
| 382 |
-
if len(unique_labels_list) > 1 and len(label_array) > 1:
|
| 383 |
-
# Convert labels to numeric indices for silhouette_score
|
| 384 |
-
label_to_idx = {label: idx for idx, label in enumerate(unique_labels_list)}
|
| 385 |
-
numeric_labels = np.array([label_to_idx[label] for label in label_array])
|
| 386 |
-
|
| 387 |
-
# Calculate in original embedding space (ground truth - measures real separation)
|
| 388 |
-
silhouette = silhouette_score(embeddings, numeric_labels, metric='euclidean')
|
| 389 |
-
davies_bouldin = davies_bouldin_score(embeddings, numeric_labels)
|
| 390 |
-
calinski_harabasz = calinski_harabasz_score(embeddings, numeric_labels)
|
| 391 |
-
dunn = compute_dunn_index(embeddings, numeric_labels)
|
| 392 |
-
|
| 393 |
-
else:
|
| 394 |
-
silhouette = None
|
| 395 |
-
davies_bouldin = None
|
| 396 |
-
calinski_harabasz = None
|
| 397 |
-
dunn = None
|
| 398 |
-
|
| 399 |
-
# Helpful reference for the reported clustering indices:
|
| 400 |
-
# • Silhouette Score ∈ [-1, 1] — closer to 1 means points fit their cluster well, 0 means overlap, < 0 suggests misassignment.
|
| 401 |
-
# • Davies–Bouldin Index ∈ [0, +∞) — lower is better; quantifies average similarity between clusters relative to their size.
|
| 402 |
-
# • Calinski–Harabasz Index ∈ [0, +∞) — higher is better; ratio of between-cluster dispersion to within-cluster dispersion.
|
| 403 |
-
# • Dunn Index ∈ [0, +∞) — higher is better; compares the tightest cluster diameter to the closest distance between clusters.
|
| 404 |
-
|
| 405 |
-
# Build color map for visualization
|
| 406 |
-
color_map = build_color_map(color_labels, prefer_true_colors=prefer_true_colors)
|
| 407 |
-
color_series = np.array([color_map[label] for label in color_labels])
|
| 408 |
-
|
| 409 |
-
plt.figure(figsize=(10, 8))
|
| 410 |
-
unique_labels = sorted(set(label_array))
|
| 411 |
-
for label in unique_labels:
|
| 412 |
-
mask = label_array == label
|
| 413 |
-
if 'color' in title:
|
| 414 |
-
c = label
|
| 415 |
-
else:
|
| 416 |
-
c = color_series[mask]
|
| 417 |
-
plt.scatter(
|
| 418 |
-
reduced[mask, 0],
|
| 419 |
-
reduced[mask, 1],
|
| 420 |
-
c=c,
|
| 421 |
-
s=15,
|
| 422 |
-
alpha=0.8,
|
| 423 |
-
label=label,
|
| 424 |
-
)
|
| 425 |
-
|
| 426 |
-
# Add silhouette score to title
|
| 427 |
-
if silhouette is not None:
|
| 428 |
-
title_with_score = f"{title}\n(t-SNE Silhouette: {silhouette:.3f} | Davies-Bouldin: {davies_bouldin:.3f} | Calinski-Harabasz: {calinski_harabasz:.3f} | Dunn: {dunn:.3f})"
|
| 429 |
-
else:
|
| 430 |
-
title_with_score = title
|
| 431 |
-
|
| 432 |
-
plt.title(title_with_score)
|
| 433 |
-
plt.xlabel("t-SNE 1")
|
| 434 |
-
plt.ylabel("t-SNE 2")
|
| 435 |
-
plt.legend(
|
| 436 |
-
bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small", frameon=False
|
| 437 |
-
)
|
| 438 |
-
plt.tight_layout()
|
| 439 |
-
plt.savefig(output_path, dpi=300)
|
| 440 |
-
plt.close()
|
| 441 |
-
print(f"✅ Figure saved in {output_path}")
|
| 442 |
-
print(f" 📊 t-SNE space: {silhouette:.3f} (matches visualization) | Davies-Bouldin: {davies_bouldin:.3f} | Calinski-Harabasz: {calinski_harabasz:.3f} | Dunn: {dunn:.3f}")
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
def filter_valid_rows(dataframe: pd.DataFrame) -> pd.DataFrame:
|
| 447 |
-
"""Keep only rows with valid local image paths and colors."""
|
| 448 |
-
dataframe = dataframe[dataframe['color'] != 'unknown'].copy()
|
| 449 |
-
df = dataframe.dropna(
|
| 450 |
-
subset=[column_local_image_path, color_column, hierarchy_column]
|
| 451 |
-
).copy()
|
| 452 |
-
mask = df[column_local_image_path].apply(lambda x: isinstance(x, str) and len(x.strip()) > 0)
|
| 453 |
-
return df[mask].reset_index(drop=True)
|
| 454 |
-
|
| 455 |
-
if __name__ == "__main__":
|
| 456 |
-
sample_size = None
|
| 457 |
-
per_color_limit = 500
|
| 458 |
-
min_per_hierarchy = 200
|
| 459 |
-
batch_size = 32
|
| 460 |
-
perplexity = 30
|
| 461 |
-
output_color = "tsne_color_space.png"
|
| 462 |
-
output_hierarchy = "tsne_hierarchy_space.png"
|
| 463 |
-
|
| 464 |
-
print("📥 Loading the dataset...")
|
| 465 |
-
df = pd.read_csv("data/data_with_local_paths.csv")
|
| 466 |
-
df = filter_valid_rows(df)
|
| 467 |
-
print(f"Total len if the dataset: {len(df)}")
|
| 468 |
-
df = prepare_dataframe(df, sample_size, per_color_limit, min_per_hierarchy)
|
| 469 |
-
print(f"✅ {len(df)} samples will be used for the t-SNE")
|
| 470 |
-
print(f"Number of colors in the dataset: {len(df['color'].unique())}")
|
| 471 |
-
print(f"Colors in the dataset: {df['color'].unique()}")
|
| 472 |
-
dataset = ImageDataset(df, images_dir)
|
| 473 |
-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
| 474 |
-
|
| 475 |
-
# 2) Loading the models
|
| 476 |
-
print("⚙️ Loading the main model...")
|
| 477 |
-
model, processor = load_main_model()
|
| 478 |
-
|
| 479 |
-
print("⚙️ Loading CLIP baseline model...")
|
| 480 |
-
clip_model, clip_processor = load_clip_baseline()
|
| 481 |
-
|
| 482 |
-
# 3) Extracting the embeddings
|
| 483 |
-
print("🎯 Extracting the embeddings...")
|
| 484 |
-
|
| 485 |
-
(
|
| 486 |
-
color_embeddings,
|
| 487 |
-
hierarchy_embeddings,
|
| 488 |
-
colors,
|
| 489 |
-
hierarchies,
|
| 490 |
-
) = compute_embeddings(model, dataloader)
|
| 491 |
-
|
| 492 |
-
# 4) Calculating the t-SNE
|
| 493 |
-
print("🌀 Calculating the color t-SNE...")
|
| 494 |
-
run_tsne(
|
| 495 |
-
color_embeddings,
|
| 496 |
-
colors,
|
| 497 |
-
output_color,
|
| 498 |
-
perplexity,
|
| 499 |
-
"t-SNE of the color embeddings of the main model",
|
| 500 |
-
scatter_color_labels=colors,
|
| 501 |
-
prefer_true_colors=True,
|
| 502 |
-
)
|
| 503 |
-
|
| 504 |
-
print("🎨 Computing color similarity matrix from image embeddings...")
|
| 505 |
-
compute_color_similarity_matrix(
|
| 506 |
-
color_embeddings,
|
| 507 |
-
colors,
|
| 508 |
-
title="Color similarity (image embeddings - main model)"
|
| 509 |
-
)
|
| 510 |
-
|
| 511 |
-
print("🌀 Calculating the hierarchy t-SNE...")
|
| 512 |
-
run_tsne(
|
| 513 |
-
hierarchy_embeddings,
|
| 514 |
-
hierarchies,
|
| 515 |
-
output_hierarchy,
|
| 516 |
-
perplexity,
|
| 517 |
-
"t-SNE of the hierarchy embeddings of the main model",
|
| 518 |
-
scatter_color_labels=hierarchies,
|
| 519 |
-
)
|
| 520 |
-
|
| 521 |
-
# ========== CLIP BASELINE EVALUATION ==========
|
| 522 |
-
print("\n" + "="*60)
|
| 523 |
-
print("🔄 Starting CLIP Baseline Evaluation")
|
| 524 |
-
print("="*60)
|
| 525 |
-
|
| 526 |
-
print("🎯 Extracting CLIP embeddings...")
|
| 527 |
-
(
|
| 528 |
-
clip_color_embeddings,
|
| 529 |
-
clip_hierarchy_embeddings,
|
| 530 |
-
clip_color_hier_embeddings,
|
| 531 |
-
clip_colors,
|
| 532 |
-
clip_hierarchies,
|
| 533 |
-
) = compute_clip_embeddings(clip_model, clip_processor, dataloader)
|
| 534 |
-
|
| 535 |
-
# Output paths for CLIP baseline
|
| 536 |
-
clip_output_color = "clip_baseline_tsne_color_space.png"
|
| 537 |
-
clip_output_hierarchy = "clip_baseline_tsne_hierarchy_space.png"
|
| 538 |
-
|
| 539 |
-
print("🌀 Calculating CLIP baseline color t-SNE...")
|
| 540 |
-
run_tsne(
|
| 541 |
-
clip_color_embeddings,
|
| 542 |
-
clip_colors,
|
| 543 |
-
clip_output_color,
|
| 544 |
-
perplexity,
|
| 545 |
-
"t-SNE of the color embeddings (CLIP Baseline)",
|
| 546 |
-
scatter_color_labels=clip_colors,
|
| 547 |
-
prefer_true_colors=True,
|
| 548 |
-
)
|
| 549 |
-
|
| 550 |
-
print("🎨 Computing color similarity matrix from image embeddings...")
|
| 551 |
-
compute_color_similarity_matrix(
|
| 552 |
-
clip_color_embeddings,
|
| 553 |
-
clip_colors,
|
| 554 |
-
title="Color similarity (image embeddings - CLIP Baseline)"
|
| 555 |
-
)
|
| 556 |
-
|
| 557 |
-
print("🌀 Calculating CLIP baseline hierarchy t-SNE...")
|
| 558 |
-
run_tsne(
|
| 559 |
-
clip_hierarchy_embeddings,
|
| 560 |
-
clip_hierarchies,
|
| 561 |
-
clip_output_hierarchy,
|
| 562 |
-
perplexity,
|
| 563 |
-
"t-SNE of the hierarchy embeddings (CLIP Baseline)",
|
| 564 |
-
scatter_color_labels=clip_hierarchies,
|
| 565 |
-
)
|
| 566 |
-
|
| 567 |
-
print("\n✅ All t-SNE visualizations completed!")
|
| 568 |
-
print(" - Main model: tsne_*.png")
|
| 569 |
-
print(" - CLIP baseline: clip_baseline_tsne_*.png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|