Leacb4 commited on
Commit
65073e2
·
verified ·
1 Parent(s): fb6be9e

Upload evaluation/0_shot_classification.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/0_shot_classification.py +512 -0
evaluation/0_shot_classification.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Zero-shot classification evaluation on a new dataset.
3
+ This file evaluates the main model's performance on unseen data by performing
4
+ zero-shot classification. It compares three methods: color-to-color classification,
5
+ text-to-text, and image-to-text. It generates confusion matrices and classification reports
6
+ for each method to analyze the model's generalization capability.
7
+ """
8
+
9
+ import os
10
+ # Set environment variable to disable tokenizers parallelism warnings
11
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+ import pandas as pd
17
+ from torch.utils.data import Dataset
18
+ import matplotlib.pyplot as plt
19
+ from PIL import Image
20
+ from torchvision import transforms
21
+ from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
22
+ import warnings
23
+ import config
24
+ from tqdm import tqdm
25
+ from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
26
+ import seaborn as sns
27
+ from color_model import CLIPModel as ColorModel
28
+ from hierarchy_model import Model, HierarchyExtractor
29
+
30
+ # Suppress warnings
31
+ warnings.filterwarnings("ignore", category=FutureWarning)
32
+ warnings.filterwarnings("ignore", category=UserWarning)
33
+
34
+ def load_trained_model(model_path, device):
35
+ """
36
+ Load the trained CLIP model from checkpoint
37
+ """
38
+ print(f"Loading trained model from: {model_path}")
39
+
40
+ # Load checkpoint
41
+ checkpoint = torch.load(model_path, map_location=device)
42
+
43
+ # Create the base CLIP model
44
+ model = CLIPModel_transformers.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
45
+
46
+ # Load the trained weights
47
+ model.load_state_dict(checkpoint['model_state_dict'])
48
+ model = model.to(device)
49
+ model.eval()
50
+
51
+ print(f"✅ Model loaded successfully!")
52
+ print(f"📊 Training epoch: {checkpoint['epoch']}")
53
+ print(f"📉 Best validation loss: {checkpoint['best_val_loss']:.4f}")
54
+
55
+ return model, checkpoint
56
+
57
+ def load_feature_models(device):
58
+ """Load feature models (color and hierarchy)"""
59
+
60
+ # Load color model (embed_dim=16)
61
+ color_checkpoint = torch.load(config.color_model_path, map_location=device, weights_only=True)
62
+ color_model = ColorModel(embed_dim=config.color_emb_dim).to(device)
63
+ color_model.load_state_dict(color_checkpoint)
64
+ color_model.eval()
65
+ color_model.name = 'color'
66
+
67
+ # Load hierarchy model (embed_dim=64)
68
+ hierarchy_checkpoint = torch.load(config.hierarchy_model_path, map_location=device)
69
+ hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
70
+ hierarchy_model = Model(
71
+ num_hierarchy_classes=len(hierarchy_classes),
72
+ embed_dim=config.hierarchy_emb_dim
73
+ ).to(device)
74
+ hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
75
+
76
+ # Set up hierarchy extractor
77
+ hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
78
+ hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
79
+ hierarchy_model.eval()
80
+ hierarchy_model.name = 'hierarchy'
81
+
82
+ feature_models = {model.name: model for model in [color_model, hierarchy_model]}
83
+ return feature_models
84
+
85
+ def get_image_embedding(model, image, device):
86
+ """Get image embedding from the trained model"""
87
+ model.eval()
88
+ with torch.no_grad():
89
+ # Ensure image has 3 channels
90
+ if image.dim() == 3 and image.size(0) == 1:
91
+ image = image.expand(3, -1, -1)
92
+ elif image.dim() == 4 and image.size(1) == 1:
93
+ image = image.expand(-1, 3, -1, -1)
94
+
95
+ # Add batch dimension if missing
96
+ if image.dim() == 3:
97
+ image = image.unsqueeze(0) # Add batch dimension: (C, H, W) -> (1, C, H, W)
98
+
99
+ image = image.to(device)
100
+
101
+ # Use vision model directly to get image embeddings
102
+ vision_outputs = model.vision_model(pixel_values=image)
103
+ image_features = model.visual_projection(vision_outputs.pooler_output)
104
+
105
+ return F.normalize(image_features, dim=-1)
106
+
107
+ def get_text_embedding(model, text, processor, device):
108
+ """Get text embedding from the trained model"""
109
+ model.eval()
110
+ with torch.no_grad():
111
+ text_inputs = processor(text=text, padding=True, return_tensors="pt")
112
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
113
+
114
+ # Use text model directly to get text embeddings
115
+ text_outputs = model.text_model(**text_inputs)
116
+ text_features = model.text_projection(text_outputs.pooler_output)
117
+
118
+ return F.normalize(text_features, dim=-1)
119
+
120
+ def evaluate_custom_csv_accuracy(model, dataset, processor, method='similarity'):
121
+ """
122
+ Evaluate the accuracy of the model on your custom CSV using text-to-text similarity
123
+
124
+ Args:
125
+ model: The trained CLIP model
126
+ dataset: CustomCSVDataset
127
+ processor: CLIPProcessor
128
+ method: 'similarity' or 'classification'
129
+ """
130
+ print(f"\n📊 === Evaluation of the accuracy on custom CSV (TEXT-TO-TEXT method) ===")
131
+
132
+ model.eval()
133
+
134
+ # Get all unique colors for classification
135
+ all_colors = set()
136
+ for i in range(len(dataset)):
137
+ _, _, color = dataset[i]
138
+ all_colors.add(color)
139
+
140
+ color_list = sorted(list(all_colors))
141
+ print(f"🎨 Colors found: {color_list}")
142
+
143
+ true_labels = []
144
+ predicted_labels = []
145
+
146
+ # Pre-calculate the embeddings of the color descriptions
147
+ print("🔄 Pre-calculating the embeddings of the colors...")
148
+ color_embeddings = {}
149
+ for color in color_list:
150
+ color_emb = get_text_embedding(model, color, processor)
151
+ color_embeddings[color] = color_emb
152
+
153
+ print("🔄 Evaluation in progress...")
154
+ correct_predictions = 0
155
+
156
+ for idx in tqdm(range(len(dataset)), desc="Evaluation"):
157
+ image, text, true_color = dataset[idx]
158
+
159
+ # Get text embedding instead of image embedding
160
+ text_emb = get_text_embedding(model, text, processor)
161
+
162
+ # Calculate the similarity with each possible color
163
+ best_similarity = -1
164
+ predicted_color = color_list[0]
165
+
166
+ for color, color_emb in color_embeddings.items():
167
+ similarity = F.cosine_similarity(text_emb, color_emb, dim=1).item()
168
+ if similarity > best_similarity:
169
+ best_similarity = similarity
170
+ predicted_color = color
171
+
172
+ true_labels.append(true_color)
173
+ predicted_labels.append(predicted_color)
174
+
175
+ if true_color == predicted_color:
176
+ correct_predictions += 1
177
+
178
+ # Calculate the accuracy
179
+ accuracy = accuracy_score(true_labels, predicted_labels)
180
+
181
+ print(f"\n✅ Results of evaluation:")
182
+ print(f"🎯 Global accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
183
+ print(f"📊 Correct predictions: {correct_predictions}/{len(true_labels)}")
184
+
185
+ return true_labels, predicted_labels, accuracy
186
+
187
+ def evaluate_custom_csv_accuracy_image(model, dataset, processor, method='similarity'):
188
+ """
189
+ Evaluate the accuracy of the model on your custom CSV using image-to-text similarity
190
+
191
+ Args:
192
+ model: The trained CLIP model
193
+ dataset: CustomCSVDataset with images loaded
194
+ processor: CLIPProcessor
195
+ method: 'similarity' or 'classification'
196
+ """
197
+ print(f"\n📊 === Evaluation of the accuracy on custom CSV (IMAGE-TO-TEXT method) ===")
198
+
199
+ model.eval()
200
+
201
+ # Get all unique colors for classification
202
+ all_colors = set()
203
+ for i in range(len(dataset)):
204
+ _, _, color = dataset[i]
205
+ all_colors.add(color)
206
+
207
+ color_list = sorted(list(all_colors))
208
+ print(f"🎨 Colors found: {color_list}")
209
+
210
+ true_labels = []
211
+ predicted_labels = []
212
+
213
+ # Pre-calculate the embeddings of the color descriptions
214
+ print("🔄 Pre-calculating the embeddings of the colors...")
215
+ color_embeddings = {}
216
+ for color in color_list:
217
+ color_emb = get_text_embedding(model, color, processor)
218
+ color_embeddings[color] = color_emb
219
+
220
+ print("🔄 Evaluation in progress...")
221
+ correct_predictions = 0
222
+
223
+ for idx in tqdm(range(len(dataset)), desc="Evaluation"):
224
+ image, text, true_color = dataset[idx]
225
+
226
+ # Get image embedding (this is the key difference from text-to-text)
227
+ image_emb = get_image_embedding(model, image, processor)
228
+
229
+ # Calculate the similarity with each possible color
230
+ best_similarity = -1
231
+ predicted_color = color_list[0]
232
+
233
+ for color, color_emb in color_embeddings.items():
234
+ similarity = F.cosine_similarity(image_emb, color_emb, dim=1).item()
235
+ if similarity > best_similarity:
236
+ best_similarity = similarity
237
+ predicted_color = color
238
+
239
+ true_labels.append(true_color)
240
+ predicted_labels.append(predicted_color)
241
+
242
+ if true_color == predicted_color:
243
+ correct_predictions += 1
244
+
245
+ # Calculate the accuracy
246
+ accuracy = accuracy_score(true_labels, predicted_labels)
247
+
248
+ print(f"\n✅ Results of evaluation:")
249
+ print(f"🎯 Global accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
250
+ print(f"📊 Correct predictions: {correct_predictions}/{len(true_labels)}")
251
+
252
+ return true_labels, predicted_labels, accuracy
253
+
254
+ def evaluate_custom_csv_accuracy_color_only(model, dataset, processor):
255
+ """
256
+ Evaluate the accuracy by encoding ONLY the color (not the full text)
257
+ This tests if the embedding space is consistent for colors
258
+
259
+ Args:
260
+ model: The trained CLIP model
261
+ dataset: CustomCSVDataset
262
+ processor: CLIPProcessor
263
+ """
264
+ print(f"\n📊 === Evaluation of the accuracy on custom CSV (COLOR-TO-COLOR method) ===")
265
+ print("🔬 This test encodes ONLY the color name, not the full text")
266
+
267
+ model.eval()
268
+
269
+ # Get all unique colors for classification
270
+ all_colors = set()
271
+ for i in range(len(dataset)):
272
+ _, _, color = dataset[i]
273
+ all_colors.add(color)
274
+
275
+ color_list = sorted(list(all_colors))
276
+ print(f"🎨 Colors found: {color_list}")
277
+
278
+ true_labels = []
279
+ predicted_labels = []
280
+
281
+ # Pre-calculate the embeddings of the color descriptions
282
+ print("🔄 Pre-calculating the embeddings of the colors...")
283
+ color_embeddings = {}
284
+ for color in color_list:
285
+ color_emb = get_text_embedding(model, color, processor)
286
+ color_embeddings[color] = color_emb
287
+
288
+ print("🔄 Evaluation in progress...")
289
+ correct_predictions = 0
290
+
291
+ for idx in tqdm(range(len(dataset)), desc="Evaluation"):
292
+ image, text, true_color = dataset[idx]
293
+
294
+ # KEY DIFFERENCE: Get embedding of the TRUE COLOR only (not the full text)
295
+ true_color_emb = get_text_embedding(model, true_color, processor)
296
+
297
+ # Calculate the similarity with each possible color
298
+ best_similarity = -1
299
+ predicted_color = color_list[0]
300
+
301
+ for color, color_emb in color_embeddings.items():
302
+ similarity = F.cosine_similarity(true_color_emb, color_emb, dim=1).item()
303
+ if similarity > best_similarity:
304
+ best_similarity = similarity
305
+ predicted_color = color
306
+
307
+ true_labels.append(true_color)
308
+ predicted_labels.append(predicted_color)
309
+
310
+ if true_color == predicted_color:
311
+ correct_predictions += 1
312
+
313
+ # Calculate the accuracy
314
+ accuracy = accuracy_score(true_labels, predicted_labels)
315
+
316
+ print(f"\n✅ Results of evaluation:")
317
+ print(f"🎯 Global accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
318
+ print(f"📊 Correct predictions: {correct_predictions}/{len(true_labels)}")
319
+
320
+ return true_labels, predicted_labels, accuracy
321
+
322
+ def search_custom_csv_by_text(model, dataset, query, processor, top_k=5):
323
+ """Search in your CSV by text query"""
324
+ print(f"\n🔍 Search in custom CSV: '{query}'")
325
+
326
+ # Get the embedding of the query
327
+ query_emb = get_text_embedding(model, query, processor)
328
+
329
+ similarities = []
330
+
331
+ print("🔄 Calculating similarities...")
332
+ for idx in tqdm(range(len(dataset)), desc="Processing"):
333
+ image, text, color, _, image_path = dataset[idx]
334
+
335
+ # Get the embedding of the image
336
+ image_emb = get_image_embedding(model, image, processor)
337
+
338
+ # Calculer la similarité
339
+ similarity = F.cosine_similarity(query_emb, image_emb, dim=1).item()
340
+
341
+ similarities.append((idx, similarity, text, color, color, image_path))
342
+
343
+ # Trier par similarité
344
+ similarities.sort(key=lambda x: x[1], reverse=True)
345
+
346
+ return similarities[:top_k]
347
+
348
+ def plot_confusion_matrix(true_labels, predicted_labels, save_path=None, title_suffix="text"):
349
+ """
350
+ Display and save the confusion matrix
351
+ """
352
+ print("\n📈 === Generation of the confusion matrix ===")
353
+
354
+ # Calculate the confusion matrix
355
+ cm = confusion_matrix(true_labels, predicted_labels)
356
+
357
+ # Get unique labels in sorted order
358
+ unique_labels = sorted(set(true_labels + predicted_labels))
359
+
360
+ # Calculate accuracy
361
+ accuracy = accuracy_score(true_labels, predicted_labels)
362
+
363
+ # Calculate the percentages and round to integers
364
+ cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
365
+ cm_percent = np.around(cm_percent).astype(int)
366
+
367
+ # Create the figure
368
+ plt.figure(figsize=(12, 10))
369
+
370
+ # Confusion matrix with percentages and labels (no decimal points)
371
+ sns.heatmap(cm_percent,
372
+ annot=True,
373
+ fmt='d',
374
+ cmap='Blues',
375
+ cbar_kws={'label': 'Percentage (%)'},
376
+ xticklabels=unique_labels,
377
+ yticklabels=unique_labels)
378
+
379
+ plt.title(f"Confusion Matrix for {title_suffix} - new data - accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)", fontsize=16)
380
+ plt.xlabel('Predictions', fontsize=12)
381
+ plt.ylabel('True colors', fontsize=12)
382
+ plt.xticks(rotation=45, ha='right')
383
+ plt.yticks(rotation=0)
384
+ plt.tight_layout()
385
+
386
+ if save_path:
387
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
388
+ print(f"💾 Confusion matrix saved: {save_path}")
389
+
390
+ plt.show()
391
+
392
+ return cm
393
+
394
+ class CustomCSVDataset(Dataset):
395
+ def __init__(self, dataframe, image_size=224, load_images=True):
396
+ self.dataframe = dataframe
397
+ self.image_size = image_size
398
+ self.load_images = load_images
399
+
400
+ # Define image transformations
401
+ self.transform = transforms.Compose([
402
+ transforms.Resize((image_size, image_size)),
403
+ transforms.ToTensor(),
404
+ transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
405
+ std=[0.26862954, 0.26130258, 0.27577711])
406
+ ])
407
+
408
+ def __len__(self):
409
+ return len(self.dataframe)
410
+
411
+ def __getitem__(self, idx):
412
+ row = self.dataframe.iloc[idx]
413
+ text = row[config.text_column]
414
+ colors = row[config.color_column]
415
+
416
+ if self.load_images and config.column_local_image_path in row:
417
+ # Load the actual image
418
+ try:
419
+ image = Image.open(row[config.column_local_image_path]).convert('RGB')
420
+ image = self.transform(image)
421
+ except Exception as e:
422
+ print(f"Warning: Could not load image {row.get(config.column_local_image_path, 'unknown')}: {e}")
423
+ image = torch.zeros(3, self.image_size, self.image_size)
424
+ else:
425
+ # Return dummy image if not loading images
426
+ image = torch.zeros(3, self.image_size, self.image_size)
427
+
428
+ return image, text, colors
429
+
430
+ if __name__ == "__main__":
431
+ """Main function with evaluation"""
432
+ print("🚀 === Test and Evaluation of the model on new dataset ===")
433
+
434
+ # Load model
435
+ print("🔧 Loading the model...")
436
+ model, checkpoint = load_trained_model(config.main_model_path, config.device)
437
+
438
+ # Create processor
439
+ processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
440
+
441
+ # Load new dataset
442
+ print("📊 Loading the new dataset...")
443
+ df = pd.read_csv(config.local_dataset_path) # replace local_dataset_path with a new df
444
+
445
+ print("\n" + "="*80)
446
+ print("🎨 COLOR-TO-COLOR CLASSIFICATION (Control Test)")
447
+ print("="*80)
448
+
449
+ # Create dataset without loading images
450
+ dataset_color = CustomCSVDataset(df, load_images=False)
451
+
452
+ # 0. Evaluation encoding ONLY the color (control test)
453
+ true_labels_color, predicted_labels_color, accuracy_color = evaluate_custom_csv_accuracy_color_only(
454
+ model, dataset_color, processor
455
+ )
456
+
457
+ # Confusion matrix for color-only
458
+ confusion_matrix_color = plot_confusion_matrix(
459
+ true_labels_color, predicted_labels_color,
460
+ save_path="confusion_matrix_color_only.png",
461
+ title_suffix="color-only"
462
+ )
463
+
464
+ print("\n" + "="*80)
465
+ print("📝 TEXT-TO-TEXT CLASSIFICATION")
466
+ print("="*80)
467
+
468
+ # Create dataset without loading images for text-to-text
469
+ dataset_text = CustomCSVDataset(df, load_images=False)
470
+
471
+ # 1. Evaluation of the accuracy (text-to-text)
472
+ true_labels_text, predicted_labels_text, accuracy_text = evaluate_custom_csv_accuracy(
473
+ model, dataset_text, processor, method='similarity'
474
+ )
475
+
476
+ # 2. Confusion matrix for text
477
+ confusion_matrix_text = plot_confusion_matrix(
478
+ true_labels_text, predicted_labels_text,
479
+ save_path="confusion_matrix_text.png",
480
+ title_suffix="text"
481
+ )
482
+
483
+ print("\n" + "="*80)
484
+ print("🖼️ IMAGE-TO-TEXT CLASSIFICATION")
485
+ print("="*80)
486
+
487
+ # Create dataset with images loaded for image-to-text
488
+ dataset_image = CustomCSVDataset(df, load_images=True)
489
+
490
+ # 3. Evaluation of the accuracy (image-to-text)
491
+ true_labels_image, predicted_labels_image, accuracy_image = evaluate_custom_csv_accuracy_image(
492
+ model, dataset_image, processor, method='similarity'
493
+ )
494
+
495
+ # 4. Confusion matrix for images
496
+ confusion_matrix_image = plot_confusion_matrix(
497
+ true_labels_image, predicted_labels_image,
498
+ save_path="confusion_matrix_image.png",
499
+ title_suffix="image"
500
+ )
501
+
502
+ # 5. Summary comparison
503
+ print("\n" + "="*80)
504
+ print("📊 SUMMARY")
505
+ print("="*80)
506
+ print(f"🎨 Color-to-Color Accuracy (Control): {accuracy_color:.4f} ({accuracy_color*100:.2f}%)")
507
+ print(f"📝 Text-to-Text Accuracy: {accuracy_text:.4f} ({accuracy_text*100:.2f}%)")
508
+ print(f"🖼️ Image-to-Text Accuracy: {accuracy_image:.4f} ({accuracy_image*100:.2f}%)")
509
+ print(f"\n📊 Analysis:")
510
+ print(f" • Loss from full text vs color-only: {abs(accuracy_color - accuracy_text):.4f} ({abs(accuracy_color - accuracy_text)*100:.2f}%)")
511
+ print(f" • Difference text vs image: {abs(accuracy_text - accuracy_image):.4f} ({abs(accuracy_text - accuracy_image)*100:.2f}%)")
512
+