Leacb4 commited on
Commit
4611995
·
verified ·
1 Parent(s): d1b4c8f

Upload evaluation/tsne_images.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/tsne_images.py +569 -0
evaluation/tsne_images.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Outputs several t-SNE visualizations with color and hierarchy overlays to
4
+ verify that the main model separates colors well inside each hierarchy group.
5
+ """
6
+
7
+ import math
8
+
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import pandas as pd
12
+ import seaborn as sns
13
+ import torch
14
+ from matplotlib.patches import Polygon
15
+ from PIL import Image
16
+ from sklearn.manifold import TSNE
17
+ from sklearn.metrics import (
18
+ silhouette_score,
19
+ davies_bouldin_score,
20
+ calinski_harabasz_score,
21
+ )
22
+ from sklearn.preprocessing import normalize
23
+ from sklearn.metrics.pairwise import cosine_similarity
24
+ from torch.utils.data import DataLoader, Dataset
25
+ from torchvision import transforms
26
+ from tqdm import tqdm
27
+ from transformers import CLIPModel as CLIPModel_transformers, CLIPProcessor
28
+
29
+ try:
30
+ from scipy.spatial import ConvexHull
31
+ except ImportError:
32
+ ConvexHull = None
33
+
34
+ from config import (
35
+ color_column,
36
+ color_emb_dim,
37
+ column_local_image_path,
38
+ device,
39
+ hierarchy_column,
40
+ hierarchy_emb_dim,
41
+ images_dir,
42
+ local_dataset_path,
43
+ main_model_path,
44
+ )
45
+
46
+
47
+ class ImageDataset(Dataset):
48
+ """Lightweight dataset to load local images along with colors and hierarchies."""
49
+
50
+ def __init__(self, dataframe: pd.DataFrame, root_dir: str):
51
+ self.df = dataframe.reset_index(drop=True)
52
+ self.root_dir = root_dir
53
+ self.transform = transforms.Compose(
54
+ [
55
+ transforms.Resize((224, 224)),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize(
58
+ mean=[0.485, 0.456, 0.406],
59
+ std=[0.229, 0.224, 0.225],
60
+ ),
61
+ ]
62
+ )
63
+
64
+ def __len__(self):
65
+ return len(self.df)
66
+
67
+ def __getitem__(self, idx):
68
+ row = self.df.iloc[idx]
69
+ img_path = row[column_local_image_path]
70
+ image = Image.open(img_path).convert("RGB")
71
+ image = self.transform(image)
72
+ color = row[color_column]
73
+ hierarchy = row[hierarchy_column]
74
+ return image, color, hierarchy
75
+
76
+
77
+
78
+ def load_main_model():
79
+ """Load the main model with the trained weights."""
80
+ checkpoint = torch.load(main_model_path, map_location=device)
81
+ state_dict = checkpoint.get("model_state_dict", checkpoint)
82
+ model = CLIPModel_transformers.from_pretrained(
83
+ "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
84
+ )
85
+ model.load_state_dict(state_dict)
86
+ model.to(device)
87
+ model.eval()
88
+ # Load processor for text tokenization
89
+ processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
90
+ return model, processor
91
+
92
+
93
+ def load_clip_baseline():
94
+ """Load the CLIP baseline model from transformers."""
95
+ print("🤗 Loading CLIP baseline model from transformers...")
96
+ clip_model = CLIPModel_transformers.from_pretrained("openai/clip-vit-base-patch32").to(device)
97
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
98
+ clip_model.eval()
99
+ print("✅ CLIP baseline model loaded successfully")
100
+ return clip_model, clip_processor
101
+
102
+
103
+ def enforce_min_hierarchy_samples(df, min_per_hierarchy):
104
+ """Filter out hierarchy groups with fewer than min_per_hierarchy rows."""
105
+ if not min_per_hierarchy or min_per_hierarchy <= 0:
106
+ return df
107
+ counts = df[hierarchy_column].value_counts()
108
+ keep_values = counts[counts >= min_per_hierarchy].index
109
+ filtered = df[df[hierarchy_column].isin(keep_values)].reset_index(drop=True)
110
+ return filtered
111
+
112
+
113
+ def prepare_dataframe(df, sample_size, per_color_limit, min_per_hierarchy=None):
114
+ """Subsample the dataframe to speed up the t-SNE."""
115
+ if per_color_limit and per_color_limit > 0:
116
+ df_limited = (
117
+ df.groupby(color_column)
118
+ .apply(lambda g: g.sample(min(len(g), per_color_limit), random_state=42))
119
+ .reset_index(drop=True)
120
+ )
121
+ else:
122
+ df_limited = df
123
+
124
+ if sample_size and 0 < sample_size < len(df_limited):
125
+ df_limited = df_limited.sample(sample_size, random_state=42).reset_index(
126
+ drop=True
127
+ )
128
+ df_limited = enforce_min_hierarchy_samples(df_limited, min_per_hierarchy)
129
+ return df_limited
130
+
131
+
132
+ def compute_embeddings(model, dataloader):
133
+ """Extract color, hierarchy, and combined embeddings."""
134
+ color_embeddings = []
135
+ hierarchy_embeddings = []
136
+ color_labels = []
137
+ hierarchy_labels = []
138
+ with torch.no_grad():
139
+ for images, colors, hierarchies in tqdm(
140
+ dataloader, desc="Extracting embeddings"
141
+ ):
142
+ images = images.to(device)
143
+ if images.shape[1] == 1: # safety in case
144
+ images = images.expand(-1, 3, -1, -1)
145
+ image_embeds = model.get_image_features(pixel_values=images)
146
+ color_part = image_embeds[:, :color_emb_dim]
147
+ hierarchy_part = image_embeds[
148
+ :, color_emb_dim : color_emb_dim + hierarchy_emb_dim
149
+ ]
150
+ color_embeddings.append(color_part.cpu().numpy())
151
+ hierarchy_embeddings.append(hierarchy_part.cpu().numpy())
152
+ color_labels.extend(colors)
153
+ hierarchy_labels.extend(hierarchies)
154
+ return (
155
+ np.concatenate(color_embeddings, axis=0),
156
+ np.concatenate(hierarchy_embeddings, axis=0),
157
+ color_labels,
158
+ hierarchy_labels,
159
+ )
160
+
161
+
162
+ def compute_clip_embeddings(clip_model, clip_processor, dataloader):
163
+ """Extract CLIP baseline embeddings (full image embeddings, not separated)."""
164
+ all_embeddings = []
165
+ color_labels = []
166
+ hierarchy_labels = []
167
+
168
+ with torch.no_grad():
169
+ for images, colors, hierarchies in tqdm(
170
+ dataloader, desc="Extracting CLIP embeddings"
171
+ ):
172
+ batch_embeddings = []
173
+ for i in range(images.shape[0]):
174
+ # Get single image from batch
175
+ image_tensor = images[i] # Shape: (3, 224, 224)
176
+
177
+ # Denormalize on CPU (safer for PIL conversion)
178
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
179
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
180
+ image_tensor = image_tensor * std + mean
181
+ image_tensor = torch.clamp(image_tensor, 0, 1)
182
+
183
+ # Convert to PIL Image (must be on CPU)
184
+ image_pil = transforms.ToPILImage()(image_tensor.cpu())
185
+
186
+ # Process with CLIP (using empty text since we only need image embeddings)
187
+ inputs = clip_processor(
188
+ text="",
189
+ images=image_pil,
190
+ return_tensors="pt",
191
+ padding=True
192
+ ).to(device)
193
+
194
+ outputs = clip_model(**inputs)
195
+ # Get normalized image embeddings
196
+ image_emb = outputs.image_embeds / outputs.image_embeds.norm(p=2, dim=-1, keepdim=True)
197
+ batch_embeddings.append(image_emb.cpu().numpy())
198
+
199
+ all_embeddings.append(np.vstack(batch_embeddings))
200
+ color_labels.extend(colors)
201
+ hierarchy_labels.extend(hierarchies)
202
+
203
+ # For CLIP, we use the full embeddings for all visualizations
204
+ # (no separation into color/hierarchy dimensions)
205
+ full_embeddings = np.concatenate(all_embeddings, axis=0)
206
+ return (
207
+ full_embeddings, # color_embeddings (using full CLIP embeddings)
208
+ full_embeddings, # hierarchy_embeddings (using full CLIP embeddings)
209
+ full_embeddings, # color_hier_embeddings (using full CLIP embeddings)
210
+ color_labels,
211
+ hierarchy_labels,
212
+ )
213
+
214
+
215
+ def compute_dunn_index(embeddings, labels):
216
+ """
217
+ Compute the Dunn Index for clustering evaluation.
218
+
219
+ The Dunn Index is the ratio of the minimum inter-cluster distance
220
+ to the maximum intra-cluster distance. Higher values indicate better clustering.
221
+
222
+ Args:
223
+ embeddings: Array of embeddings [N, embed_dim]
224
+ labels: Array of cluster labels [N]
225
+
226
+ Returns:
227
+ Dunn Index value (float) or None if calculation fails
228
+ """
229
+ try:
230
+ unique_labels = np.unique(labels)
231
+ if len(unique_labels) < 2:
232
+ return None
233
+
234
+ # Calculate intra-cluster distances (maximum within each cluster)
235
+ max_intra_cluster_dist = 0
236
+ for label in unique_labels:
237
+ cluster_points = embeddings[labels == label]
238
+ if len(cluster_points) > 1:
239
+ # Calculate pairwise distances within cluster
240
+ from scipy.spatial.distance import pdist
241
+ intra_dists = pdist(cluster_points, metric='euclidean')
242
+ if len(intra_dists) > 0:
243
+ max_intra = np.max(intra_dists)
244
+ max_intra_cluster_dist = max(max_intra_cluster_dist, max_intra)
245
+
246
+ if max_intra_cluster_dist == 0:
247
+ return None
248
+
249
+ # Calculate inter-cluster distances (minimum between clusters)
250
+ min_inter_cluster_dist = float('inf')
251
+ for i, label1 in enumerate(unique_labels):
252
+ for label2 in unique_labels[i+1:]:
253
+ cluster1_points = embeddings[labels == label1]
254
+ cluster2_points = embeddings[labels == label2]
255
+
256
+ # Calculate distances between clusters
257
+ from scipy.spatial.distance import cdist
258
+ inter_dists = cdist(cluster1_points, cluster2_points, metric='euclidean')
259
+ min_inter = np.min(inter_dists)
260
+ min_inter_cluster_dist = min(min_inter_cluster_dist, min_inter)
261
+
262
+ if min_inter_cluster_dist == float('inf'):
263
+ return None
264
+
265
+ # Dunn Index = minimum inter-cluster distance / maximum intra-cluster distance
266
+ dunn_index = min_inter_cluster_dist / max_intra_cluster_dist
267
+ return float(dunn_index)
268
+ except Exception as e:
269
+ print(f"⚠️ Error computing Dunn Index: {e}")
270
+ return None
271
+
272
+
273
+ def build_color_map(labels, prefer_true_colors=False):
274
+ """Build a color mapping for labels."""
275
+ unique_labels = sorted(set(labels))
276
+ palette = sns.color_palette("husl", len(unique_labels))
277
+ return {label: palette[idx] for idx, label in enumerate(unique_labels)}
278
+
279
+
280
+ def compute_color_similarity_matrix(embeddings, colors, title="Color similarity (image embeddings)"):
281
+ """Compute and visualize similarity matrix between color centroids."""
282
+ # Use only the colors from the reference heatmap
283
+ reference_colors = ['red', 'pink', 'blue', 'green', 'aqua', 'lime', 'yellow', 'orange',
284
+ 'purple', 'brown', 'gray', 'black', 'white']
285
+ # Map 'yelloworange' to 'yellow' or 'orange' if needed
286
+ color_mapping = {
287
+ 'yelloworange': 'yellow',
288
+ 'grey': 'gray' # Handle grey/gray variation
289
+ }
290
+
291
+ # Filter to only include colors that are in the reference list
292
+ filtered_colors = []
293
+ filtered_embeddings = []
294
+ for i, color in enumerate(colors):
295
+ # Normalize color name
296
+ normalized_color = color_mapping.get(color.lower(), color.lower())
297
+ if normalized_color in reference_colors:
298
+ filtered_colors.append(normalized_color)
299
+ filtered_embeddings.append(embeddings[i])
300
+
301
+ if len(filtered_colors) == 0:
302
+ print("⚠️ No matching colors found in reference list")
303
+ return None
304
+
305
+ # Use only unique colors from reference that exist in data
306
+ unique_colors = sorted([c for c in reference_colors if c in filtered_colors])
307
+
308
+ # Convert to numpy arrays
309
+ filtered_embeddings = np.array(filtered_embeddings)
310
+ filtered_colors = np.array(filtered_colors)
311
+
312
+ # Compute centroids for each color
313
+ centroids = {}
314
+ for color in unique_colors:
315
+ color_mask = np.array([c == color for c in filtered_colors])
316
+ if color_mask.sum() > 0:
317
+ centroids[color] = np.mean(filtered_embeddings[color_mask], axis=0)
318
+
319
+ # Compute similarity matrix
320
+ similarity_matrix = np.zeros((len(unique_colors), len(unique_colors)))
321
+ for i, color1 in enumerate(unique_colors):
322
+ for j, color2 in enumerate(unique_colors):
323
+ if i == j:
324
+ similarity_matrix[i, j] = 1.0
325
+ else:
326
+ if color1 in centroids and color2 in centroids:
327
+ similarity = cosine_similarity(
328
+ [centroids[color1]],
329
+ [centroids[color2]]
330
+ )[0][0]
331
+ similarity_matrix[i, j] = similarity
332
+
333
+ # Create heatmap
334
+ plt.figure(figsize=(12, 10))
335
+ sns.heatmap(
336
+ similarity_matrix,
337
+ annot=True,
338
+ fmt='.2f',
339
+ cmap='RdYlBu_r',
340
+ xticklabels=unique_colors,
341
+ yticklabels=unique_colors,
342
+ square=True,
343
+ cbar_kws={'label': 'Cosine Similarity'},
344
+ linewidths=0.5,
345
+ vmin=-0.6,
346
+ vmax=1.0
347
+ )
348
+
349
+ plt.title(title, fontsize=16, fontweight='bold', pad=20)
350
+ plt.xlabel('Colors', fontsize=14, fontweight='bold')
351
+ plt.ylabel('Colors', fontsize=14, fontweight='bold')
352
+ plt.xticks(rotation=45, ha='right')
353
+ plt.yticks(rotation=0)
354
+ plt.tight_layout()
355
+
356
+ output_path = "color_similarity_image_embeddings.png"
357
+ plt.savefig(output_path, dpi=300, bbox_inches='tight')
358
+ plt.close()
359
+ print(f"✅ Color similarity heatmap saved: {output_path}")
360
+
361
+ return similarity_matrix
362
+
363
+
364
+ def run_tsne(embeddings,legend_labels,output_path,perplexity,title,scatter_color_labels=None,prefer_true_colors=False):
365
+ """Calculate and plot a t-SNE projection."""
366
+ tsne = TSNE(
367
+ n_components=2,
368
+ perplexity=perplexity,
369
+ init="pca",
370
+ learning_rate="auto",
371
+ random_state=42,
372
+ )
373
+ reduced = tsne.fit_transform(embeddings)
374
+
375
+ label_array = np.array(legend_labels)
376
+ color_labels = (
377
+ np.array(scatter_color_labels) if scatter_color_labels is not None else label_array
378
+ )
379
+
380
+ # Calculate silhouette scores
381
+ unique_labels_list = sorted(set(label_array))
382
+ if len(unique_labels_list) > 1 and len(label_array) > 1:
383
+ # Convert labels to numeric indices for silhouette_score
384
+ label_to_idx = {label: idx for idx, label in enumerate(unique_labels_list)}
385
+ numeric_labels = np.array([label_to_idx[label] for label in label_array])
386
+
387
+ # Calculate in original embedding space (ground truth - measures real separation)
388
+ silhouette = silhouette_score(embeddings, numeric_labels, metric='euclidean')
389
+ davies_bouldin = davies_bouldin_score(embeddings, numeric_labels)
390
+ calinski_harabasz = calinski_harabasz_score(embeddings, numeric_labels)
391
+ dunn = compute_dunn_index(embeddings, numeric_labels)
392
+
393
+ else:
394
+ silhouette = None
395
+ davies_bouldin = None
396
+ calinski_harabasz = None
397
+ dunn = None
398
+
399
+ # Helpful reference for the reported clustering indices:
400
+ # • Silhouette Score ∈ [-1, 1] — closer to 1 means points fit their cluster well, 0 means overlap, < 0 suggests misassignment.
401
+ # • Davies–Bouldin Index ∈ [0, +∞) — lower is better; quantifies average similarity between clusters relative to their size.
402
+ # • Calinski–Harabasz Index ∈ [0, +∞) — higher is better; ratio of between-cluster dispersion to within-cluster dispersion.
403
+ # • Dunn Index ∈ [0, +∞) — higher is better; compares the tightest cluster diameter to the closest distance between clusters.
404
+
405
+ # Build color map for visualization
406
+ color_map = build_color_map(color_labels, prefer_true_colors=prefer_true_colors)
407
+ color_series = np.array([color_map[label] for label in color_labels])
408
+
409
+ plt.figure(figsize=(10, 8))
410
+ unique_labels = sorted(set(label_array))
411
+ for label in unique_labels:
412
+ mask = label_array == label
413
+ if 'color' in title:
414
+ c = label
415
+ else:
416
+ c = color_series[mask]
417
+ plt.scatter(
418
+ reduced[mask, 0],
419
+ reduced[mask, 1],
420
+ c=c,
421
+ s=15,
422
+ alpha=0.8,
423
+ label=label,
424
+ )
425
+
426
+ # Add silhouette score to title
427
+ if silhouette is not None:
428
+ title_with_score = f"{title}\n(t-SNE Silhouette: {silhouette:.3f} | Davies-Bouldin: {davies_bouldin:.3f} | Calinski-Harabasz: {calinski_harabasz:.3f} | Dunn: {dunn:.3f})"
429
+ else:
430
+ title_with_score = title
431
+
432
+ plt.title(title_with_score)
433
+ plt.xlabel("t-SNE 1")
434
+ plt.ylabel("t-SNE 2")
435
+ plt.legend(
436
+ bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small", frameon=False
437
+ )
438
+ plt.tight_layout()
439
+ plt.savefig(output_path, dpi=300)
440
+ plt.close()
441
+ print(f"✅ Figure saved in {output_path}")
442
+ print(f" 📊 t-SNE space: {silhouette:.3f} (matches visualization) | Davies-Bouldin: {davies_bouldin:.3f} | Calinski-Harabasz: {calinski_harabasz:.3f} | Dunn: {dunn:.3f}")
443
+
444
+
445
+
446
+ def filter_valid_rows(dataframe: pd.DataFrame) -> pd.DataFrame:
447
+ """Keep only rows with valid local image paths and colors."""
448
+ dataframe = dataframe[dataframe['color'] != 'unknown'].copy()
449
+ df = dataframe.dropna(
450
+ subset=[column_local_image_path, color_column, hierarchy_column]
451
+ ).copy()
452
+ mask = df[column_local_image_path].apply(lambda x: isinstance(x, str) and len(x.strip()) > 0)
453
+ return df[mask].reset_index(drop=True)
454
+
455
+ if __name__ == "__main__":
456
+ sample_size = None
457
+ per_color_limit = 500
458
+ min_per_hierarchy = 200
459
+ batch_size = 32
460
+ perplexity = 30
461
+ output_color = "tsne_color_space.png"
462
+ output_hierarchy = "tsne_hierarchy_space.png"
463
+
464
+ print("📥 Loading the dataset...")
465
+ df = pd.read_csv("data/data_with_local_paths.csv")
466
+ df = filter_valid_rows(df)
467
+ print(f"Total len if the dataset: {len(df)}")
468
+ df = prepare_dataframe(df, sample_size, per_color_limit, min_per_hierarchy)
469
+ print(f"✅ {len(df)} samples will be used for the t-SNE")
470
+ print(f"Number of colors in the dataset: {len(df['color'].unique())}")
471
+ print(f"Colors in the dataset: {df['color'].unique()}")
472
+ dataset = ImageDataset(df, images_dir)
473
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
474
+
475
+ # 2) Loading the models
476
+ print("⚙️ Loading the main model...")
477
+ model, processor = load_main_model()
478
+
479
+ print("⚙️ Loading CLIP baseline model...")
480
+ clip_model, clip_processor = load_clip_baseline()
481
+
482
+ # 3) Extracting the embeddings
483
+ print("🎯 Extracting the embeddings...")
484
+
485
+ (
486
+ color_embeddings,
487
+ hierarchy_embeddings,
488
+ colors,
489
+ hierarchies,
490
+ ) = compute_embeddings(model, dataloader)
491
+
492
+ # 4) Calculating the t-SNE
493
+ print("🌀 Calculating the color t-SNE...")
494
+ run_tsne(
495
+ color_embeddings,
496
+ colors,
497
+ output_color,
498
+ perplexity,
499
+ "t-SNE of the color embeddings of the main model",
500
+ scatter_color_labels=colors,
501
+ prefer_true_colors=True,
502
+ )
503
+
504
+ print("🎨 Computing color similarity matrix from image embeddings...")
505
+ compute_color_similarity_matrix(
506
+ color_embeddings,
507
+ colors,
508
+ title="Color similarity (image embeddings - main model)"
509
+ )
510
+
511
+ print("🌀 Calculating the hierarchy t-SNE...")
512
+ run_tsne(
513
+ hierarchy_embeddings,
514
+ hierarchies,
515
+ output_hierarchy,
516
+ perplexity,
517
+ "t-SNE of the hierarchy embeddings of the main model",
518
+ scatter_color_labels=hierarchies,
519
+ )
520
+
521
+ # ========== CLIP BASELINE EVALUATION ==========
522
+ print("\n" + "="*60)
523
+ print("🔄 Starting CLIP Baseline Evaluation")
524
+ print("="*60)
525
+
526
+ print("🎯 Extracting CLIP embeddings...")
527
+ (
528
+ clip_color_embeddings,
529
+ clip_hierarchy_embeddings,
530
+ clip_color_hier_embeddings,
531
+ clip_colors,
532
+ clip_hierarchies,
533
+ ) = compute_clip_embeddings(clip_model, clip_processor, dataloader)
534
+
535
+ # Output paths for CLIP baseline
536
+ clip_output_color = "clip_baseline_tsne_color_space.png"
537
+ clip_output_hierarchy = "clip_baseline_tsne_hierarchy_space.png"
538
+
539
+ print("🌀 Calculating CLIP baseline color t-SNE...")
540
+ run_tsne(
541
+ clip_color_embeddings,
542
+ clip_colors,
543
+ clip_output_color,
544
+ perplexity,
545
+ "t-SNE of the color embeddings (CLIP Baseline)",
546
+ scatter_color_labels=clip_colors,
547
+ prefer_true_colors=True,
548
+ )
549
+
550
+ print("🎨 Computing color similarity matrix from image embeddings...")
551
+ compute_color_similarity_matrix(
552
+ clip_color_embeddings,
553
+ clip_colors,
554
+ title="Color similarity (image embeddings - CLIP Baseline)"
555
+ )
556
+
557
+ print("🌀 Calculating CLIP baseline hierarchy t-SNE...")
558
+ run_tsne(
559
+ clip_hierarchy_embeddings,
560
+ clip_hierarchies,
561
+ clip_output_hierarchy,
562
+ perplexity,
563
+ "t-SNE of the hierarchy embeddings (CLIP Baseline)",
564
+ scatter_color_labels=clip_hierarchies,
565
+ )
566
+
567
+ print("\n✅ All t-SNE visualizations completed!")
568
+ print(" - Main model: tsne_*.png")
569
+ print(" - CLIP baseline: clip_baseline_tsne_*.png")