Leacb4 commited on
Commit
5014bf7
·
verified ·
1 Parent(s): 0a876a6

Upload models/main_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/main_model.py +840 -0
models/main_model.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Main file for training the CLIP model with color and hierarchy alignment.
4
+ This file centralizes all the logic for training the main model. It uses
5
+ pre-trained color and hierarchy models to guide the main model's learning
6
+ through contrastive and alignment loss functions. It handles data loading,
7
+ training with validation, and checkpoint saving.
8
+ """
9
+
10
+ import os
11
+ # Set environment variable to disable tokenizers parallelism warnings
12
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
13
+
14
+ import pandas as pd
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch.utils.data import Dataset, DataLoader, random_split
19
+ from torchvision import transforms
20
+ from PIL import Image
21
+ import matplotlib.pyplot as plt
22
+ from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
23
+ import warnings
24
+ from tqdm import tqdm
25
+ import json
26
+ import config
27
+
28
+ # Suppress warnings
29
+ warnings.filterwarnings("ignore", category=FutureWarning)
30
+ warnings.filterwarnings("ignore", category=UserWarning)
31
+
32
+ # -------------------------------
33
+ # Loss Functions
34
+ # -------------------------------
35
+
36
+ def triple_contrastive_loss(text_features, image_features, attribute_features, temperature=0.07):
37
+ """
38
+ Calculate triple contrastive loss for text, image, and attribute features.
39
+
40
+ This loss combines text-image similarity with attribute-based similarities
41
+ (color and hierarchy) to learn aligned embeddings.
42
+
43
+ Args:
44
+ text_features: Text embeddings from main model [batch_size, embed_dim]
45
+ image_features: Image embeddings from main model [batch_size, embed_dim]
46
+ attribute_features: Concatenated color + hierarchy embeddings [batch_size, color_dim + hierarchy_dim]
47
+ temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
48
+
49
+ Returns:
50
+ Contrastive loss value
51
+ """
52
+ text_features = F.normalize(text_features, dim=-1)
53
+ image_features = F.normalize(image_features, dim=-1)
54
+ attribute_features = F.normalize(attribute_features, dim=-1)
55
+
56
+ text_image_logits = (text_features[:, config.color_emb_dim+config.hierarchy_emb_dim:] @ image_features[:, config.color_emb_dim+config.hierarchy_emb_dim:].T) / temperature
57
+ text_attr_logits = (text_features[:, :config.color_emb_dim+config.hierarchy_emb_dim] @ attribute_features.T) / temperature
58
+ image_attr_logits = (attribute_features @ image_features[:,:config.color_emb_dim+config.hierarchy_emb_dim].T) / temperature
59
+
60
+ # Weight distribution
61
+ weight_text_image = 0.7
62
+ weight_attr_based = 0.15
63
+
64
+ logits = (weight_text_image * text_image_logits +
65
+ weight_attr_based * text_attr_logits +
66
+ weight_attr_based * image_attr_logits)
67
+
68
+ labels = torch.arange(len(text_features)).to(text_features.device)
69
+ loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
70
+
71
+ return loss
72
+
73
+ def enhanced_contrastive_loss(text_features, image_features, attribute_features,
74
+ color_model, hierarchy_model, colors, hierarchies, temperature=0.07, alignment_weight=0.3):
75
+ """
76
+ Enhanced contrastive loss with direct alignment between color/hierarchy models and main model.
77
+
78
+ This loss combines the original triple contrastive loss with direct alignment losses
79
+ that force the main model's color and hierarchy dimensions to align with the
80
+ specialized color and hierarchy models.
81
+
82
+ Args:
83
+ text_features: Main model text embeddings [batch_size, embed_dim]
84
+ image_features: Main model image embeddings [batch_size, embed_dim]
85
+ attribute_features: Concatenated color + hierarchy features [batch_size, color_dim + hierarchy_dim]
86
+ color_model: Pre-trained color model for extracting color embeddings
87
+ hierarchy_model: Pre-trained hierarchy model for extracting hierarchy embeddings
88
+ colors: List of color strings for this batch [batch_size]
89
+ hierarchies: List of hierarchy strings for this batch [batch_size]
90
+ temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
91
+ alignment_weight: Weight for the alignment loss component (default: 0.3)
92
+
93
+ Returns:
94
+ Tuple of (total_loss, metrics_dict) where metrics_dict contains detailed loss components
95
+ """
96
+
97
+ # Original triple contrastive loss
98
+ text_features_norm = F.normalize(text_features, dim=-1)
99
+ image_features_norm = F.normalize(image_features, dim=-1)
100
+ attribute_features_norm = F.normalize(attribute_features, dim=-1)
101
+
102
+ text_image_logits = (text_features_norm[:, config.color_emb_dim+config.hierarchy_emb_dim:] @
103
+ image_features_norm[:, config.color_emb_dim+config.hierarchy_emb_dim:].T) / temperature
104
+ text_attr_logits = (text_features_norm[:, :config.color_emb_dim+config.hierarchy_emb_dim] @
105
+ attribute_features_norm.T) / temperature
106
+ image_attr_logits = (attribute_features_norm @
107
+ image_features_norm[:,:config.color_emb_dim+config.hierarchy_emb_dim].T) / temperature
108
+
109
+ # Weight distribution for original loss
110
+ weight_text_image = 0.7
111
+ weight_attr_based = 0.15
112
+
113
+ original_logits = (weight_text_image * text_image_logits +
114
+ weight_attr_based * text_attr_logits +
115
+ weight_attr_based * image_attr_logits)
116
+
117
+ labels = torch.arange(len(text_features)).to(text_features.device)
118
+ original_loss = (F.cross_entropy(original_logits, labels) +
119
+ F.cross_entropy(original_logits.T, labels)) / 2
120
+
121
+ # Direct alignment loss between color model and main model first 16 dims
122
+ with torch.no_grad():
123
+ color_embeddings = color_model.get_text_embeddings(colors)
124
+ hierarchy_embeddings = hierarchy_model.get_text_embeddings(hierarchies)
125
+
126
+ # Extract color dimensions from main model embeddings
127
+ main_color_text = text_features[:, :config.color_emb_dim]
128
+ main_color_image = image_features[:, :config.color_emb_dim]
129
+
130
+ # Extract hierarchy dimensions from main model embeddings
131
+ main_hierarchy_text = text_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim]
132
+ main_hierarchy_image = image_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim]
133
+
134
+ # Normalize for better correlation
135
+ color_embeddings_norm = F.normalize(color_embeddings, dim=-1)
136
+ main_color_text_norm = F.normalize(main_color_text, dim=-1)
137
+ main_color_image_norm = F.normalize(main_color_image, dim=-1)
138
+
139
+ hierarchy_embeddings_norm = F.normalize(hierarchy_embeddings, dim=-1)
140
+ main_hierarchy_text_norm = F.normalize(main_hierarchy_text, dim=-1)
141
+ main_hierarchy_image_norm = F.normalize(main_hierarchy_image, dim=-1)
142
+
143
+ # Color alignment loss using MSE and cosine similarity
144
+ color_text_alignment_loss = F.mse_loss(main_color_text_norm, color_embeddings_norm)
145
+ color_image_alignment_loss = F.mse_loss(main_color_image_norm, color_embeddings_norm)
146
+ color_text_cosine_loss = 1 - F.cosine_similarity(main_color_text_norm, color_embeddings_norm).mean()
147
+ color_image_cosine_loss = 1 - F.cosine_similarity(main_color_image_norm, color_embeddings_norm).mean()
148
+
149
+ # Color alignment loss
150
+ color_alignment_loss = (
151
+ color_text_alignment_loss + color_image_alignment_loss +
152
+ color_text_cosine_loss + color_image_cosine_loss
153
+ ) / 4
154
+
155
+ # Hierarchy alignment loss using MSE and cosine similarity
156
+ hierarchy_text_alignment_loss = F.mse_loss(main_hierarchy_text_norm, hierarchy_embeddings_norm)
157
+ hierarchy_image_alignment_loss = F.mse_loss(main_hierarchy_image_norm, hierarchy_embeddings_norm)
158
+ hierarchy_text_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_text_norm, hierarchy_embeddings_norm).mean()
159
+ hierarchy_image_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_image_norm, hierarchy_embeddings_norm).mean()
160
+
161
+ # Hierarchy alignment loss
162
+ hierarchy_alignment_loss = (
163
+ hierarchy_text_alignment_loss + hierarchy_image_alignment_loss +
164
+ hierarchy_text_cosine_loss + hierarchy_image_cosine_loss
165
+ ) / 4
166
+
167
+ # Combined alignment loss
168
+ alignment_loss = (color_alignment_loss + hierarchy_alignment_loss) / 2
169
+
170
+ # Combine losses
171
+ total_loss = (1 - alignment_weight) * original_loss + alignment_weight * alignment_loss
172
+
173
+ return total_loss, {
174
+ 'original_loss': original_loss.item(),
175
+ 'alignment_loss': alignment_loss.item(),
176
+ 'color_text_alignment': color_text_alignment_loss.item(),
177
+ 'color_image_alignment': color_image_alignment_loss.item(),
178
+ 'color_text_cosine': color_text_cosine_loss.item(),
179
+ 'color_image_cosine': color_image_cosine_loss.item(),
180
+ 'hierarchy_text_alignment': hierarchy_text_alignment_loss.item(),
181
+ 'hierarchy_image_alignment': hierarchy_image_alignment_loss.item(),
182
+ 'hierarchy_text_cosine': hierarchy_text_cosine_loss.item(),
183
+ 'hierarchy_image_cosine': hierarchy_image_cosine_loss.item()
184
+ }
185
+
186
+ # -------------------------------
187
+ # Training Functions
188
+ # -------------------------------
189
+
190
+ def train_one_epoch(model, train_loader, optimizer, feature_models, device, clip_processor, temperature=0.07):
191
+ """
192
+ Train the model for one epoch using triple contrastive loss.
193
+
194
+ Args:
195
+ model: Main CLIP model to train
196
+ train_loader: DataLoader for training data
197
+ optimizer: Optimizer instance
198
+ feature_models: Dictionary containing color and hierarchy models
199
+ device: Device to train on
200
+ clip_processor: CLIP processor for text preprocessing
201
+ temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
202
+
203
+ Returns:
204
+ Average training loss for the epoch
205
+ """
206
+ model.train()
207
+ total_loss = 0.0
208
+ num_batches = 0
209
+
210
+ # Create progress bar for training
211
+ pbar = tqdm(train_loader, desc="Training", leave=False)
212
+
213
+ for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar):
214
+ # Move data to device
215
+ images = images.to(device)
216
+ images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
217
+
218
+ # Process text inputs
219
+ text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
220
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
221
+
222
+ # Forward pass
223
+ optimizer.zero_grad()
224
+ outputs = model(**text_inputs, pixel_values=images)
225
+
226
+ text_features = outputs.text_embeds
227
+ image_features = outputs.image_embeds
228
+
229
+ # Get feature embeddings
230
+ # Use exact color-name embeddings if available (new color model)
231
+ if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'):
232
+ color_features = feature_models[config.color_column].get_color_name_embeddings(colors)
233
+ else:
234
+ color_features = feature_models[config.color_column].get_text_embeddings(colors)
235
+ hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy)
236
+ concat_features = torch.cat((color_features, hierarchy_features), dim=1)
237
+
238
+ # Calculate loss
239
+ loss = triple_contrastive_loss(text_features, image_features, concat_features, temperature)
240
+
241
+ # Backward pass
242
+ loss.backward()
243
+ optimizer.step()
244
+
245
+ total_loss += loss.item()
246
+ num_batches += 1
247
+
248
+ # Update progress bar
249
+ pbar.set_postfix({
250
+ 'Loss': f'{loss.item():.4f}',
251
+ 'Avg Loss': f'{total_loss/num_batches:.4f}'
252
+ })
253
+
254
+ return total_loss / num_batches
255
+
256
+ def train_one_epoch_enhanced(model, train_loader, optimizer, feature_models, color_model, hierarchy_model,
257
+ device, clip_processor, temperature=0.07, alignment_weight=0.3):
258
+ """
259
+ Enhanced training with direct color and hierarchy alignment loss.
260
+
261
+ This function trains the model using the enhanced contrastive loss that includes
262
+ direct alignment between the main model's color/hierarchy dimensions and the
263
+ specialized color/hierarchy models.
264
+
265
+ Args:
266
+ model: Main CLIP model to train
267
+ train_loader: DataLoader for training data
268
+ optimizer: Optimizer instance
269
+ feature_models: Dictionary containing color and hierarchy models
270
+ color_model: Pre-trained color model for alignment
271
+ hierarchy_model: Pre-trained hierarchy model for alignment
272
+ device: Device to train on
273
+ clip_processor: CLIP processor for text preprocessing
274
+ temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
275
+ alignment_weight: Weight for the alignment loss component (default: 0.3)
276
+
277
+ Returns:
278
+ Tuple of (average_loss, metrics_dict) where metrics_dict contains detailed loss components
279
+ """
280
+ model.train()
281
+ total_loss = 0.0
282
+ total_metrics = {
283
+ 'original_loss': 0.0,
284
+ 'alignment_loss': 0.0,
285
+ 'color_text_alignment': 0.0,
286
+ 'color_image_alignment': 0.0,
287
+ 'color_text_cosine': 0.0,
288
+ 'color_image_cosine': 0.0,
289
+ 'hierarchy_text_alignment': 0.0,
290
+ 'hierarchy_image_alignment': 0.0,
291
+ 'hierarchy_text_cosine': 0.0,
292
+ 'hierarchy_image_cosine': 0.0
293
+ }
294
+ num_batches = 0
295
+
296
+ pbar = tqdm(train_loader, desc="Training Enhanced", leave=False)
297
+
298
+ for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar):
299
+ # Move data to device
300
+ images = images.to(device)
301
+ images = images.expand(-1, 3, -1, -1)
302
+
303
+ # Process text inputs
304
+ text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
305
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
306
+
307
+ # Forward pass
308
+ optimizer.zero_grad()
309
+ outputs = model(**text_inputs, pixel_values=images)
310
+
311
+ text_features = outputs.text_embeds
312
+ image_features = outputs.image_embeds
313
+
314
+ # Get feature embeddings
315
+ if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'):
316
+ color_features = feature_models[config.color_column].get_color_name_embeddings(colors)
317
+ else:
318
+ color_features = feature_models[config.color_column].get_text_embeddings(colors)
319
+ hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy)
320
+ concat_features = torch.cat((color_features, hierarchy_features), dim=1)
321
+
322
+ # Calculate enhanced loss with hierarchy alignment
323
+ loss, metrics = enhanced_contrastive_loss(
324
+ text_features, image_features, concat_features,
325
+ color_model, hierarchy_model, colors, hierarchy, temperature, alignment_weight
326
+ )
327
+
328
+ # Backward pass
329
+ loss.backward()
330
+ optimizer.step()
331
+
332
+ total_loss += loss.item()
333
+ for key, value in metrics.items():
334
+ total_metrics[key] += value
335
+ num_batches += 1
336
+
337
+ # Update progress bar
338
+ pbar.set_postfix({
339
+ 'Loss': f'{loss.item():.4f}',
340
+ 'Align': f'{metrics["alignment_loss"]:.4f}',
341
+ 'ColCos': f'{metrics["color_text_cosine"]:.3f}',
342
+ 'HierCos': f'{metrics["hierarchy_text_cosine"]:.3f}'
343
+ })
344
+
345
+ avg_metrics = {key: value / num_batches for key, value in total_metrics.items()}
346
+ return total_loss / num_batches, avg_metrics
347
+
348
+ def valid_one_epoch(model, val_loader, feature_models, device, clip_processor, temperature=0.07):
349
+ """
350
+ Validate the model for one epoch using triple contrastive loss.
351
+
352
+ Args:
353
+ model: Main CLIP model to validate
354
+ val_loader: DataLoader for validation data
355
+ feature_models: Dictionary containing color and hierarchy models
356
+ device: Device to validate on
357
+ clip_processor: CLIP processor for text preprocessing
358
+ temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
359
+
360
+ Returns:
361
+ Average validation loss for the epoch
362
+ """
363
+ model.eval()
364
+ total_loss = 0.0
365
+ num_batches = 0
366
+
367
+ # Create progress bar for validation
368
+ pbar = tqdm(val_loader, desc="Validation", leave=False)
369
+
370
+ with torch.no_grad():
371
+ for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar):
372
+ # Move data to device
373
+ images = images.to(device)
374
+ images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
375
+
376
+ # Process text inputs
377
+ text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
378
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
379
+
380
+ # Forward pass
381
+ outputs = model(**text_inputs, pixel_values=images)
382
+
383
+ text_features = outputs.text_embeds
384
+ image_features = outputs.image_embeds
385
+
386
+ # Get feature embeddings
387
+ if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'):
388
+ color_features = feature_models[config.color_column].get_color_name_embeddings(colors)
389
+ else:
390
+ color_features = feature_models[config.color_column].get_text_embeddings(colors)
391
+ hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy)
392
+ concat_features = torch.cat((color_features, hierarchy_features), dim=1)
393
+
394
+ # Calculate loss
395
+ loss = triple_contrastive_loss(text_features, image_features, concat_features, temperature)
396
+
397
+ total_loss += loss.item()
398
+ num_batches += 1
399
+
400
+ # Update progress bar
401
+ pbar.set_postfix({
402
+ 'Loss': f'{loss.item():.4f}',
403
+ 'Avg Loss': f'{total_loss/num_batches:.4f}'
404
+ })
405
+
406
+ return total_loss / num_batches
407
+
408
+ # -------------------------------
409
+ # Dataset
410
+ # -------------------------------
411
+
412
+ class CustomDataset(Dataset):
413
+ """
414
+ Custom dataset for main model training.
415
+
416
+ Handles loading images from local paths, extracting text descriptions,
417
+ and applying appropriate transformations for training and validation.
418
+ """
419
+
420
+ def __init__(self, dataframe, use_local_images=True, image_size=224):
421
+ """
422
+ Initialize the custom dataset.
423
+
424
+ Args:
425
+ dataframe: DataFrame with columns for image paths, text descriptions, colors, and hierarchy labels
426
+ use_local_images: Whether to use local images (default: True)
427
+ image_size: Size of images after resizing (default: 224)
428
+ """
429
+ self.dataframe = dataframe
430
+ self.use_local_images = use_local_images
431
+ self.image_size = image_size
432
+
433
+ # Transforms with augmentation for training
434
+ self.transform = transforms.Compose([
435
+ transforms.Resize((image_size, image_size)),
436
+ transforms.RandomHorizontalFlip(p=0.5),
437
+ transforms.RandomRotation(15),
438
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15),
439
+ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
440
+ transforms.ToTensor(),
441
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
442
+ ])
443
+
444
+ # Transforms for validation (no augmentation)
445
+ self.val_transform = transforms.Compose([
446
+ transforms.Resize((image_size, image_size)),
447
+ transforms.ToTensor(),
448
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
449
+ ])
450
+
451
+ self.training_mode = True
452
+
453
+ def set_training_mode(self, training=True):
454
+ """
455
+ Switch between training and validation transforms.
456
+
457
+ Args:
458
+ training: If True, use training transforms with augmentation; if False, use validation transforms
459
+ """
460
+ self.training_mode = training
461
+
462
+ def __len__(self):
463
+ """Return the number of samples in the dataset."""
464
+ return len(self.dataframe)
465
+
466
+ def __getitem__(self, idx):
467
+ """
468
+ Get a sample from the dataset.
469
+
470
+ Args:
471
+ idx: Index of the sample
472
+
473
+ Returns:
474
+ Tuple of (image_tensor, description_text, color_label, hierarchy_label)
475
+ """
476
+ row = self.dataframe.iloc[idx]
477
+
478
+ image_data = row[config.column_local_image_path]
479
+ image = Image.open(image_data).convert("RGB")
480
+
481
+ # Apply appropriate transform
482
+ if self.training_mode:
483
+ image = self.transform(image)
484
+ else:
485
+ image = self.val_transform(image)
486
+
487
+ # Get text and labels
488
+ description = row[config.text_column]
489
+ color = row[config.color_column]
490
+ hierarchy = row[config.hierarchy_column]
491
+
492
+ return image, description, color, hierarchy
493
+
494
+ # -------------------------------
495
+ # Model Loading
496
+ # -------------------------------
497
+
498
+ def load_models():
499
+ """
500
+ Load color and hierarchy models from checkpoints.
501
+
502
+ This function loads the pre-trained color and hierarchy models along with
503
+ their tokenizers and extractors, and prepares them for use in main model training.
504
+
505
+ Returns:
506
+ Dictionary mapping model names to model instances:
507
+ - 'color': ColorCLIP model instance
508
+ - 'hierarchy': Hierarchy model instance
509
+ """
510
+ from color_model import ColorCLIP, Tokenizer
511
+ from hierarchy_model import Model, HierarchyExtractor
512
+
513
+ # Initialize tokenizer first
514
+ tokenizer = Tokenizer()
515
+
516
+ # Load vocabulary if available
517
+ if os.path.exists(config.tokeniser_path):
518
+ with open(config.tokeniser_path, 'r') as f:
519
+ vocab_dict = json.load(f)
520
+ tokenizer.load_vocab(vocab_dict)
521
+ print(f"Tokenizer vocabulary loaded from {config.tokeniser_path}")
522
+ else:
523
+ print(f"Warning: {config.tokeniser_path} not found. Using default tokenizer.")
524
+
525
+ # Load trained model first to get correct vocab size
526
+ checkpoint = torch.load(config.config.color_model_path, map_location=config.device)
527
+
528
+ # Extract vocab size from the checkpoint's embedding layer
529
+ vocab_size_from_checkpoint = checkpoint['text_encoder.embedding.weight'].shape[0]
530
+ print(f"Vocab size from checkpoint: {vocab_size_from_checkpoint}")
531
+ print(f"Vocab size from tokenizer: {tokenizer.counter}")
532
+
533
+ # Use the larger of the two to ensure compatibility
534
+ vocab_size = max(vocab_size_from_checkpoint, tokenizer.counter)
535
+
536
+ # Initialize model with correct vocab size
537
+ color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=config.color_emb_dim).to(config.device)
538
+ color_model.tokenizer = tokenizer
539
+
540
+ # Load the checkpoint
541
+ color_model.load_state_dict(checkpoint)
542
+ print(f"Color model loaded from {config.color_model_path}")
543
+
544
+ color_model.eval()
545
+ color_model.name = config.color_column
546
+
547
+ # Load hierarchy model
548
+ hierarchy_checkpoint = torch.load(config.hierarchy_model_path, map_location=config.device)
549
+ hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
550
+ hierarchy_model = Model(
551
+ num_hierarchy_classes=len(hierarchy_classes),
552
+ embed_dim=config.hierarchy_emb_dim
553
+ ).to(config.device)
554
+ hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
555
+
556
+ # Set up hierarchy extractor
557
+ hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
558
+ hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
559
+ hierarchy_model.eval()
560
+ hierarchy_model.name = config.hierarchy_column
561
+
562
+ feature_models = {model.name: model for model in [color_model, hierarchy_model]}
563
+
564
+ return feature_models
565
+
566
+ # -------------------------------
567
+ # Main Training Function
568
+ # -------------------------------
569
+
570
+ def train_model(model, train_loader, val_loader, feature_models, device,
571
+ num_epochs=20, learning_rate=1e-5, temperature=0.07,
572
+ save_path=config.main_model_path, use_enhanced_loss=False, alignment_weight=0.3, color_alignment_model=None):
573
+ """
574
+ Custom training loop using train_one_epoch and valid_one_epoch functions.
575
+
576
+ This function handles the complete training process including:
577
+ - Training and validation loops
578
+ - Learning rate scheduling
579
+ - Early stopping
580
+ - Model checkpointing
581
+ - Training curve visualization
582
+
583
+ Args:
584
+ model: Main CLIP model to train
585
+ train_loader: DataLoader for training data
586
+ val_loader: DataLoader for validation data
587
+ feature_models: Dictionary containing color and hierarchy models
588
+ device: Device to train on
589
+ num_epochs: Number of training epochs (default: 20)
590
+ learning_rate: Learning rate for optimizer (default: 1e-5)
591
+ temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
592
+ save_path: Path to save model checkpoints (default: main_model_path)
593
+ use_enhanced_loss: Whether to use enhanced contrastive loss with alignment (default: False)
594
+ alignment_weight: Weight for alignment loss component if using enhanced loss (default: 0.3)
595
+ color_alignment_model: Optional color model for alignment (default: None, uses feature_models)
596
+
597
+ Returns:
598
+ Tuple of (training_losses, validation_losses) lists
599
+ """
600
+ model = model.to(device)
601
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
602
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
603
+
604
+ train_losses = []
605
+ val_losses = []
606
+ best_val_loss = float('inf')
607
+ patience_counter = 0
608
+ patience = 5
609
+
610
+ print(f"Starting training for {num_epochs} epochs...")
611
+ print(f"Learning rate: {learning_rate}")
612
+ print(f"Temperature: {temperature}")
613
+ print(f"Device: {device}")
614
+ print(f"Training samples: {len(train_loader.dataset)}")
615
+ print(f"Validation samples: {len(val_loader.dataset)}")
616
+ print(f"Batch size: {train_loader.batch_size}")
617
+ print(f"Estimated time per epoch: ~{len(train_loader) * 2 / 60:.1f} minutes")
618
+
619
+ # Create processor once for efficiency
620
+ processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
621
+
622
+ # Create progress bar for epochs
623
+ epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", position=0)
624
+
625
+ for epoch in epoch_pbar:
626
+ # Update epoch progress bar
627
+ epoch_pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")
628
+
629
+ # Training
630
+ if use_enhanced_loss:
631
+ if color_alignment_model is None:
632
+ color_alignment_model = feature_models[config.color_column]
633
+ hierarchy_model = feature_models[config.hierarchy_column]
634
+ train_loss, align_metrics = train_one_epoch_enhanced(
635
+ model, train_loader, optimizer, feature_models, color_alignment_model, hierarchy_model, device, processor, temperature, alignment_weight
636
+ )
637
+ else:
638
+ train_loss = train_one_epoch(model, train_loader, optimizer, feature_models, device, processor, temperature)
639
+ align_metrics = None
640
+ train_losses.append(train_loss)
641
+
642
+ # Validation
643
+ val_loss = valid_one_epoch(model, val_loader, feature_models, device, processor, temperature)
644
+ val_losses.append(val_loss)
645
+
646
+ # Learning rate scheduling
647
+ scheduler.step(val_loss)
648
+
649
+ # Update epoch progress bar with metrics
650
+ postfix = {
651
+ 'Train Loss': f'{train_loss:.4f}',
652
+ 'Val Loss': f'{val_loss:.4f}',
653
+ 'LR': f'{optimizer.param_groups[0]["lr"]:.2e}',
654
+ 'Best Val': f'{best_val_loss:.4f}'
655
+ }
656
+ if align_metrics is not None:
657
+ postfix.update({
658
+ 'Align': f"{align_metrics['alignment_loss']:.3f}",
659
+ 'ColCos': f"{align_metrics['color_text_cosine']:.3f}",
660
+ 'HierCos': f"{align_metrics['hierarchy_text_cosine']:.3f}"
661
+ })
662
+ epoch_pbar.set_postfix(postfix)
663
+
664
+ # Save best model
665
+ if val_loss < best_val_loss:
666
+ best_val_loss = val_loss
667
+ patience_counter = 0
668
+
669
+ # Save checkpoint
670
+ torch.save({
671
+ 'epoch': epoch,
672
+ 'model_state_dict': model.state_dict(),
673
+ 'optimizer_state_dict': optimizer.state_dict(),
674
+ 'train_loss': train_loss,
675
+ 'val_loss': val_loss,
676
+ 'best_val_loss': best_val_loss,
677
+ }, save_path)
678
+ else:
679
+ patience_counter += 1
680
+
681
+ # Early stopping
682
+ if patience_counter >= patience:
683
+ print(f"\n🛑 Early stopping triggered after {patience_counter} epochs without improvement")
684
+ break
685
+
686
+ # Plot training curves
687
+ plt.figure(figsize=(12, 4))
688
+
689
+ plt.subplot(1, 2, 1)
690
+ plt.plot(train_losses, label='Train Loss', color='blue')
691
+ plt.plot(val_losses, label='Val Loss', color='red')
692
+ plt.title('Training and Validation Loss')
693
+ plt.xlabel('Epoch')
694
+ plt.ylabel('Loss')
695
+ plt.legend()
696
+ plt.grid(True, alpha=0.3)
697
+
698
+ plt.subplot(1, 2, 2)
699
+ plt.plot(train_losses, label='Train Loss', color='blue')
700
+ plt.title('Training Loss')
701
+ plt.xlabel('Epoch')
702
+ plt.ylabel('Loss')
703
+ plt.legend()
704
+ plt.grid(True, alpha=0.3)
705
+
706
+ plt.tight_layout()
707
+ plt.savefig('training_curves.png', dpi=300, bbox_inches='tight')
708
+ plt.close()
709
+
710
+ print(f"\nTraining completed!")
711
+ print(f"Best validation loss: {best_val_loss:.4f}")
712
+ print(f"Final model saved to: {save_path}")
713
+ print(f"Training curves saved to: training_curves.png")
714
+
715
+ return train_losses, val_losses
716
+
717
+ # -------------------------------
718
+ # Main Function
719
+ # -------------------------------
720
+
721
+ def main():
722
+ print("="*80)
723
+ print("🚀 Training of the model with alignement color and hierarchy")
724
+ print("="*80)
725
+
726
+ # Configuration
727
+ num_epochs = 20
728
+ learning_rate = 1e-5
729
+ temperature = 0.07
730
+ alignment_weight = 0.5
731
+ batch_size = 32
732
+ subset_size = 10000
733
+ use_enhanced_loss = True
734
+
735
+ # Load the data
736
+ print(f"\n📂 Loading the data...")
737
+ df = pd.read_csv(config.local_dataset_path)
738
+ print(f" Data downloaded: {len(df)} samples")
739
+
740
+ # filter the rows with NaN values
741
+ df_clean = df.dropna(subset=[config.column_local_image_path])
742
+ print(f" After filtering NaN: {len(df_clean)} samples")
743
+
744
+ # Creation of datasets
745
+ dataset = CustomDataset(df_clean)
746
+
747
+ # Creation of a subset for a faster training
748
+ print(f"\n📊 Creation of a subset of {subset_size} samples...")
749
+ subset_size = min(subset_size, len(dataset))
750
+ train_size = int(0.8 * subset_size)
751
+ val_size = subset_size - train_size
752
+
753
+ # Creation of a subset with random indexes but reproductibles
754
+ np.random.seed(42)
755
+ subset_indices = np.random.choice(len(dataset), subset_size, replace=False)
756
+ subset_dataset = torch.utils.data.Subset(dataset, subset_indices)
757
+
758
+ train_dataset, val_dataset = random_split(
759
+ subset_dataset,
760
+ [train_size, val_size],
761
+ generator=torch.Generator().manual_seed(42)
762
+ )
763
+
764
+ # Creation of dataloaders
765
+ train_loader = DataLoader(
766
+ train_dataset,
767
+ batch_size=batch_size,
768
+ shuffle=True,
769
+ num_workers=2,
770
+ pin_memory=True if torch.cuda.is_available() else False
771
+ )
772
+ val_loader = DataLoader(
773
+ val_dataset,
774
+ batch_size=batch_size,
775
+ shuffle=False,
776
+ num_workers=2,
777
+ pin_memory=True if torch.cuda.is_available() else False
778
+ )
779
+
780
+ print(f" Train: {len(train_dataset)} samples")
781
+ print(f" Validation: {len(val_dataset)} samples")
782
+
783
+ # Loading models
784
+ print(f"\n🔧 Loading models...")
785
+ feature_models = load_models()
786
+
787
+ # Load or create the main model
788
+ print(f"\n📦 Loading main model...")
789
+ clip_model = CLIPModel_transformers.from_pretrained(
790
+ 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
791
+ )
792
+
793
+ # Load the model
794
+ if os.path.exists(config.main_model_path):
795
+ print(f" Model found {config.main_model_path}")
796
+ print(f" Loading checkpoint...")
797
+ checkpoint = torch.load(config.main_model_path, map_location=config.device)
798
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
799
+ clip_model.load_state_dict(checkpoint['model_state_dict'])
800
+ print(f" ✅ Checkpoint loaded from {checkpoint.get('epoch', '?')}")
801
+ else:
802
+ clip_model.load_state_dict(checkpoint)
803
+ print(f" ✅ Checkpoint loaded")
804
+ else:
805
+ print(f" New model, no checkpoint found")
806
+
807
+ # Move the model on the device
808
+ clip_model = clip_model.to(config.device)
809
+
810
+ # Training with enhanced loss
811
+ print(f"\n🎯 Beginning training...")
812
+ print(f"\n" + "="*80)
813
+
814
+ train_losses, val_losses = train_model(
815
+ model=clip_model,
816
+ train_loader=train_loader,
817
+ val_loader=val_loader,
818
+ feature_models=feature_models,
819
+ device=config.device,
820
+ num_epochs=num_epochs,
821
+ learning_rate=learning_rate,
822
+ temperature=temperature,
823
+ save_path=config.main_model_path,
824
+ use_enhanced_loss=use_enhanced_loss,
825
+ alignment_weight=alignment_weight,
826
+ color_alignment_model=feature_models[config.color_column]
827
+ )
828
+
829
+ print("\n" + "="*80)
830
+ print("✅ Traning finished!")
831
+ print(f" Modèle sauvegardé: {config.main_model_path}")
832
+ print(f" Training curves: training_curves.png")
833
+ print("\n📊 Final results:")
834
+ print(f" Last train loss: {train_losses[-1]:.4f}")
835
+ print(f" Last validation loss: {val_losses[-1]:.4f}")
836
+ print(f" Best loss validation: {min(val_losses):.4f}")
837
+ print("="*80)
838
+
839
+ if __name__ == "__main__":
840
+ main()