Leacb4 commited on
Commit
ae8a3ca
·
verified ·
1 Parent(s): bcdaf40

Upload main_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. main_model.py +156 -162
main_model.py CHANGED
@@ -33,45 +33,9 @@ warnings.filterwarnings("ignore", category=UserWarning)
33
  # Loss Functions
34
  # -------------------------------
35
 
36
- def triple_contrastive_loss(text_features, image_features, attribute_features, temperature=0.07):
37
- """
38
- Calculate triple contrastive loss for text, image, and attribute features.
39
-
40
- This loss combines text-image similarity with attribute-based similarities
41
- (color and hierarchy) to learn aligned embeddings.
42
-
43
- Args:
44
- text_features: Text embeddings from main model [batch_size, embed_dim]
45
- image_features: Image embeddings from main model [batch_size, embed_dim]
46
- attribute_features: Concatenated color + hierarchy embeddings [batch_size, color_dim + hierarchy_dim]
47
- temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
48
-
49
- Returns:
50
- Contrastive loss value
51
- """
52
- text_features = F.normalize(text_features, dim=-1)
53
- image_features = F.normalize(image_features, dim=-1)
54
- attribute_features = F.normalize(attribute_features, dim=-1)
55
-
56
- text_image_logits = (text_features[:, config.color_emb_dim+config.hierarchy_emb_dim:] @ image_features[:, config.color_emb_dim+config.hierarchy_emb_dim:].T) / temperature
57
- text_attr_logits = (text_features[:, :config.color_emb_dim+config.hierarchy_emb_dim] @ attribute_features.T) / temperature
58
- image_attr_logits = (attribute_features @ image_features[:,:config.color_emb_dim+config.hierarchy_emb_dim].T) / temperature
59
-
60
- # Weight distribution
61
- weight_text_image = 0.7
62
- weight_attr_based = 0.15
63
-
64
- logits = (weight_text_image * text_image_logits +
65
- weight_attr_based * text_attr_logits +
66
- weight_attr_based * image_attr_logits)
67
-
68
- labels = torch.arange(len(text_features)).to(text_features.device)
69
- loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
70
-
71
- return loss
72
-
73
  def enhanced_contrastive_loss(text_features, image_features, attribute_features,
74
- color_model, hierarchy_model, colors, hierarchies, temperature=0.07, alignment_weight=0.3):
 
75
  """
76
  Enhanced contrastive loss with direct alignment between color/hierarchy models and main model.
77
 
@@ -167,12 +131,23 @@ def enhanced_contrastive_loss(text_features, image_features, attribute_features,
167
  # Combined alignment loss
168
  alignment_loss = (color_alignment_loss + hierarchy_alignment_loss) / 2
169
 
 
 
 
 
 
 
 
 
170
  # Combine losses
171
  total_loss = (1 - alignment_weight) * original_loss + alignment_weight * alignment_loss
 
 
172
 
173
  return total_loss, {
174
  'original_loss': original_loss.item(),
175
  'alignment_loss': alignment_loss.item(),
 
176
  'color_text_alignment': color_text_alignment_loss.item(),
177
  'color_image_alignment': color_image_alignment_loss.item(),
178
  'color_text_cosine': color_text_cosine_loss.item(),
@@ -187,74 +162,10 @@ def enhanced_contrastive_loss(text_features, image_features, attribute_features,
187
  # Training Functions
188
  # -------------------------------
189
 
190
- def train_one_epoch(model, train_loader, optimizer, feature_models, device, clip_processor, temperature=0.07):
191
- """
192
- Train the model for one epoch using triple contrastive loss.
193
-
194
- Args:
195
- model: Main CLIP model to train
196
- train_loader: DataLoader for training data
197
- optimizer: Optimizer instance
198
- feature_models: Dictionary containing color and hierarchy models
199
- device: Device to train on
200
- clip_processor: CLIP processor for text preprocessing
201
- temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
202
-
203
- Returns:
204
- Average training loss for the epoch
205
- """
206
- model.train()
207
- total_loss = 0.0
208
- num_batches = 0
209
-
210
- # Create progress bar for training
211
- pbar = tqdm(train_loader, desc="Training", leave=False)
212
-
213
- for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar):
214
- # Move data to device
215
- images = images.to(device)
216
- images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
217
-
218
- # Process text inputs
219
- text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
220
- text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
221
-
222
- # Forward pass
223
- optimizer.zero_grad()
224
- outputs = model(**text_inputs, pixel_values=images)
225
-
226
- text_features = outputs.text_embeds
227
- image_features = outputs.image_embeds
228
-
229
- # Get feature embeddings
230
- # Use exact color-name embeddings if available (new color model)
231
- if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'):
232
- color_features = feature_models[config.color_column].get_color_name_embeddings(colors)
233
- else:
234
- color_features = feature_models[config.color_column].get_text_embeddings(colors)
235
- hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy)
236
- concat_features = torch.cat((color_features, hierarchy_features), dim=1)
237
-
238
- # Calculate loss
239
- loss = triple_contrastive_loss(text_features, image_features, concat_features, temperature)
240
-
241
- # Backward pass
242
- loss.backward()
243
- optimizer.step()
244
-
245
- total_loss += loss.item()
246
- num_batches += 1
247
-
248
- # Update progress bar
249
- pbar.set_postfix({
250
- 'Loss': f'{loss.item():.4f}',
251
- 'Avg Loss': f'{total_loss/num_batches:.4f}'
252
- })
253
-
254
- return total_loss / num_batches
255
 
256
- def train_one_epoch_enhanced(model, train_loader, optimizer, feature_models, color_model, hierarchy_model,
257
- device, clip_processor, temperature=0.07, alignment_weight=0.3):
 
258
  """
259
  Enhanced training with direct color and hierarchy alignment loss.
260
 
@@ -282,6 +193,7 @@ def train_one_epoch_enhanced(model, train_loader, optimizer, feature_models, col
282
  total_metrics = {
283
  'original_loss': 0.0,
284
  'alignment_loss': 0.0,
 
285
  'color_text_alignment': 0.0,
286
  'color_image_alignment': 0.0,
287
  'color_text_cosine': 0.0,
@@ -304,6 +216,12 @@ def train_one_epoch_enhanced(model, train_loader, optimizer, feature_models, col
304
  text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
305
  text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
306
 
 
 
 
 
 
 
307
  # Forward pass
308
  optimizer.zero_grad()
309
  outputs = model(**text_inputs, pixel_values=images)
@@ -322,11 +240,16 @@ def train_one_epoch_enhanced(model, train_loader, optimizer, feature_models, col
322
  # Calculate enhanced loss with hierarchy alignment
323
  loss, metrics = enhanced_contrastive_loss(
324
  text_features, image_features, concat_features,
325
- color_model, hierarchy_model, colors, hierarchy, temperature, alignment_weight
 
326
  )
327
 
328
  # Backward pass
329
  loss.backward()
 
 
 
 
330
  optimizer.step()
331
 
332
  total_loss += loss.item()
@@ -345,9 +268,10 @@ def train_one_epoch_enhanced(model, train_loader, optimizer, feature_models, col
345
  avg_metrics = {key: value / num_batches for key, value in total_metrics.items()}
346
  return total_loss / num_batches, avg_metrics
347
 
348
- def valid_one_epoch(model, val_loader, feature_models, device, clip_processor, temperature=0.07):
 
349
  """
350
- Validate the model for one epoch using triple contrastive loss.
351
 
352
  Args:
353
  model: Main CLIP model to validate
@@ -356,6 +280,7 @@ def valid_one_epoch(model, val_loader, feature_models, device, clip_processor, t
356
  device: Device to validate on
357
  clip_processor: CLIP processor for text preprocessing
358
  temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
 
359
 
360
  Returns:
361
  Average validation loss for the epoch
@@ -364,6 +289,10 @@ def valid_one_epoch(model, val_loader, feature_models, device, clip_processor, t
364
  total_loss = 0.0
365
  num_batches = 0
366
 
 
 
 
 
367
  # Create progress bar for validation
368
  pbar = tqdm(val_loader, desc="Validation", leave=False)
369
 
@@ -377,6 +306,11 @@ def valid_one_epoch(model, val_loader, feature_models, device, clip_processor, t
377
  text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
378
  text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
379
 
 
 
 
 
 
380
  # Forward pass
381
  outputs = model(**text_inputs, pixel_values=images)
382
 
@@ -391,8 +325,13 @@ def valid_one_epoch(model, val_loader, feature_models, device, clip_processor, t
391
  hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy)
392
  concat_features = torch.cat((color_features, hierarchy_features), dim=1)
393
 
394
- # Calculate loss
395
- loss = triple_contrastive_loss(text_features, image_features, concat_features, temperature)
 
 
 
 
 
396
 
397
  total_loss += loss.item()
398
  num_batches += 1
@@ -430,13 +369,14 @@ class CustomDataset(Dataset):
430
  self.use_local_images = use_local_images
431
  self.image_size = image_size
432
 
433
- # Transforms with augmentation for training
434
  self.transform = transforms.Compose([
435
  transforms.Resize((image_size, image_size)),
436
  transforms.RandomHorizontalFlip(p=0.5),
437
- transforms.RandomRotation(15),
438
- transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15),
439
- transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
 
440
  transforms.ToTensor(),
441
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
442
  ])
@@ -569,7 +509,9 @@ def load_models():
569
 
570
  def train_model(model, train_loader, val_loader, feature_models, device,
571
  num_epochs=20, learning_rate=1e-5, temperature=0.07,
572
- save_path=config.main_model_path, use_enhanced_loss=False, alignment_weight=0.3, color_alignment_model=None):
 
 
573
  """
574
  Custom training loop using train_one_epoch and valid_one_epoch functions.
575
 
@@ -590,26 +532,29 @@ def train_model(model, train_loader, val_loader, feature_models, device,
590
  learning_rate: Learning rate for optimizer (default: 1e-5)
591
  temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
592
  save_path: Path to save model checkpoints (default: main_model_path)
593
- use_enhanced_loss: Whether to use enhanced contrastive loss with alignment (default: False)
594
  alignment_weight: Weight for alignment loss component if using enhanced loss (default: 0.3)
595
  color_alignment_model: Optional color model for alignment (default: None, uses feature_models)
 
596
 
597
  Returns:
598
  Tuple of (training_losses, validation_losses) lists
599
  """
600
  model = model.to(device)
601
- optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
 
602
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
603
 
604
  train_losses = []
605
  val_losses = []
606
  best_val_loss = float('inf')
607
  patience_counter = 0
608
- patience = 5
609
 
610
  print(f"Starting training for {num_epochs} epochs...")
611
  print(f"Learning rate: {learning_rate}")
612
  print(f"Temperature: {temperature}")
 
 
613
  print(f"Device: {device}")
614
  print(f"Training samples: {len(train_loader.dataset)}")
615
  print(f"Validation samples: {len(val_loader.dataset)}")
@@ -619,6 +564,13 @@ def train_model(model, train_loader, val_loader, feature_models, device,
619
  # Create processor once for efficiency
620
  processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
621
 
 
 
 
 
 
 
 
622
  # Create progress bar for epochs
623
  epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", position=0)
624
 
@@ -627,29 +579,35 @@ def train_model(model, train_loader, val_loader, feature_models, device,
627
  epoch_pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")
628
 
629
  # Training
630
- if use_enhanced_loss:
631
- if color_alignment_model is None:
632
- color_alignment_model = feature_models[config.color_column]
633
- hierarchy_model = feature_models[config.hierarchy_column]
634
- train_loss, align_metrics = train_one_epoch_enhanced(
635
- model, train_loader, optimizer, feature_models, color_alignment_model, hierarchy_model, device, processor, temperature, alignment_weight
636
- )
637
- else:
638
- train_loss = train_one_epoch(model, train_loader, optimizer, feature_models, device, processor, temperature)
639
- align_metrics = None
640
  train_losses.append(train_loss)
641
 
642
  # Validation
643
- val_loss = valid_one_epoch(model, val_loader, feature_models, device, processor, temperature)
 
 
 
 
644
  val_losses.append(val_loss)
645
 
646
  # Learning rate scheduling
647
  scheduler.step(val_loss)
648
 
 
 
 
649
  # Update epoch progress bar with metrics
650
  postfix = {
651
  'Train Loss': f'{train_loss:.4f}',
652
  'Val Loss': f'{val_loss:.4f}',
 
653
  'LR': f'{optimizer.param_groups[0]["lr"]:.2e}',
654
  'Best Val': f'{best_val_loss:.4f}'
655
  }
@@ -661,6 +619,10 @@ def train_model(model, train_loader, val_loader, feature_models, device,
661
  })
662
  epoch_pbar.set_postfix(postfix)
663
 
 
 
 
 
664
  # Save best model
665
  if val_loss < best_val_loss:
666
  best_val_loss = val_loss
@@ -683,21 +645,38 @@ def train_model(model, train_loader, val_loader, feature_models, device,
683
  print(f"\n🛑 Early stopping triggered after {patience_counter} epochs without improvement")
684
  break
685
 
686
- # Plot training curves
687
- plt.figure(figsize=(12, 4))
688
 
689
- plt.subplot(1, 2, 1)
690
- plt.plot(train_losses, label='Train Loss', color='blue')
691
- plt.plot(val_losses, label='Val Loss', color='red')
692
- plt.title('Training and Validation Loss')
 
693
  plt.xlabel('Epoch')
694
  plt.ylabel('Loss')
695
  plt.legend()
696
  plt.grid(True, alpha=0.3)
697
 
698
- plt.subplot(1, 2, 2)
699
- plt.plot(train_losses, label='Train Loss', color='blue')
700
- plt.title('Training Loss')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
  plt.xlabel('Epoch')
702
  plt.ylabel('Loss')
703
  plt.legend()
@@ -723,14 +702,14 @@ def main():
723
  print("🚀 Training of the model with alignement color and hierarchy")
724
  print("="*80)
725
 
726
- # Configuration
727
  num_epochs = 20
728
- learning_rate = 1e-5
729
- temperature = 0.07
730
- alignment_weight = 0.5
 
731
  batch_size = 32
732
- subset_size = 10000
733
- use_enhanced_loss = True
734
 
735
  # Load the data
736
  print(f"\n📂 Loading the data...")
@@ -789,23 +768,31 @@ def main():
789
  clip_model = CLIPModel_transformers.from_pretrained(
790
  'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
791
  )
 
 
 
 
792
 
793
- # Load the model
794
- if os.path.exists(config.main_model_path):
795
- print(f" Model found {config.main_model_path}")
796
- print(f" Loading checkpoint...")
797
- checkpoint = torch.load(config.main_model_path, map_location=config.device)
798
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
799
- clip_model.load_state_dict(checkpoint['model_state_dict'])
800
- print(f" ✅ Checkpoint loaded from {checkpoint.get('epoch', '?')}")
801
- else:
802
- clip_model.load_state_dict(checkpoint)
803
- print(f" ✅ Checkpoint loaded")
804
- else:
805
- print(f" New model, no checkpoint found")
806
 
807
  # Move the model on the device
808
  clip_model = clip_model.to(config.device)
 
 
 
 
809
 
810
  # Training with enhanced loss
811
  print(f"\n🎯 Beginning training...")
@@ -821,19 +808,26 @@ def main():
821
  learning_rate=learning_rate,
822
  temperature=temperature,
823
  save_path=config.main_model_path,
824
- use_enhanced_loss=use_enhanced_loss,
825
  alignment_weight=alignment_weight,
826
- color_alignment_model=feature_models[config.color_column]
 
 
 
827
  )
828
 
829
  print("\n" + "="*80)
830
- print("✅ Traning finished!")
831
- print(f" Modèle sauvegardé: {config.main_model_path}")
832
  print(f" Training curves: training_curves.png")
833
  print("\n📊 Final results:")
834
  print(f" Last train loss: {train_losses[-1]:.4f}")
835
  print(f" Last validation loss: {val_losses[-1]:.4f}")
836
- print(f" Best loss validation: {min(val_losses):.4f}")
 
 
 
 
 
837
  print("="*80)
838
 
839
  if __name__ == "__main__":
 
33
  # Loss Functions
34
  # -------------------------------
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def enhanced_contrastive_loss(text_features, image_features, attribute_features,
37
+ color_model, hierarchy_model, colors, hierarchies, temperature=0.07, alignment_weight=0.3,
38
+ reference_text_features=None, reference_weight=0.1):
39
  """
40
  Enhanced contrastive loss with direct alignment between color/hierarchy models and main model.
41
 
 
131
  # Combined alignment loss
132
  alignment_loss = (color_alignment_loss + hierarchy_alignment_loss) / 2
133
 
134
+ # Optional guidance to keep text space close to base CLIP (helps cross-domain generalization)
135
+ reference_loss = 0.0
136
+ if reference_text_features is not None:
137
+ reference_loss = F.mse_loss(
138
+ F.normalize(text_features, dim=-1),
139
+ F.normalize(reference_text_features, dim=-1)
140
+ )
141
+
142
  # Combine losses
143
  total_loss = (1 - alignment_weight) * original_loss + alignment_weight * alignment_loss
144
+ if reference_text_features is not None:
145
+ total_loss = total_loss + reference_weight * reference_loss
146
 
147
  return total_loss, {
148
  'original_loss': original_loss.item(),
149
  'alignment_loss': alignment_loss.item(),
150
+ 'reference_loss': reference_loss if isinstance(reference_loss, float) else reference_loss.item(),
151
  'color_text_alignment': color_text_alignment_loss.item(),
152
  'color_image_alignment': color_image_alignment_loss.item(),
153
  'color_text_cosine': color_text_cosine_loss.item(),
 
162
  # Training Functions
163
  # -------------------------------
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ def train_one_epoch(model, train_loader, optimizer, feature_models, color_model, hierarchy_model,
167
+ device, clip_processor, temperature=0.07, alignment_weight=0.3,
168
+ reference_model=None, reference_weight=0.1):
169
  """
170
  Enhanced training with direct color and hierarchy alignment loss.
171
 
 
193
  total_metrics = {
194
  'original_loss': 0.0,
195
  'alignment_loss': 0.0,
196
+ 'reference_loss': 0.0,
197
  'color_text_alignment': 0.0,
198
  'color_image_alignment': 0.0,
199
  'color_text_cosine': 0.0,
 
216
  text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
217
  text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
218
 
219
+ # Optional reference text features to keep close to base CLIP
220
+ reference_text_features = None
221
+ if reference_model is not None:
222
+ with torch.no_grad():
223
+ reference_text_features = reference_model.get_text_features(**text_inputs)
224
+
225
  # Forward pass
226
  optimizer.zero_grad()
227
  outputs = model(**text_inputs, pixel_values=images)
 
240
  # Calculate enhanced loss with hierarchy alignment
241
  loss, metrics = enhanced_contrastive_loss(
242
  text_features, image_features, concat_features,
243
+ color_model, hierarchy_model, colors, hierarchy, temperature, alignment_weight,
244
+ reference_text_features=reference_text_features, reference_weight=reference_weight
245
  )
246
 
247
  # Backward pass
248
  loss.backward()
249
+
250
+ # Gradient clipping to prevent exploding gradients
251
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
252
+
253
  optimizer.step()
254
 
255
  total_loss += loss.item()
 
268
  avg_metrics = {key: value / num_batches for key, value in total_metrics.items()}
269
  return total_loss / num_batches, avg_metrics
270
 
271
+ def valid_one_epoch(model, val_loader, feature_models, device, clip_processor, temperature=0.07, alignment_weight=0.3,
272
+ reference_model=None, reference_weight=0.1):
273
  """
274
+ Validate the model for one epoch using enhanced contrastive loss.
275
 
276
  Args:
277
  model: Main CLIP model to validate
 
280
  device: Device to validate on
281
  clip_processor: CLIP processor for text preprocessing
282
  temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
283
+ alignment_weight: Weight for the alignment loss component (default: 0.3)
284
 
285
  Returns:
286
  Average validation loss for the epoch
 
289
  total_loss = 0.0
290
  num_batches = 0
291
 
292
+ # Extract color and hierarchy models
293
+ color_model = feature_models[config.color_column]
294
+ hierarchy_model = feature_models[config.hierarchy_column]
295
+
296
  # Create progress bar for validation
297
  pbar = tqdm(val_loader, desc="Validation", leave=False)
298
 
 
306
  text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
307
  text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
308
 
309
+ # Optional reference text features
310
+ reference_text_features = None
311
+ if reference_model is not None:
312
+ reference_text_features = reference_model.get_text_features(**text_inputs)
313
+
314
  # Forward pass
315
  outputs = model(**text_inputs, pixel_values=images)
316
 
 
325
  hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy)
326
  concat_features = torch.cat((color_features, hierarchy_features), dim=1)
327
 
328
+ # Calculate loss with all required arguments
329
+ loss, metrics = enhanced_contrastive_loss(
330
+ text_features, image_features, concat_features,
331
+ color_model, hierarchy_model, colors, hierarchy,
332
+ temperature, alignment_weight,
333
+ reference_text_features=reference_text_features, reference_weight=reference_weight
334
+ )
335
 
336
  total_loss += loss.item()
337
  num_batches += 1
 
369
  self.use_local_images = use_local_images
370
  self.image_size = image_size
371
 
372
+ # Transforms with augmentation for training (increased augmentation to reduce overfitting)
373
  self.transform = transforms.Compose([
374
  transforms.Resize((image_size, image_size)),
375
  transforms.RandomHorizontalFlip(p=0.5),
376
+ transforms.RandomRotation(15), # Increased for more variation
377
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15), # Increased intensity
378
+ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)), # Increased transform range
379
+ transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.2), # Add blur
380
  transforms.ToTensor(),
381
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
382
  ])
 
509
 
510
  def train_model(model, train_loader, val_loader, feature_models, device,
511
  num_epochs=20, learning_rate=1e-5, temperature=0.07,
512
+ save_path=config.main_model_path, alignment_weight=0.3,
513
+ color_alignment_model=None, weight_decay=3e-4,
514
+ reference_model=None, reference_weight=0.1):
515
  """
516
  Custom training loop using train_one_epoch and valid_one_epoch functions.
517
 
 
532
  learning_rate: Learning rate for optimizer (default: 1e-5)
533
  temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
534
  save_path: Path to save model checkpoints (default: main_model_path)
 
535
  alignment_weight: Weight for alignment loss component if using enhanced loss (default: 0.3)
536
  color_alignment_model: Optional color model for alignment (default: None, uses feature_models)
537
+ weight_decay: L2 regularization weight (default: 3e-4, increased to reduce overfitting)
538
 
539
  Returns:
540
  Tuple of (training_losses, validation_losses) lists
541
  """
542
  model = model.to(device)
543
+ # Use AdamW with weight decay for better regularization (reduces overfitting)
544
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
545
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
546
 
547
  train_losses = []
548
  val_losses = []
549
  best_val_loss = float('inf')
550
  patience_counter = 0
551
+ patience = 7 # Increased from 5 to 7 for better convergence
552
 
553
  print(f"Starting training for {num_epochs} epochs...")
554
  print(f"Learning rate: {learning_rate}")
555
  print(f"Temperature: {temperature}")
556
+ print(f"Weight decay: {weight_decay}")
557
+ print(f"Alignment weight: {alignment_weight}")
558
  print(f"Device: {device}")
559
  print(f"Training samples: {len(train_loader.dataset)}")
560
  print(f"Validation samples: {len(val_loader.dataset)}")
 
564
  # Create processor once for efficiency
565
  processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
566
 
567
+ # Freeze and move reference model (used for text-space regularization)
568
+ if reference_model is not None:
569
+ reference_model = reference_model.to(device)
570
+ reference_model.eval()
571
+ for param in reference_model.parameters():
572
+ param.requires_grad = False
573
+
574
  # Create progress bar for epochs
575
  epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", position=0)
576
 
 
579
  epoch_pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")
580
 
581
  # Training
582
+ if color_alignment_model is None:
583
+ color_alignment_model = feature_models[config.color_column]
584
+ hierarchy_model = feature_models[config.hierarchy_column]
585
+ train_loss, align_metrics = train_one_epoch_enhanced(
586
+ model, train_loader, optimizer, feature_models, color_alignment_model, hierarchy_model,
587
+ device, processor, temperature, alignment_weight,
588
+ reference_model=reference_model, reference_weight=reference_weight
589
+ )
 
 
590
  train_losses.append(train_loss)
591
 
592
  # Validation
593
+ val_loss = valid_one_epoch(
594
+ model, val_loader, feature_models, device, processor,
595
+ temperature=temperature, alignment_weight=alignment_weight,
596
+ reference_model=reference_model, reference_weight=reference_weight
597
+ )
598
  val_losses.append(val_loss)
599
 
600
  # Learning rate scheduling
601
  scheduler.step(val_loss)
602
 
603
+ # Calculate overfitting gap
604
+ overfitting_gap = val_loss - train_loss
605
+
606
  # Update epoch progress bar with metrics
607
  postfix = {
608
  'Train Loss': f'{train_loss:.4f}',
609
  'Val Loss': f'{val_loss:.4f}',
610
+ 'Gap': f'{overfitting_gap:.4f}',
611
  'LR': f'{optimizer.param_groups[0]["lr"]:.2e}',
612
  'Best Val': f'{best_val_loss:.4f}'
613
  }
 
619
  })
620
  epoch_pbar.set_postfix(postfix)
621
 
622
+ # Warning if overfitting is detected
623
+ if overfitting_gap > 0.15 and epoch > 3:
624
+ print(f"\n⚠️ Warning: Significant overfitting detected at epoch {epoch+1} (gap={overfitting_gap:.4f})")
625
+
626
  # Save best model
627
  if val_loss < best_val_loss:
628
  best_val_loss = val_loss
 
645
  print(f"\n🛑 Early stopping triggered after {patience_counter} epochs without improvement")
646
  break
647
 
648
+ # Plot training curves with overfitting analysis
649
+ plt.figure(figsize=(15, 5))
650
 
651
+ # Plot 1: Training and Validation Loss
652
+ plt.subplot(1, 3, 1)
653
+ plt.plot(train_losses, label='Train Loss', color='blue', linewidth=2)
654
+ plt.plot(val_losses, label='Val Loss', color='red', linewidth=2)
655
+ plt.title('Training and Validation Loss', fontsize=12, fontweight='bold')
656
  plt.xlabel('Epoch')
657
  plt.ylabel('Loss')
658
  plt.legend()
659
  plt.grid(True, alpha=0.3)
660
 
661
+ # Plot 2: Overfitting Gap (Val Loss - Train Loss)
662
+ plt.subplot(1, 3, 2)
663
+ gap = [val_losses[i] - train_losses[i] for i in range(len(train_losses))]
664
+ plt.plot(gap, label='Overfitting Gap', color='purple', linewidth=2)
665
+ plt.axhline(y=0, color='black', linestyle='--', alpha=0.3)
666
+ plt.axhline(y=0.1, color='red', linestyle='--', alpha=0.3, label='Warning threshold')
667
+ plt.title('Overfitting Gap (Val - Train)', fontsize=12, fontweight='bold')
668
+ plt.xlabel('Epoch')
669
+ plt.ylabel('Gap')
670
+ plt.legend()
671
+ plt.grid(True, alpha=0.3)
672
+
673
+ # Plot 3: Loss comparison
674
+ plt.subplot(1, 3, 3)
675
+ epochs = list(range(len(train_losses)))
676
+ plt.plot(epochs, train_losses, 'o-', label='Train Loss', color='blue', linewidth=2)
677
+ plt.plot(epochs, val_losses, 's-', label='Val Loss', color='red', linewidth=2)
678
+ plt.fill_between(epochs, train_losses, val_losses, alpha=0.2, color='red')
679
+ plt.title('Loss Comparison', fontsize=12, fontweight='bold')
680
  plt.xlabel('Epoch')
681
  plt.ylabel('Loss')
682
  plt.legend()
 
702
  print("🚀 Training of the model with alignement color and hierarchy")
703
  print("="*80)
704
 
705
+ # Configuration (optimized to reduce overfitting)
706
  num_epochs = 20
707
+ learning_rate = 1.5e-5 # Reduced slightly to prevent overfitting
708
+ temperature = 0.09 # Increased from 0.07 for softer contrastive learning
709
+ alignment_weight = 0.2 # Reduced from 0.3 to prevent overfitting on alignment
710
+ weight_decay = 5e-4 # Increased weight decay for stronger regularization
711
  batch_size = 32
712
+ subset_size = 20000 # Increased dataset size for better generalization
 
713
 
714
  # Load the data
715
  print(f"\n📂 Loading the data...")
 
768
  clip_model = CLIPModel_transformers.from_pretrained(
769
  'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
770
  )
771
+ # Frozen reference CLIP to regularize text space (improves cross-domain generalization)
772
+ reference_clip = CLIPModel_transformers.from_pretrained(
773
+ 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
774
+ )
775
 
776
+ # # Load the model
777
+ # if os.path.exists(config.main_model_path):
778
+ # print(f" Model found {config.main_model_path}")
779
+ # print(f" Loading checkpoint...")
780
+ # checkpoint = torch.load(config.main_model_path, map_location=config.device)
781
+ # if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
782
+ # clip_model.load_state_dict(checkpoint['model_state_dict'])
783
+ # print(f" ✅ Checkpoint loaded from {checkpoint.get('epoch', '?')}")
784
+ # else:
785
+ # clip_model.load_state_dict(checkpoint)
786
+ # print(f" ✅ Checkpoint loaded")
787
+ # else:
788
+ # print(f" New model, no checkpoint found")
789
 
790
  # Move the model on the device
791
  clip_model = clip_model.to(config.device)
792
+ reference_clip = reference_clip.to(config.device)
793
+ reference_clip.eval()
794
+ for param in reference_clip.parameters():
795
+ param.requires_grad = False
796
 
797
  # Training with enhanced loss
798
  print(f"\n🎯 Beginning training...")
 
808
  learning_rate=learning_rate,
809
  temperature=temperature,
810
  save_path=config.main_model_path,
 
811
  alignment_weight=alignment_weight,
812
+ color_alignment_model=feature_models[config.color_column],
813
+ weight_decay=weight_decay,
814
+ reference_model=reference_clip,
815
+ reference_weight=0.1
816
  )
817
 
818
  print("\n" + "="*80)
819
+ print("✅ Training finished!")
820
+ print(f" Model saved: {config.main_model_path}")
821
  print(f" Training curves: training_curves.png")
822
  print("\n📊 Final results:")
823
  print(f" Last train loss: {train_losses[-1]:.4f}")
824
  print(f" Last validation loss: {val_losses[-1]:.4f}")
825
+ print(f" Best validation loss: {min(val_losses):.4f}")
826
+ print(f" Overfitting gap (val-train): {val_losses[-1] - train_losses[-1]:.4f}")
827
+ if val_losses[-1] - train_losses[-1] > 0.1:
828
+ print(" ⚠️ Warning: Significant overfitting detected!")
829
+ elif val_losses[-1] - train_losses[-1] < 0.05:
830
+ print(" ✅ Good generalization!")
831
  print("="*80)
832
 
833
  if __name__ == "__main__":