Leacb4 commited on
Commit
9c2cc41
·
verified ·
1 Parent(s): ae8a3ca

Upload train_main_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_main_model.py +298 -0
train_main_model.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Training script using best hyperparameters from Optuna optimization.
4
+ This script trains the model with the optimized hyperparameters and additional
5
+ regularization techniques to reduce overfitting.
6
+ """
7
+
8
+ import os
9
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
10
+
11
+ import pandas as pd
12
+ import numpy as np
13
+ import torch
14
+ from torch.utils.data import DataLoader, random_split
15
+ from transformers import CLIPModel as CLIPModel_transformers
16
+ import warnings
17
+ import config
18
+ from main_model import CustomDataset, load_models, train_model
19
+
20
+ warnings.filterwarnings("ignore")
21
+
22
+ def train_with_best_params(
23
+ learning_rate=1.42e-05, # Best from Optuna
24
+ temperature=0.0503, # Best from Optuna
25
+ alignment_weight=0.5639, # Best from Optuna
26
+ weight_decay=2.76e-05, # Best from Optuna
27
+ num_epochs=20,
28
+ batch_size=32,
29
+ subset_size=20000, # Increased for better generalization
30
+ use_early_stopping=True,
31
+ patience=7
32
+ ):
33
+ """
34
+ Train model with best hyperparameters and anti-overfitting techniques.
35
+
36
+ Args:
37
+ learning_rate: Learning rate for optimizer (from Optuna)
38
+ temperature: Temperature for contrastive loss (from Optuna)
39
+ alignment_weight: Weight for alignment loss (from Optuna)
40
+ weight_decay: L2 regularization weight (from Optuna)
41
+ num_epochs: Number of training epochs
42
+ batch_size: Batch size for training
43
+ subset_size: Size of dataset subset
44
+ use_early_stopping: Whether to use early stopping
45
+ patience: Patience for early stopping
46
+ """
47
+ print("="*80)
48
+ print("🚀 Training with Optimized Hyperparameters")
49
+ print("="*80)
50
+
51
+ print(f"\n📋 Configuration:")
52
+ print(f" Learning rate: {learning_rate:.2e}")
53
+ print(f" Temperature: {temperature:.4f}")
54
+ print(f" Alignment weight: {alignment_weight:.4f}")
55
+ print(f" Weight decay: {weight_decay:.2e}")
56
+ print(f" Num epochs: {num_epochs}")
57
+ print(f" Batch size: {batch_size}")
58
+ print(f" Subset size: {subset_size}")
59
+ print(f" Early stopping: {use_early_stopping} (patience={patience})")
60
+
61
+ # Load data
62
+ print(f"\n📂 Loading data...")
63
+ df = pd.read_csv(config.local_dataset_path)
64
+ df_clean = df.dropna(subset=[config.column_local_image_path])
65
+ print(f" Total samples: {len(df_clean)}")
66
+
67
+ # Create dataset
68
+ dataset = CustomDataset(df_clean)
69
+
70
+ # Create subset
71
+ subset_size = min(subset_size, len(dataset))
72
+ train_size = int(0.8 * subset_size)
73
+ val_size = subset_size - train_size
74
+
75
+ np.random.seed(42)
76
+ subset_indices = np.random.choice(len(dataset), subset_size, replace=False)
77
+ subset_dataset = torch.utils.data.Subset(dataset, subset_indices)
78
+
79
+ train_dataset, val_dataset = random_split(
80
+ subset_dataset,
81
+ [train_size, val_size],
82
+ generator=torch.Generator().manual_seed(42)
83
+ )
84
+
85
+ # Create data loaders
86
+ train_loader = DataLoader(
87
+ train_dataset,
88
+ batch_size=batch_size,
89
+ shuffle=True,
90
+ num_workers=2,
91
+ pin_memory=True if torch.cuda.is_available() else False
92
+ )
93
+ val_loader = DataLoader(
94
+ val_dataset,
95
+ batch_size=batch_size,
96
+ shuffle=False,
97
+ num_workers=2,
98
+ pin_memory=True if torch.cuda.is_available() else False
99
+ )
100
+
101
+ print(f" Train: {len(train_dataset)} samples")
102
+ print(f" Val: {len(val_dataset)} samples")
103
+
104
+ # Load feature models
105
+ print(f"\n🔧 Loading feature models...")
106
+ feature_models = load_models()
107
+
108
+ # Load main model
109
+ print(f"\n📦 Loading main model...")
110
+ clip_model = CLIPModel_transformers.from_pretrained(
111
+ 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
112
+ )
113
+ # Frozen reference CLIP for text-space regularization (helps cross-domain generalization)
114
+ reference_clip = CLIPModel_transformers.from_pretrained(
115
+ 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
116
+ )
117
+
118
+ # Optionally load previous checkpoint
119
+ if os.path.exists(config.main_model_path):
120
+ user_input = input(f"\n⚠️ Found existing checkpoint at {config.main_model_path}. Load it? (y/n): ")
121
+ if user_input.lower() == 'y':
122
+ print(f" Loading checkpoint...")
123
+ checkpoint = torch.load(config.main_model_path, map_location=config.device)
124
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
125
+ clip_model.load_state_dict(checkpoint['model_state_dict'])
126
+ print(f" ✅ Checkpoint loaded from epoch {checkpoint.get('epoch', '?')}")
127
+ else:
128
+ clip_model.load_state_dict(checkpoint)
129
+ print(f" ✅ Checkpoint loaded")
130
+ else:
131
+ print(f" Starting from pretrained model")
132
+ else:
133
+ print(f" Starting from pretrained model")
134
+
135
+ clip_model = clip_model.to(config.device)
136
+ reference_clip = reference_clip.to(config.device)
137
+ reference_clip.eval()
138
+ for param in reference_clip.parameters():
139
+ param.requires_grad = False
140
+
141
+ # Train model with custom training function that uses weight_decay
142
+ print(f"\n🎯 Starting training...")
143
+ print(f"\n" + "="*80)
144
+
145
+ # We need to modify the train_model function to accept weight_decay
146
+ # For now, we'll use a modified version
147
+ model = clip_model.to(config.device)
148
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
149
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
150
+ optimizer, mode='min', patience=3, factor=0.5
151
+ )
152
+
153
+ from transformers import CLIPProcessor
154
+ from tqdm import tqdm
155
+ from main_model import train_one_epoch, valid_one_epoch
156
+ import matplotlib.pyplot as plt
157
+
158
+ train_losses = []
159
+ val_losses = []
160
+ best_val_loss = float('inf')
161
+ patience_counter = 0
162
+
163
+ processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
164
+ epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", position=0)
165
+
166
+ for epoch in epoch_pbar:
167
+ epoch_pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")
168
+
169
+ # Training
170
+ color_model = feature_models[config.color_column]
171
+ hierarchy_model = feature_models[config.hierarchy_column]
172
+ train_loss, align_metrics = train_one_epoch_enhanced(
173
+ model, train_loader, optimizer, feature_models, color_model, hierarchy_model,
174
+ config.device, processor, temperature, alignment_weight,
175
+ reference_model=reference_clip, reference_weight=0.1
176
+ )
177
+ train_losses.append(train_loss)
178
+
179
+ # Validation
180
+ val_loss = valid_one_epoch(
181
+ model, val_loader, feature_models, config.device, processor,
182
+ temperature=temperature, alignment_weight=alignment_weight,
183
+ reference_model=reference_clip, reference_weight=0.1
184
+ )
185
+ val_losses.append(val_loss)
186
+
187
+ # Learning rate scheduling
188
+ scheduler.step(val_loss)
189
+
190
+ # Update progress bar
191
+ epoch_pbar.set_postfix({
192
+ 'Train Loss': f'{train_loss:.4f}',
193
+ 'Val Loss': f'{val_loss:.4f}',
194
+ 'LR': f'{optimizer.param_groups[0]["lr"]:.2e}',
195
+ 'Best Val': f'{best_val_loss:.4f}'
196
+ })
197
+
198
+ # Save best model
199
+ if val_loss < best_val_loss:
200
+ best_val_loss = val_loss
201
+ patience_counter = 0
202
+
203
+ # Save checkpoint
204
+ save_path = config.main_model_path.replace('.pt', '_best_optuna.pt')
205
+ torch.save({
206
+ 'epoch': epoch,
207
+ 'model_state_dict': model.state_dict(),
208
+ 'optimizer_state_dict': optimizer.state_dict(),
209
+ 'train_loss': train_loss,
210
+ 'val_loss': val_loss,
211
+ 'best_val_loss': best_val_loss,
212
+ 'hyperparameters': {
213
+ 'learning_rate': learning_rate,
214
+ 'temperature': temperature,
215
+ 'alignment_weight': alignment_weight,
216
+ 'weight_decay': weight_decay,
217
+ }
218
+ }, save_path)
219
+ print(f"\n💾 Best model saved at epoch {epoch+1}")
220
+ else:
221
+ patience_counter += 1
222
+
223
+ # Early stopping
224
+ if use_early_stopping and patience_counter >= patience:
225
+ print(f"\n🛑 Early stopping triggered after {patience_counter} epochs without improvement")
226
+ break
227
+
228
+ # Plot training curves
229
+ plt.figure(figsize=(12, 5))
230
+
231
+ plt.subplot(1, 2, 1)
232
+ plt.plot(train_losses, label='Train Loss', color='blue', linewidth=2)
233
+ plt.plot(val_losses, label='Val Loss', color='red', linewidth=2)
234
+ plt.title('Training and Validation Loss (Optimized)', fontsize=14, fontweight='bold')
235
+ plt.xlabel('Epoch', fontsize=12)
236
+ plt.ylabel('Loss', fontsize=12)
237
+ plt.legend(fontsize=11)
238
+ plt.grid(True, alpha=0.3)
239
+
240
+ plt.subplot(1, 2, 2)
241
+ gap = [train_losses[i] - val_losses[i] for i in range(len(train_losses))]
242
+ plt.plot(gap, label='Train-Val Gap', color='purple', linewidth=2)
243
+ plt.axhline(y=0, color='black', linestyle='--', alpha=0.3)
244
+ plt.title('Overfitting Gap (Optimized)', fontsize=14, fontweight='bold')
245
+ plt.xlabel('Epoch', fontsize=12)
246
+ plt.ylabel('Train Loss - Val Loss', fontsize=12)
247
+ plt.legend(fontsize=11)
248
+ plt.grid(True, alpha=0.3)
249
+
250
+ plt.tight_layout()
251
+ plt.savefig('training_curves_optimized.png', dpi=300, bbox_inches='tight')
252
+ plt.close()
253
+
254
+ print("\n" + "="*80)
255
+ print("✅ Training completed!")
256
+ print(f" Best model: {save_path}")
257
+ print(f" Training curves: training_curves_optimized.png")
258
+ print("\n📊 Final results:")
259
+ print(f" Last train loss: {train_losses[-1]:.4f}")
260
+ print(f" Last validation loss: {val_losses[-1]:.4f}")
261
+ print(f" Best validation loss: {best_val_loss:.4f}")
262
+ print(f" Overfitting gap: {train_losses[-1] - val_losses[-1]:.4f}")
263
+ print("="*80)
264
+
265
+ return train_losses, val_losses
266
+
267
+ def main():
268
+ """
269
+ Main function - Uses best parameters from Optuna optimization.
270
+ """
271
+ print("\n" + "="*80)
272
+ print("🚀 Training with Best Optuna Hyperparameters")
273
+ print("="*80)
274
+
275
+ # Best hyperparameters from Optuna optimization (Trial 29 - Best validation loss: 0.1129)
276
+ # Source: optuna_results.txt
277
+ BEST_PARAMS = {
278
+ 'learning_rate': 1.42e-05, # From Optuna (best trial)
279
+ 'temperature': 0.0503, # From Optuna (best trial)
280
+ 'alignment_weight': 0.5639, # From Optuna (best trial)
281
+ 'weight_decay': 2.76e-05, # From Optuna (best trial)
282
+ 'num_epochs': 20,
283
+ 'batch_size': 32,
284
+ 'subset_size': 20000, # Increased for better generalization
285
+ 'patience': 7
286
+ }
287
+
288
+ print(f"\n✅ Using optimized hyperparameters from Optuna:")
289
+ print(f" Learning rate: {BEST_PARAMS['learning_rate']:.2e}")
290
+ print(f" Temperature: {BEST_PARAMS['temperature']:.4f}")
291
+ print(f" Alignment weight: {BEST_PARAMS['alignment_weight']:.4f}")
292
+ print(f" Weight decay: {BEST_PARAMS['weight_decay']:.2e}")
293
+ print(f" Expected validation loss: ~0.1129 (from Optuna)\n")
294
+
295
+ train_with_best_params(**BEST_PARAMS)
296
+
297
+ if __name__ == "__main__":
298
+ main()