Leacb4 commited on
Commit
c97f09a
·
verified ·
1 Parent(s): 8c71174

Upload evaluation/heatmap_color_similarities.py with huggingface_hub

Browse files
evaluation/heatmap_color_similarities.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pandas as pd
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
9
+ from sklearn.model_selection import train_test_split
10
+ from config import local_dataset_path, column_local_image_path, color_emb_dim, main_model_path, device
11
+ from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
12
+ import warnings
13
+ warnings.filterwarnings('ignore')
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from torchvision import transforms
16
+ from PIL import Image
17
+ from tqdm import tqdm
18
+
19
+
20
+ PRIMARY_COLORS = [
21
+ 'red', 'pink', 'blue', 'green', 'aqua', 'lime', 'yellow',
22
+ 'orange', 'purple', 'brown', 'gray', 'black', 'white'
23
+ ]
24
+
25
+ class ColorEncoder:
26
+ def __init__(self, main_model_path, device='mps'):
27
+ self.device = torch.device(device)
28
+ self.color_emb_dim = color_emb_dim
29
+ self.primary_colors = PRIMARY_COLORS
30
+
31
+ print(f"🚀 Loading Main Model from {main_model_path}")
32
+
33
+ # Load the main CLIP model
34
+ if os.path.exists(main_model_path):
35
+ checkpoint = torch.load(main_model_path, map_location=self.device)
36
+ self.main_model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
37
+ self.main_model.load_state_dict(checkpoint['model_state_dict'])
38
+ self.main_model.to(self.device)
39
+ self.main_model.eval()
40
+ print(f"✅ Main model loaded successfully")
41
+ else:
42
+ raise FileNotFoundError(f"Main model file {main_model_path} not found")
43
+
44
+ # Create processor
45
+ self.processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
46
+
47
+ # Load dataset
48
+ self._load_dataset()
49
+
50
+ def _load_dataset(self):
51
+ """Load and prepare dataset with primary colors filtering"""
52
+ print("📊 Loading dataset...")
53
+ df = pd.read_csv(local_dataset_path)
54
+ print(f"📊 Loaded {len(df)} samples")
55
+
56
+ # Filter out rows with NaN values in image path
57
+ df_clean = df.dropna(subset=[column_local_image_path])
58
+ print(f"📊 After filtering NaN image paths: {len(df_clean)} samples")
59
+
60
+ # Filter for primary colors only
61
+ df_primary = df_clean[df_clean['color'].isin(self.primary_colors)]
62
+ print(f"📊 After filtering for primary colors: {len(df_primary)} samples")
63
+
64
+ # Show color distribution
65
+ color_counts = df_primary['color'].value_counts()
66
+ print(f"📊 Color distribution:")
67
+ for color in self.primary_colors:
68
+ count = color_counts.get(color, 0)
69
+ print(f" {color}: {count} samples")
70
+
71
+ # Split for train/val - Limit to 10000 samples
72
+ if len(df_primary) > 0:
73
+ # Limit to 10000 samples maximum
74
+ if len(df_primary) > 10000:
75
+ df_primary = df_primary.sample(n=10000, random_state=42)
76
+ print(f"📊 Limited to 10000 samples for processing")
77
+
78
+ _, self.val_df = train_test_split(df_primary, test_size=0.2, random_state=42, stratify=df_primary['color'])
79
+ print(f"📊 Validation samples: {len(self.val_df)}")
80
+ else:
81
+ print("❌ No samples found for primary colors!")
82
+ self.val_df = pd.DataFrame()
83
+
84
+ def create_dataloader(self, dataframe, batch_size=8):
85
+ """Create a dataloader for the dataset"""
86
+ dataset = CustomDataset(dataframe, image_size=224)
87
+ dataset.set_training_mode(False) # Use validation transforms
88
+
89
+ dataloader = DataLoader(
90
+ dataset,
91
+ batch_size=batch_size,
92
+ shuffle=False,
93
+ num_workers=0 # No multiprocessing to avoid memory issues
94
+ )
95
+
96
+ return dataloader
97
+
98
+ def extract_color_embeddings(self, dataloader, embedding_type='text', max_samples=10000):
99
+ """Extract color embeddings (first 16 dimensions) from text or image"""
100
+ all_embeddings = []
101
+ all_colors = []
102
+
103
+ sample_count = 0
104
+
105
+ with torch.no_grad():
106
+ for batch in tqdm(dataloader, desc=f"Extracting {embedding_type} color embeddings"):
107
+ if sample_count >= max_samples:
108
+ break
109
+
110
+ images, texts, colors, hierarchies = batch
111
+ images = images.to(self.device)
112
+ images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
113
+
114
+ # Process text inputs
115
+ text_inputs = self.processor(text=texts, padding=True, return_tensors="pt")
116
+ text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
117
+
118
+ # Forward pass through main model
119
+ outputs = self.main_model(**text_inputs, pixel_values=images)
120
+
121
+ # Extract embeddings based on type
122
+ if embedding_type == 'text':
123
+ embeddings = outputs.text_embeds
124
+ elif embedding_type == 'image':
125
+ embeddings = outputs.image_embeds
126
+ else:
127
+ embeddings = outputs.text_embeds
128
+
129
+ # Extract only the first 16 dimensions (color embeddings)
130
+ color_embeddings = embeddings[:, :self.color_emb_dim]
131
+
132
+ all_embeddings.append(color_embeddings.cpu().numpy())
133
+ all_colors.extend(colors)
134
+
135
+ sample_count += len(images)
136
+
137
+ # Clear GPU memory
138
+ del images, text_inputs, outputs, embeddings, color_embeddings
139
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
140
+
141
+ return np.vstack(all_embeddings), all_colors
142
+
143
+ # Modifiez la méthode predict_colors_from_embeddings
144
+ def predict_colors_from_embeddings(self, embeddings, colors):
145
+ """Predict colors from embeddings using centroid-based classification"""
146
+ # Create color centroids from training data - only for primary colors
147
+ unique_colors = [c for c in self.primary_colors if c in colors]
148
+ centroids = {}
149
+
150
+ for color in unique_colors:
151
+ color_indices = [i for i, c in enumerate(colors) if c == color]
152
+ if len(color_indices) > 0:
153
+ color_embeddings = embeddings[color_indices]
154
+ centroids[color] = np.mean(color_embeddings, axis=0)
155
+
156
+ # Predict colors for all embeddings
157
+ predictions = []
158
+
159
+ for i, embedding in enumerate(embeddings):
160
+ # Find closest centroid
161
+ best_similarity = -1
162
+ predicted_color = None
163
+
164
+ for color, centroid in centroids.items():
165
+ similarity = cosine_similarity([embedding], [centroid])[0][0]
166
+ if similarity > best_similarity:
167
+ best_similarity = similarity
168
+ predicted_color = color
169
+
170
+ predictions.append(predicted_color)
171
+
172
+ return predictions
173
+
174
+ # Modifiez la méthode create_color_confusion_matrix
175
+ def create_color_confusion_matrix(self, true_colors, predicted_colors, title="Primary Colors Confusion Matrix"):
176
+ """Create and plot confusion matrix for primary colors"""
177
+ # Use only the primary colors in the order specified
178
+ unique_colors = [c for c in self.primary_colors if c in true_colors or c in predicted_colors]
179
+
180
+ # Create confusion matrix
181
+ cm = confusion_matrix(true_colors, predicted_colors, labels=unique_colors)
182
+
183
+ # Calculate accuracy
184
+ accuracy = accuracy_score(true_colors, predicted_colors)
185
+
186
+ # Plot confusion matrix with better formatting
187
+ plt.figure(figsize=(14, 12))
188
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
189
+ xticklabels=unique_colors, yticklabels=unique_colors,
190
+ cbar_kws={'label': 'Number of Samples'})
191
+ plt.title(f'{title}\nAccuracy: {accuracy:.3f} ({accuracy*100:.1f}%)', fontsize=16, fontweight='bold')
192
+ plt.ylabel('True Color', fontsize=14, fontweight='bold')
193
+ plt.xlabel('Predicted Color', fontsize=14, fontweight='bold')
194
+ plt.xticks(rotation=45, ha='right')
195
+ plt.yticks(rotation=0)
196
+ plt.tight_layout()
197
+
198
+ return plt.gcf(), accuracy, cm
199
+
200
+ # Modifiez la méthode evaluate_color_classification
201
+ def evaluate_color_classification(self, dataframe, max_samples=10000):
202
+ """Evaluate primary color classification using first 16 dimensions"""
203
+ if len(dataframe) == 0:
204
+ print("❌ No data available for evaluation")
205
+ return None
206
+
207
+ print(f"\n{'='*60}")
208
+ print(f"Evaluating Primary Color Classification (max {max_samples} samples)")
209
+ print(f"Target colors: {', '.join(self.primary_colors)}")
210
+ print(f"{'='*60}")
211
+
212
+ # Create dataloader
213
+ dataloader = self.create_dataloader(dataframe, batch_size=8)
214
+
215
+ results = {}
216
+
217
+ # Evaluate text embeddings
218
+ print("🎨 Extracting text color embeddings (first 16 dimensions)...")
219
+ text_color_embeddings, color_labels = self.extract_color_embeddings(dataloader, 'text', max_samples)
220
+ text_predictions = self.predict_colors_from_embeddings(text_color_embeddings, color_labels)
221
+ text_accuracy = accuracy_score(color_labels, text_predictions)
222
+
223
+ # Create confusion matrix for text
224
+ text_fig, text_acc, text_cm = self.create_color_confusion_matrix(
225
+ color_labels, text_predictions, "Text Color Embeddings (16D) - Confusion Matrix"
226
+ )
227
+
228
+ results['text'] = {
229
+ 'embeddings': text_color_embeddings,
230
+ 'true_colors': color_labels,
231
+ 'predicted_colors': text_predictions,
232
+ 'accuracy': text_accuracy,
233
+ 'confusion_matrix': text_cm,
234
+ 'figure': text_fig
235
+ }
236
+
237
+ # Clear memory
238
+ del text_color_embeddings
239
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
240
+
241
+ # Evaluate image embeddings
242
+ print("🎨 Extracting image color embeddings (first 16 dimensions)...")
243
+ image_color_embeddings, color_labels_img = self.extract_color_embeddings(dataloader, 'image', max_samples)
244
+ image_predictions = self.predict_colors_from_embeddings(image_color_embeddings, color_labels_img)
245
+ image_accuracy = accuracy_score(color_labels_img, image_predictions)
246
+
247
+ # Create confusion matrix for image
248
+ image_fig, image_acc, image_cm = self.create_color_confusion_matrix(
249
+ color_labels_img, image_predictions, "Image Color Embeddings (16D) - Confusion Matrix"
250
+ )
251
+
252
+ results['image'] = {
253
+ 'embeddings': image_color_embeddings,
254
+ 'true_colors': color_labels_img,
255
+ 'predicted_colors': image_predictions,
256
+ 'accuracy': image_accuracy,
257
+ 'confusion_matrix': image_cm,
258
+ 'figure': image_fig
259
+ }
260
+
261
+ # Clear memory
262
+ del image_color_embeddings
263
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
264
+
265
+ # Print detailed results
266
+ print(f"\nPrimary Color Classification Results:")
267
+ print("-" * 50)
268
+ print(f"Text Color Embeddings:")
269
+ print(f" Accuracy: {text_accuracy:.4f} ({text_accuracy*100:.1f}%)")
270
+ print(f"Image Color Embeddings:")
271
+ print(f" Accuracy: {image_accuracy:.4f} ({image_accuracy*100:.1f}%)")
272
+
273
+ # Show classification report
274
+ print(f"\n📊 Detailed Classification Report - Text:")
275
+ text_report = classification_report(color_labels, text_predictions, labels=self.primary_colors,
276
+ target_names=self.primary_colors, output_dict=True)
277
+ for color in self.primary_colors:
278
+ if color in text_report:
279
+ precision = text_report[color]['precision']
280
+ recall = text_report[color]['recall']
281
+ f1 = text_report[color]['f1-score']
282
+ support = text_report[color]['support']
283
+ print(f" {color:>8}: P={precision:.3f} R={recall:.3f} F1={f1:.3f} S={support}")
284
+
285
+ print(f"\n📊 Detailed Classification Report - Image:")
286
+ image_report = classification_report(color_labels_img, image_predictions, labels=self.primary_colors,
287
+ target_names=self.primary_colors, output_dict=True)
288
+ for color in self.primary_colors:
289
+ if color in image_report:
290
+ precision = image_report[color]['precision']
291
+ recall = image_report[color]['recall']
292
+ f1 = image_report[color]['f1-score']
293
+ support = image_report[color]['support']
294
+ print(f" {color:>8}: P={precision:.3f} R={recall:.3f} F1={f1:.3f} S={support}")
295
+
296
+ # Create visualizations
297
+ os.makedirs('evaluation/color_evaluation_results', exist_ok=True)
298
+ results['text']['figure'].savefig('evaluation/color_evaluation_results/text_color_confusion_matrix.png',
299
+ dpi=300, bbox_inches='tight')
300
+ results['image']['figure'].savefig('evaluation/color_evaluation_results/image_color_confusion_matrix.png',
301
+ dpi=300, bbox_inches='tight')
302
+ plt.close(results['text']['figure'])
303
+ plt.close(results['image']['figure'])
304
+
305
+ return results
306
+
307
+ def create_color_similarity_heatmap(self, embeddings, colors, embedding_type='text', save_path='evaluation/color_similarity_results/color_similarity_heatmap.png'):
308
+ """
309
+ Create a heatmap of similarities between encoded colors
310
+ """
311
+ print(f"🎨 Creating color similarity heatmap for {embedding_type} embeddings...")
312
+
313
+ unique_colors = [c for c in self.primary_colors if c in colors]
314
+ centroids = {}
315
+
316
+ for color in unique_colors:
317
+ color_indices = [i for i, c in enumerate(colors) if c == color]
318
+ if len(color_indices) > 0:
319
+ color_embeddings = embeddings[color_indices]
320
+ centroids[color] = np.mean(color_embeddings, axis=0)
321
+
322
+ similarity_matrix = np.zeros((len(unique_colors), len(unique_colors)))
323
+
324
+ for i, color1 in enumerate(unique_colors):
325
+ for j, color2 in enumerate(unique_colors):
326
+ if i == j:
327
+ similarity_matrix[i, j] = 1.0
328
+ else:
329
+ similarity = cosine_similarity([centroids[color1]], [centroids[color2]])[0][0]
330
+ similarity_matrix[i, j] = similarity
331
+
332
+ plt.figure(figsize=(12, 10))
333
+
334
+ sns.heatmap(
335
+ similarity_matrix,
336
+ annot=True,
337
+ fmt='.3f',
338
+ cmap='RdYlBu_r',
339
+ xticklabels=unique_colors,
340
+ yticklabels=unique_colors,
341
+ square=True,
342
+ cbar_kws={'label': 'Cosine Similarity'},
343
+ linewidths=0.5,
344
+ vmin=-0.6,
345
+ vmax=1.0
346
+ )
347
+
348
+ plt.title(f'Color similarity ({embedding_type} embeddings)',
349
+ fontsize=16, fontweight='bold', pad=20)
350
+ plt.xlabel('Colors', fontsize=14, fontweight='bold')
351
+ plt.ylabel('Colors', fontsize=14, fontweight='bold')
352
+ plt.xticks(rotation=45, ha='right')
353
+ plt.yticks(rotation=0)
354
+ plt.tight_layout()
355
+
356
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
357
+ print(f"💾 Heatmap saved: {save_path}")
358
+
359
+ return plt.gcf(), similarity_matrix
360
+
361
+
362
+
363
+ def create_color_similarity_analysis(self, results):
364
+ """
365
+ Complete analysis of similarities between colors for text and image embeddings
366
+ """
367
+ print(f"\n{'='*60}")
368
+ print("🎨 ANALYSIS OF SIMILARITIES BETWEEN COLORS")
369
+ print(f"{'='*60}")
370
+
371
+ os.makedirs('evaluation/color_similarity_results', exist_ok=True)
372
+
373
+ similarity_results = {}
374
+
375
+ if 'text' in results:
376
+ print("\n📝 Analyse des similarités - Text Embeddings:")
377
+ text_fig, text_similarity_matrix = self.create_color_similarity_heatmap(
378
+ results['text']['embeddings'],
379
+ results['text']['true_colors'],
380
+ 'text',
381
+ 'evaluation/color_similarity_results/text_color_similarity_heatmap.png'
382
+ )
383
+ similarity_results['text'] = {
384
+ 'similarity_matrix': text_similarity_matrix,
385
+ 'figure': text_fig
386
+ }
387
+ plt.close(text_fig)
388
+
389
+ # Analyser les embeddings image
390
+ if 'image' in results:
391
+ print("\n🖼️ Analyse des similarités - Image Embeddings:")
392
+ image_fig, image_similarity_matrix = self.create_color_similarity_heatmap(
393
+ results['image']['embeddings'],
394
+ results['image']['true_colors'],
395
+ 'image',
396
+ 'evaluation/color_similarity_results/image_color_similarity_heatmap.png'
397
+ )
398
+ similarity_results['image'] = {
399
+ 'similarity_matrix': image_similarity_matrix,
400
+ 'figure': image_fig
401
+ }
402
+ plt.close(image_fig)
403
+
404
+ # Analyser les similarités les plus élevées et les plus faibles
405
+ self._analyze_similarity_patterns(similarity_results)
406
+
407
+ return similarity_results
408
+
409
+ def _analyze_similarity_patterns(self, similarity_results):
410
+ """
411
+ Analyse les patterns de similarité entre les couleurs
412
+ """
413
+ print(f"\n�� ANALYSE DES PATTERNS DE SIMILARITÉ")
414
+ print("-" * 50)
415
+
416
+ for embedding_type, data in similarity_results.items():
417
+ matrix = data['similarity_matrix']
418
+ unique_colors = [c for c in self.primary_colors if c in [f"color_{i}" for i in range(len(matrix))]]
419
+
420
+ print(f"\n{embedding_type.upper()} Embeddings:")
421
+
422
+ # Trouver les paires les plus similaires (hors diagonale)
423
+ n = len(matrix)
424
+ similarities = []
425
+
426
+ for i in range(n):
427
+ for j in range(i+1, n): # Éviter la diagonale et la redondance
428
+ similarities.append((i, j, matrix[i, j]))
429
+
430
+ # Trier par similarité décroissante
431
+ similarities.sort(key=lambda x: x[2], reverse=True)
432
+
433
+ print("🔗 Couleurs les plus similaires:")
434
+ for i, (idx1, idx2, sim) in enumerate(similarities[:5]):
435
+ color1 = self.primary_colors[idx1] if idx1 < len(self.primary_colors) else f"Color_{idx1}"
436
+ color2 = self.primary_colors[idx2] if idx2 < len(self.primary_colors) else f"Color_{idx2}"
437
+ print(f" {i+1}. {color1} ↔ {color2}: {sim:.3f}")
438
+
439
+ print("🔗 Couleurs les moins similaires:")
440
+ for i, (idx1, idx2, sim) in enumerate(similarities[-5:]):
441
+ color1 = self.primary_colors[idx1] if idx1 < len(self.primary_colors) else f"Color_{idx1}"
442
+ color2 = self.primary_colors[idx2] if idx2 < len(self.primary_colors) else f"Color_{idx2}"
443
+ print(f" {i+1}. {color1} ↔ {color2}: {sim:.3f}")
444
+
445
+ # Calculer la similarité moyenne
446
+ off_diagonal = matrix[np.triu_indices_from(matrix, k=1)]
447
+ mean_similarity = np.mean(off_diagonal)
448
+ std_similarity = np.std(off_diagonal)
449
+
450
+ print(f"📈 Similarité moyenne: {mean_similarity:.3f} ± {std_similarity:.3f}")
451
+
452
+ class CustomDataset(Dataset):
453
+ def __init__(self, dataframe, image_size=224):
454
+ self.dataframe = dataframe
455
+ self.image_size = image_size
456
+
457
+ # Transforms for validation (no augmentation)
458
+ self.val_transform = transforms.Compose([
459
+ transforms.Resize((image_size, image_size)),
460
+ transforms.ToTensor(),
461
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
462
+ ])
463
+
464
+ self.training_mode = True
465
+
466
+ def set_training_mode(self, training=True):
467
+ self.training_mode = training
468
+
469
+ def __len__(self):
470
+ return len(self.dataframe)
471
+
472
+ def __getitem__(self, idx):
473
+ row = self.dataframe.iloc[idx]
474
+
475
+ image_data = row[column_local_image_path]
476
+ image = Image.open(image_data).convert("RGB")
477
+
478
+ # Apply validation transform
479
+ image = self.val_transform(image)
480
+
481
+ # Get text and labels
482
+ description = row['text']
483
+ color = row['color']
484
+ hierarchy = row['hierarchy']
485
+
486
+ return image, description, color, hierarchy
487
+
488
+ # Modifiez la section main
489
+ if __name__ == "__main__":
490
+ print("🚀 Starting Primary Color Encoding and Similarity Analysis")
491
+ print("="*70)
492
+ print(f"Target Primary Colors: {', '.join(PRIMARY_COLORS)}")
493
+ print("="*70)
494
+
495
+ # Initialize color encoder
496
+ color_encoder = ColorEncoder(
497
+ main_model_path=main_model_path,
498
+ device=device
499
+ )
500
+
501
+ # Evaluate primary color classification
502
+ results = color_encoder.evaluate_color_classification(
503
+ color_encoder.val_df,
504
+ max_samples=10000
505
+ )
506
+
507
+ if results:
508
+ print(f"\n✅ Primary color encoding and confusion matrix generation completed!")
509
+ print(f"📊 Results saved in 'evaluation/color_evaluation_results/' directory")
510
+ print(f"🎨 Text Primary Color Accuracy: {results['text']['accuracy']*100:.1f}%")
511
+ print(f"🖼️ Image Primary Color Accuracy: {results['image']['accuracy']*100:.1f}%")
512
+
513
+ # NOUVELLE SECTION: Analyse des similarités
514
+ print(f"\n🎨 Starting Color Similarity Analysis...")
515
+ similarity_results = color_encoder.create_color_similarity_analysis(results)
516
+
517
+ print(f"\n✅ Color similarity analysis completed!")
518
+ print(f"📊 Similarity heatmaps saved in 'evaluation/color_similarity_results/' directory")
519
+
520
+ # Show some sample predictions
521
+ print(f"\n📝 Sample Text Predictions:")
522
+ for i in range(min(10, len(results['text']['true_colors']))):
523
+ true_color = results['text']['true_colors'][i]
524
+ pred_color = results['text']['predicted_colors'][i]
525
+ status = "✓" if true_color == pred_color else "✗"
526
+ print(f" {status} True: {true_color:>8} | Predicted: {pred_color:>8}")
527
+
528
+ print(f"\n🖼️ Sample Image Predictions:")
529
+ for i in range(min(10, len(results['image']['true_colors']))):
530
+ true_color = results['image']['true_colors'][i]
531
+ pred_color = results['image']['predicted_colors'][i]
532
+ status = "✓" if true_color == pred_color else "✗"
533
+ print(f" {status} True: {true_color:>8} | Predicted: {pred_color:>8}")
534
+ else:
535
+ print("❌ No results generated - check if primary colors exist in dataset")