Leacb4 commited on
Commit
32f4666
·
verified ·
1 Parent(s): 22ab4d2

Delete models/main_model.py

Browse files
Files changed (1) hide show
  1. models/main_model.py +0 -840
models/main_model.py DELETED
@@ -1,840 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Main file for training the CLIP model with color and hierarchy alignment.
4
- This file centralizes all the logic for training the main model. It uses
5
- pre-trained color and hierarchy models to guide the main model's learning
6
- through contrastive and alignment loss functions. It handles data loading,
7
- training with validation, and checkpoint saving.
8
- """
9
-
10
- import os
11
- # Set environment variable to disable tokenizers parallelism warnings
12
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
13
-
14
- import pandas as pd
15
- import numpy as np
16
- import torch
17
- import torch.nn.functional as F
18
- from torch.utils.data import Dataset, DataLoader, random_split
19
- from torchvision import transforms
20
- from PIL import Image
21
- import matplotlib.pyplot as plt
22
- from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
23
- import warnings
24
- from tqdm import tqdm
25
- import json
26
- import config
27
-
28
- # Suppress warnings
29
- warnings.filterwarnings("ignore", category=FutureWarning)
30
- warnings.filterwarnings("ignore", category=UserWarning)
31
-
32
- # -------------------------------
33
- # Loss Functions
34
- # -------------------------------
35
-
36
- def triple_contrastive_loss(text_features, image_features, attribute_features, temperature=0.07):
37
- """
38
- Calculate triple contrastive loss for text, image, and attribute features.
39
-
40
- This loss combines text-image similarity with attribute-based similarities
41
- (color and hierarchy) to learn aligned embeddings.
42
-
43
- Args:
44
- text_features: Text embeddings from main model [batch_size, embed_dim]
45
- image_features: Image embeddings from main model [batch_size, embed_dim]
46
- attribute_features: Concatenated color + hierarchy embeddings [batch_size, color_dim + hierarchy_dim]
47
- temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
48
-
49
- Returns:
50
- Contrastive loss value
51
- """
52
- text_features = F.normalize(text_features, dim=-1)
53
- image_features = F.normalize(image_features, dim=-1)
54
- attribute_features = F.normalize(attribute_features, dim=-1)
55
-
56
- text_image_logits = (text_features[:, config.color_emb_dim+config.hierarchy_emb_dim:] @ image_features[:, config.color_emb_dim+config.hierarchy_emb_dim:].T) / temperature
57
- text_attr_logits = (text_features[:, :config.color_emb_dim+config.hierarchy_emb_dim] @ attribute_features.T) / temperature
58
- image_attr_logits = (attribute_features @ image_features[:,:config.color_emb_dim+config.hierarchy_emb_dim].T) / temperature
59
-
60
- # Weight distribution
61
- weight_text_image = 0.7
62
- weight_attr_based = 0.15
63
-
64
- logits = (weight_text_image * text_image_logits +
65
- weight_attr_based * text_attr_logits +
66
- weight_attr_based * image_attr_logits)
67
-
68
- labels = torch.arange(len(text_features)).to(text_features.device)
69
- loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
70
-
71
- return loss
72
-
73
- def enhanced_contrastive_loss(text_features, image_features, attribute_features,
74
- color_model, hierarchy_model, colors, hierarchies, temperature=0.07, alignment_weight=0.3):
75
- """
76
- Enhanced contrastive loss with direct alignment between color/hierarchy models and main model.
77
-
78
- This loss combines the original triple contrastive loss with direct alignment losses
79
- that force the main model's color and hierarchy dimensions to align with the
80
- specialized color and hierarchy models.
81
-
82
- Args:
83
- text_features: Main model text embeddings [batch_size, embed_dim]
84
- image_features: Main model image embeddings [batch_size, embed_dim]
85
- attribute_features: Concatenated color + hierarchy features [batch_size, color_dim + hierarchy_dim]
86
- color_model: Pre-trained color model for extracting color embeddings
87
- hierarchy_model: Pre-trained hierarchy model for extracting hierarchy embeddings
88
- colors: List of color strings for this batch [batch_size]
89
- hierarchies: List of hierarchy strings for this batch [batch_size]
90
- temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
91
- alignment_weight: Weight for the alignment loss component (default: 0.3)
92
-
93
- Returns:
94
- Tuple of (total_loss, metrics_dict) where metrics_dict contains detailed loss components
95
- """
96
-
97
- # Original triple contrastive loss
98
- text_features_norm = F.normalize(text_features, dim=-1)
99
- image_features_norm = F.normalize(image_features, dim=-1)
100
- attribute_features_norm = F.normalize(attribute_features, dim=-1)
101
-
102
- text_image_logits = (text_features_norm[:, config.color_emb_dim+config.hierarchy_emb_dim:] @
103
- image_features_norm[:, config.color_emb_dim+config.hierarchy_emb_dim:].T) / temperature
104
- text_attr_logits = (text_features_norm[:, :config.color_emb_dim+config.hierarchy_emb_dim] @
105
- attribute_features_norm.T) / temperature
106
- image_attr_logits = (attribute_features_norm @
107
- image_features_norm[:,:config.color_emb_dim+config.hierarchy_emb_dim].T) / temperature
108
-
109
- # Weight distribution for original loss
110
- weight_text_image = 0.7
111
- weight_attr_based = 0.15
112
-
113
- original_logits = (weight_text_image * text_image_logits +
114
- weight_attr_based * text_attr_logits +
115
- weight_attr_based * image_attr_logits)
116
-
117
- labels = torch.arange(len(text_features)).to(text_features.device)
118
- original_loss = (F.cross_entropy(original_logits, labels) +
119
- F.cross_entropy(original_logits.T, labels)) / 2
120
-
121
- # Direct alignment loss between color model and main model first 16 dims
122
- with torch.no_grad():
123
- color_embeddings = color_model.get_text_embeddings(colors)
124
- hierarchy_embeddings = hierarchy_model.get_text_embeddings(hierarchies)
125
-
126
- # Extract color dimensions from main model embeddings
127
- main_color_text = text_features[:, :config.color_emb_dim]
128
- main_color_image = image_features[:, :config.color_emb_dim]
129
-
130
- # Extract hierarchy dimensions from main model embeddings
131
- main_hierarchy_text = text_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim]
132
- main_hierarchy_image = image_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim]
133
-
134
- # Normalize for better correlation
135
- color_embeddings_norm = F.normalize(color_embeddings, dim=-1)
136
- main_color_text_norm = F.normalize(main_color_text, dim=-1)
137
- main_color_image_norm = F.normalize(main_color_image, dim=-1)
138
-
139
- hierarchy_embeddings_norm = F.normalize(hierarchy_embeddings, dim=-1)
140
- main_hierarchy_text_norm = F.normalize(main_hierarchy_text, dim=-1)
141
- main_hierarchy_image_norm = F.normalize(main_hierarchy_image, dim=-1)
142
-
143
- # Color alignment loss using MSE and cosine similarity
144
- color_text_alignment_loss = F.mse_loss(main_color_text_norm, color_embeddings_norm)
145
- color_image_alignment_loss = F.mse_loss(main_color_image_norm, color_embeddings_norm)
146
- color_text_cosine_loss = 1 - F.cosine_similarity(main_color_text_norm, color_embeddings_norm).mean()
147
- color_image_cosine_loss = 1 - F.cosine_similarity(main_color_image_norm, color_embeddings_norm).mean()
148
-
149
- # Color alignment loss
150
- color_alignment_loss = (
151
- color_text_alignment_loss + color_image_alignment_loss +
152
- color_text_cosine_loss + color_image_cosine_loss
153
- ) / 4
154
-
155
- # Hierarchy alignment loss using MSE and cosine similarity
156
- hierarchy_text_alignment_loss = F.mse_loss(main_hierarchy_text_norm, hierarchy_embeddings_norm)
157
- hierarchy_image_alignment_loss = F.mse_loss(main_hierarchy_image_norm, hierarchy_embeddings_norm)
158
- hierarchy_text_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_text_norm, hierarchy_embeddings_norm).mean()
159
- hierarchy_image_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_image_norm, hierarchy_embeddings_norm).mean()
160
-
161
- # Hierarchy alignment loss
162
- hierarchy_alignment_loss = (
163
- hierarchy_text_alignment_loss + hierarchy_image_alignment_loss +
164
- hierarchy_text_cosine_loss + hierarchy_image_cosine_loss
165
- ) / 4
166
-
167
- # Combined alignment loss
168
- alignment_loss = (color_alignment_loss + hierarchy_alignment_loss) / 2
169
-
170
- # Combine losses
171
- total_loss = (1 - alignment_weight) * original_loss + alignment_weight * alignment_loss
172
-
173
- return total_loss, {
174
- 'original_loss': original_loss.item(),
175
- 'alignment_loss': alignment_loss.item(),
176
- 'color_text_alignment': color_text_alignment_loss.item(),
177
- 'color_image_alignment': color_image_alignment_loss.item(),
178
- 'color_text_cosine': color_text_cosine_loss.item(),
179
- 'color_image_cosine': color_image_cosine_loss.item(),
180
- 'hierarchy_text_alignment': hierarchy_text_alignment_loss.item(),
181
- 'hierarchy_image_alignment': hierarchy_image_alignment_loss.item(),
182
- 'hierarchy_text_cosine': hierarchy_text_cosine_loss.item(),
183
- 'hierarchy_image_cosine': hierarchy_image_cosine_loss.item()
184
- }
185
-
186
- # -------------------------------
187
- # Training Functions
188
- # -------------------------------
189
-
190
- def train_one_epoch(model, train_loader, optimizer, feature_models, device, clip_processor, temperature=0.07):
191
- """
192
- Train the model for one epoch using triple contrastive loss.
193
-
194
- Args:
195
- model: Main CLIP model to train
196
- train_loader: DataLoader for training data
197
- optimizer: Optimizer instance
198
- feature_models: Dictionary containing color and hierarchy models
199
- device: Device to train on
200
- clip_processor: CLIP processor for text preprocessing
201
- temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
202
-
203
- Returns:
204
- Average training loss for the epoch
205
- """
206
- model.train()
207
- total_loss = 0.0
208
- num_batches = 0
209
-
210
- # Create progress bar for training
211
- pbar = tqdm(train_loader, desc="Training", leave=False)
212
-
213
- for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar):
214
- # Move data to device
215
- images = images.to(device)
216
- images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
217
-
218
- # Process text inputs
219
- text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
220
- text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
221
-
222
- # Forward pass
223
- optimizer.zero_grad()
224
- outputs = model(**text_inputs, pixel_values=images)
225
-
226
- text_features = outputs.text_embeds
227
- image_features = outputs.image_embeds
228
-
229
- # Get feature embeddings
230
- # Use exact color-name embeddings if available (new color model)
231
- if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'):
232
- color_features = feature_models[config.color_column].get_color_name_embeddings(colors)
233
- else:
234
- color_features = feature_models[config.color_column].get_text_embeddings(colors)
235
- hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy)
236
- concat_features = torch.cat((color_features, hierarchy_features), dim=1)
237
-
238
- # Calculate loss
239
- loss = triple_contrastive_loss(text_features, image_features, concat_features, temperature)
240
-
241
- # Backward pass
242
- loss.backward()
243
- optimizer.step()
244
-
245
- total_loss += loss.item()
246
- num_batches += 1
247
-
248
- # Update progress bar
249
- pbar.set_postfix({
250
- 'Loss': f'{loss.item():.4f}',
251
- 'Avg Loss': f'{total_loss/num_batches:.4f}'
252
- })
253
-
254
- return total_loss / num_batches
255
-
256
- def train_one_epoch_enhanced(model, train_loader, optimizer, feature_models, color_model, hierarchy_model,
257
- device, clip_processor, temperature=0.07, alignment_weight=0.3):
258
- """
259
- Enhanced training with direct color and hierarchy alignment loss.
260
-
261
- This function trains the model using the enhanced contrastive loss that includes
262
- direct alignment between the main model's color/hierarchy dimensions and the
263
- specialized color/hierarchy models.
264
-
265
- Args:
266
- model: Main CLIP model to train
267
- train_loader: DataLoader for training data
268
- optimizer: Optimizer instance
269
- feature_models: Dictionary containing color and hierarchy models
270
- color_model: Pre-trained color model for alignment
271
- hierarchy_model: Pre-trained hierarchy model for alignment
272
- device: Device to train on
273
- clip_processor: CLIP processor for text preprocessing
274
- temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
275
- alignment_weight: Weight for the alignment loss component (default: 0.3)
276
-
277
- Returns:
278
- Tuple of (average_loss, metrics_dict) where metrics_dict contains detailed loss components
279
- """
280
- model.train()
281
- total_loss = 0.0
282
- total_metrics = {
283
- 'original_loss': 0.0,
284
- 'alignment_loss': 0.0,
285
- 'color_text_alignment': 0.0,
286
- 'color_image_alignment': 0.0,
287
- 'color_text_cosine': 0.0,
288
- 'color_image_cosine': 0.0,
289
- 'hierarchy_text_alignment': 0.0,
290
- 'hierarchy_image_alignment': 0.0,
291
- 'hierarchy_text_cosine': 0.0,
292
- 'hierarchy_image_cosine': 0.0
293
- }
294
- num_batches = 0
295
-
296
- pbar = tqdm(train_loader, desc="Training Enhanced", leave=False)
297
-
298
- for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar):
299
- # Move data to device
300
- images = images.to(device)
301
- images = images.expand(-1, 3, -1, -1)
302
-
303
- # Process text inputs
304
- text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
305
- text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
306
-
307
- # Forward pass
308
- optimizer.zero_grad()
309
- outputs = model(**text_inputs, pixel_values=images)
310
-
311
- text_features = outputs.text_embeds
312
- image_features = outputs.image_embeds
313
-
314
- # Get feature embeddings
315
- if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'):
316
- color_features = feature_models[config.color_column].get_color_name_embeddings(colors)
317
- else:
318
- color_features = feature_models[config.color_column].get_text_embeddings(colors)
319
- hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy)
320
- concat_features = torch.cat((color_features, hierarchy_features), dim=1)
321
-
322
- # Calculate enhanced loss with hierarchy alignment
323
- loss, metrics = enhanced_contrastive_loss(
324
- text_features, image_features, concat_features,
325
- color_model, hierarchy_model, colors, hierarchy, temperature, alignment_weight
326
- )
327
-
328
- # Backward pass
329
- loss.backward()
330
- optimizer.step()
331
-
332
- total_loss += loss.item()
333
- for key, value in metrics.items():
334
- total_metrics[key] += value
335
- num_batches += 1
336
-
337
- # Update progress bar
338
- pbar.set_postfix({
339
- 'Loss': f'{loss.item():.4f}',
340
- 'Align': f'{metrics["alignment_loss"]:.4f}',
341
- 'ColCos': f'{metrics["color_text_cosine"]:.3f}',
342
- 'HierCos': f'{metrics["hierarchy_text_cosine"]:.3f}'
343
- })
344
-
345
- avg_metrics = {key: value / num_batches for key, value in total_metrics.items()}
346
- return total_loss / num_batches, avg_metrics
347
-
348
- def valid_one_epoch(model, val_loader, feature_models, device, clip_processor, temperature=0.07):
349
- """
350
- Validate the model for one epoch using triple contrastive loss.
351
-
352
- Args:
353
- model: Main CLIP model to validate
354
- val_loader: DataLoader for validation data
355
- feature_models: Dictionary containing color and hierarchy models
356
- device: Device to validate on
357
- clip_processor: CLIP processor for text preprocessing
358
- temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
359
-
360
- Returns:
361
- Average validation loss for the epoch
362
- """
363
- model.eval()
364
- total_loss = 0.0
365
- num_batches = 0
366
-
367
- # Create progress bar for validation
368
- pbar = tqdm(val_loader, desc="Validation", leave=False)
369
-
370
- with torch.no_grad():
371
- for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar):
372
- # Move data to device
373
- images = images.to(device)
374
- images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
375
-
376
- # Process text inputs
377
- text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
378
- text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
379
-
380
- # Forward pass
381
- outputs = model(**text_inputs, pixel_values=images)
382
-
383
- text_features = outputs.text_embeds
384
- image_features = outputs.image_embeds
385
-
386
- # Get feature embeddings
387
- if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'):
388
- color_features = feature_models[config.color_column].get_color_name_embeddings(colors)
389
- else:
390
- color_features = feature_models[config.color_column].get_text_embeddings(colors)
391
- hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy)
392
- concat_features = torch.cat((color_features, hierarchy_features), dim=1)
393
-
394
- # Calculate loss
395
- loss = triple_contrastive_loss(text_features, image_features, concat_features, temperature)
396
-
397
- total_loss += loss.item()
398
- num_batches += 1
399
-
400
- # Update progress bar
401
- pbar.set_postfix({
402
- 'Loss': f'{loss.item():.4f}',
403
- 'Avg Loss': f'{total_loss/num_batches:.4f}'
404
- })
405
-
406
- return total_loss / num_batches
407
-
408
- # -------------------------------
409
- # Dataset
410
- # -------------------------------
411
-
412
- class CustomDataset(Dataset):
413
- """
414
- Custom dataset for main model training.
415
-
416
- Handles loading images from local paths, extracting text descriptions,
417
- and applying appropriate transformations for training and validation.
418
- """
419
-
420
- def __init__(self, dataframe, use_local_images=True, image_size=224):
421
- """
422
- Initialize the custom dataset.
423
-
424
- Args:
425
- dataframe: DataFrame with columns for image paths, text descriptions, colors, and hierarchy labels
426
- use_local_images: Whether to use local images (default: True)
427
- image_size: Size of images after resizing (default: 224)
428
- """
429
- self.dataframe = dataframe
430
- self.use_local_images = use_local_images
431
- self.image_size = image_size
432
-
433
- # Transforms with augmentation for training
434
- self.transform = transforms.Compose([
435
- transforms.Resize((image_size, image_size)),
436
- transforms.RandomHorizontalFlip(p=0.5),
437
- transforms.RandomRotation(15),
438
- transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15),
439
- transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
440
- transforms.ToTensor(),
441
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
442
- ])
443
-
444
- # Transforms for validation (no augmentation)
445
- self.val_transform = transforms.Compose([
446
- transforms.Resize((image_size, image_size)),
447
- transforms.ToTensor(),
448
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
449
- ])
450
-
451
- self.training_mode = True
452
-
453
- def set_training_mode(self, training=True):
454
- """
455
- Switch between training and validation transforms.
456
-
457
- Args:
458
- training: If True, use training transforms with augmentation; if False, use validation transforms
459
- """
460
- self.training_mode = training
461
-
462
- def __len__(self):
463
- """Return the number of samples in the dataset."""
464
- return len(self.dataframe)
465
-
466
- def __getitem__(self, idx):
467
- """
468
- Get a sample from the dataset.
469
-
470
- Args:
471
- idx: Index of the sample
472
-
473
- Returns:
474
- Tuple of (image_tensor, description_text, color_label, hierarchy_label)
475
- """
476
- row = self.dataframe.iloc[idx]
477
-
478
- image_data = row[config.column_local_image_path]
479
- image = Image.open(image_data).convert("RGB")
480
-
481
- # Apply appropriate transform
482
- if self.training_mode:
483
- image = self.transform(image)
484
- else:
485
- image = self.val_transform(image)
486
-
487
- # Get text and labels
488
- description = row[config.text_column]
489
- color = row[config.color_column]
490
- hierarchy = row[config.hierarchy_column]
491
-
492
- return image, description, color, hierarchy
493
-
494
- # -------------------------------
495
- # Model Loading
496
- # -------------------------------
497
-
498
- def load_models():
499
- """
500
- Load color and hierarchy models from checkpoints.
501
-
502
- This function loads the pre-trained color and hierarchy models along with
503
- their tokenizers and extractors, and prepares them for use in main model training.
504
-
505
- Returns:
506
- Dictionary mapping model names to model instances:
507
- - 'color': ColorCLIP model instance
508
- - 'hierarchy': Hierarchy model instance
509
- """
510
- from color_model import ColorCLIP, Tokenizer
511
- from hierarchy_model import Model, HierarchyExtractor
512
-
513
- # Initialize tokenizer first
514
- tokenizer = Tokenizer()
515
-
516
- # Load vocabulary if available
517
- if os.path.exists(config.tokeniser_path):
518
- with open(config.tokeniser_path, 'r') as f:
519
- vocab_dict = json.load(f)
520
- tokenizer.load_vocab(vocab_dict)
521
- print(f"Tokenizer vocabulary loaded from {config.tokeniser_path}")
522
- else:
523
- print(f"Warning: {config.tokeniser_path} not found. Using default tokenizer.")
524
-
525
- # Load trained model first to get correct vocab size
526
- checkpoint = torch.load(config.config.color_model_path, map_location=config.device)
527
-
528
- # Extract vocab size from the checkpoint's embedding layer
529
- vocab_size_from_checkpoint = checkpoint['text_encoder.embedding.weight'].shape[0]
530
- print(f"Vocab size from checkpoint: {vocab_size_from_checkpoint}")
531
- print(f"Vocab size from tokenizer: {tokenizer.counter}")
532
-
533
- # Use the larger of the two to ensure compatibility
534
- vocab_size = max(vocab_size_from_checkpoint, tokenizer.counter)
535
-
536
- # Initialize model with correct vocab size
537
- color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=config.color_emb_dim).to(config.device)
538
- color_model.tokenizer = tokenizer
539
-
540
- # Load the checkpoint
541
- color_model.load_state_dict(checkpoint)
542
- print(f"Color model loaded from {config.color_model_path}")
543
-
544
- color_model.eval()
545
- color_model.name = config.color_column
546
-
547
- # Load hierarchy model
548
- hierarchy_checkpoint = torch.load(config.hierarchy_model_path, map_location=config.device)
549
- hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
550
- hierarchy_model = Model(
551
- num_hierarchy_classes=len(hierarchy_classes),
552
- embed_dim=config.hierarchy_emb_dim
553
- ).to(config.device)
554
- hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
555
-
556
- # Set up hierarchy extractor
557
- hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
558
- hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
559
- hierarchy_model.eval()
560
- hierarchy_model.name = config.hierarchy_column
561
-
562
- feature_models = {model.name: model for model in [color_model, hierarchy_model]}
563
-
564
- return feature_models
565
-
566
- # -------------------------------
567
- # Main Training Function
568
- # -------------------------------
569
-
570
- def train_model(model, train_loader, val_loader, feature_models, device,
571
- num_epochs=20, learning_rate=1e-5, temperature=0.07,
572
- save_path=config.main_model_path, use_enhanced_loss=False, alignment_weight=0.3, color_alignment_model=None):
573
- """
574
- Custom training loop using train_one_epoch and valid_one_epoch functions.
575
-
576
- This function handles the complete training process including:
577
- - Training and validation loops
578
- - Learning rate scheduling
579
- - Early stopping
580
- - Model checkpointing
581
- - Training curve visualization
582
-
583
- Args:
584
- model: Main CLIP model to train
585
- train_loader: DataLoader for training data
586
- val_loader: DataLoader for validation data
587
- feature_models: Dictionary containing color and hierarchy models
588
- device: Device to train on
589
- num_epochs: Number of training epochs (default: 20)
590
- learning_rate: Learning rate for optimizer (default: 1e-5)
591
- temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
592
- save_path: Path to save model checkpoints (default: main_model_path)
593
- use_enhanced_loss: Whether to use enhanced contrastive loss with alignment (default: False)
594
- alignment_weight: Weight for alignment loss component if using enhanced loss (default: 0.3)
595
- color_alignment_model: Optional color model for alignment (default: None, uses feature_models)
596
-
597
- Returns:
598
- Tuple of (training_losses, validation_losses) lists
599
- """
600
- model = model.to(device)
601
- optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
602
- scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
603
-
604
- train_losses = []
605
- val_losses = []
606
- best_val_loss = float('inf')
607
- patience_counter = 0
608
- patience = 5
609
-
610
- print(f"Starting training for {num_epochs} epochs...")
611
- print(f"Learning rate: {learning_rate}")
612
- print(f"Temperature: {temperature}")
613
- print(f"Device: {device}")
614
- print(f"Training samples: {len(train_loader.dataset)}")
615
- print(f"Validation samples: {len(val_loader.dataset)}")
616
- print(f"Batch size: {train_loader.batch_size}")
617
- print(f"Estimated time per epoch: ~{len(train_loader) * 2 / 60:.1f} minutes")
618
-
619
- # Create processor once for efficiency
620
- processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
621
-
622
- # Create progress bar for epochs
623
- epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", position=0)
624
-
625
- for epoch in epoch_pbar:
626
- # Update epoch progress bar
627
- epoch_pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")
628
-
629
- # Training
630
- if use_enhanced_loss:
631
- if color_alignment_model is None:
632
- color_alignment_model = feature_models[config.color_column]
633
- hierarchy_model = feature_models[config.hierarchy_column]
634
- train_loss, align_metrics = train_one_epoch_enhanced(
635
- model, train_loader, optimizer, feature_models, color_alignment_model, hierarchy_model, device, processor, temperature, alignment_weight
636
- )
637
- else:
638
- train_loss = train_one_epoch(model, train_loader, optimizer, feature_models, device, processor, temperature)
639
- align_metrics = None
640
- train_losses.append(train_loss)
641
-
642
- # Validation
643
- val_loss = valid_one_epoch(model, val_loader, feature_models, device, processor, temperature)
644
- val_losses.append(val_loss)
645
-
646
- # Learning rate scheduling
647
- scheduler.step(val_loss)
648
-
649
- # Update epoch progress bar with metrics
650
- postfix = {
651
- 'Train Loss': f'{train_loss:.4f}',
652
- 'Val Loss': f'{val_loss:.4f}',
653
- 'LR': f'{optimizer.param_groups[0]["lr"]:.2e}',
654
- 'Best Val': f'{best_val_loss:.4f}'
655
- }
656
- if align_metrics is not None:
657
- postfix.update({
658
- 'Align': f"{align_metrics['alignment_loss']:.3f}",
659
- 'ColCos': f"{align_metrics['color_text_cosine']:.3f}",
660
- 'HierCos': f"{align_metrics['hierarchy_text_cosine']:.3f}"
661
- })
662
- epoch_pbar.set_postfix(postfix)
663
-
664
- # Save best model
665
- if val_loss < best_val_loss:
666
- best_val_loss = val_loss
667
- patience_counter = 0
668
-
669
- # Save checkpoint
670
- torch.save({
671
- 'epoch': epoch,
672
- 'model_state_dict': model.state_dict(),
673
- 'optimizer_state_dict': optimizer.state_dict(),
674
- 'train_loss': train_loss,
675
- 'val_loss': val_loss,
676
- 'best_val_loss': best_val_loss,
677
- }, save_path)
678
- else:
679
- patience_counter += 1
680
-
681
- # Early stopping
682
- if patience_counter >= patience:
683
- print(f"\n🛑 Early stopping triggered after {patience_counter} epochs without improvement")
684
- break
685
-
686
- # Plot training curves
687
- plt.figure(figsize=(12, 4))
688
-
689
- plt.subplot(1, 2, 1)
690
- plt.plot(train_losses, label='Train Loss', color='blue')
691
- plt.plot(val_losses, label='Val Loss', color='red')
692
- plt.title('Training and Validation Loss')
693
- plt.xlabel('Epoch')
694
- plt.ylabel('Loss')
695
- plt.legend()
696
- plt.grid(True, alpha=0.3)
697
-
698
- plt.subplot(1, 2, 2)
699
- plt.plot(train_losses, label='Train Loss', color='blue')
700
- plt.title('Training Loss')
701
- plt.xlabel('Epoch')
702
- plt.ylabel('Loss')
703
- plt.legend()
704
- plt.grid(True, alpha=0.3)
705
-
706
- plt.tight_layout()
707
- plt.savefig('training_curves.png', dpi=300, bbox_inches='tight')
708
- plt.close()
709
-
710
- print(f"\nTraining completed!")
711
- print(f"Best validation loss: {best_val_loss:.4f}")
712
- print(f"Final model saved to: {save_path}")
713
- print(f"Training curves saved to: training_curves.png")
714
-
715
- return train_losses, val_losses
716
-
717
- # -------------------------------
718
- # Main Function
719
- # -------------------------------
720
-
721
- def main():
722
- print("="*80)
723
- print("🚀 Training of the model with alignement color and hierarchy")
724
- print("="*80)
725
-
726
- # Configuration
727
- num_epochs = 20
728
- learning_rate = 1e-5
729
- temperature = 0.07
730
- alignment_weight = 0.5
731
- batch_size = 32
732
- subset_size = 10000
733
- use_enhanced_loss = True
734
-
735
- # Load the data
736
- print(f"\n📂 Loading the data...")
737
- df = pd.read_csv(config.local_dataset_path)
738
- print(f" Data downloaded: {len(df)} samples")
739
-
740
- # filter the rows with NaN values
741
- df_clean = df.dropna(subset=[config.column_local_image_path])
742
- print(f" After filtering NaN: {len(df_clean)} samples")
743
-
744
- # Creation of datasets
745
- dataset = CustomDataset(df_clean)
746
-
747
- # Creation of a subset for a faster training
748
- print(f"\n📊 Creation of a subset of {subset_size} samples...")
749
- subset_size = min(subset_size, len(dataset))
750
- train_size = int(0.8 * subset_size)
751
- val_size = subset_size - train_size
752
-
753
- # Creation of a subset with random indexes but reproductibles
754
- np.random.seed(42)
755
- subset_indices = np.random.choice(len(dataset), subset_size, replace=False)
756
- subset_dataset = torch.utils.data.Subset(dataset, subset_indices)
757
-
758
- train_dataset, val_dataset = random_split(
759
- subset_dataset,
760
- [train_size, val_size],
761
- generator=torch.Generator().manual_seed(42)
762
- )
763
-
764
- # Creation of dataloaders
765
- train_loader = DataLoader(
766
- train_dataset,
767
- batch_size=batch_size,
768
- shuffle=True,
769
- num_workers=2,
770
- pin_memory=True if torch.cuda.is_available() else False
771
- )
772
- val_loader = DataLoader(
773
- val_dataset,
774
- batch_size=batch_size,
775
- shuffle=False,
776
- num_workers=2,
777
- pin_memory=True if torch.cuda.is_available() else False
778
- )
779
-
780
- print(f" Train: {len(train_dataset)} samples")
781
- print(f" Validation: {len(val_dataset)} samples")
782
-
783
- # Loading models
784
- print(f"\n🔧 Loading models...")
785
- feature_models = load_models()
786
-
787
- # Load or create the main model
788
- print(f"\n📦 Loading main model...")
789
- clip_model = CLIPModel_transformers.from_pretrained(
790
- 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
791
- )
792
-
793
- # Load the model
794
- if os.path.exists(config.main_model_path):
795
- print(f" Model found {config.main_model_path}")
796
- print(f" Loading checkpoint...")
797
- checkpoint = torch.load(config.main_model_path, map_location=config.device)
798
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
799
- clip_model.load_state_dict(checkpoint['model_state_dict'])
800
- print(f" ✅ Checkpoint loaded from {checkpoint.get('epoch', '?')}")
801
- else:
802
- clip_model.load_state_dict(checkpoint)
803
- print(f" ✅ Checkpoint loaded")
804
- else:
805
- print(f" New model, no checkpoint found")
806
-
807
- # Move the model on the device
808
- clip_model = clip_model.to(config.device)
809
-
810
- # Training with enhanced loss
811
- print(f"\n🎯 Beginning training...")
812
- print(f"\n" + "="*80)
813
-
814
- train_losses, val_losses = train_model(
815
- model=clip_model,
816
- train_loader=train_loader,
817
- val_loader=val_loader,
818
- feature_models=feature_models,
819
- device=config.device,
820
- num_epochs=num_epochs,
821
- learning_rate=learning_rate,
822
- temperature=temperature,
823
- save_path=config.main_model_path,
824
- use_enhanced_loss=use_enhanced_loss,
825
- alignment_weight=alignment_weight,
826
- color_alignment_model=feature_models[config.color_column]
827
- )
828
-
829
- print("\n" + "="*80)
830
- print("✅ Traning finished!")
831
- print(f" Modèle sauvegardé: {config.main_model_path}")
832
- print(f" Training curves: training_curves.png")
833
- print("\n📊 Final results:")
834
- print(f" Last train loss: {train_losses[-1]:.4f}")
835
- print(f" Last validation loss: {val_losses[-1]:.4f}")
836
- print(f" Best loss validation: {min(val_losses):.4f}")
837
- print("="*80)
838
-
839
- if __name__ == "__main__":
840
- main()