File size: 11,297 Bytes
9c2cc41 51820f5 9c2cc41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 |
#!/usr/bin/env python3
"""
Training script using best hyperparameters from Optuna optimization.
This script trains the model with the optimized hyperparameters and additional
regularization techniques to reduce overfitting.
"""
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
from transformers import CLIPModel as CLIPModel_transformers
import warnings
import config
from main_model import CustomDataset, load_models, train_model
warnings.filterwarnings("ignore")
def train_with_best_params(
learning_rate=1.42e-05, # Best from Optuna
temperature=0.0503, # Best from Optuna
alignment_weight=0.5639, # Best from Optuna
weight_decay=2.76e-05, # Best from Optuna
num_epochs=20,
batch_size=32,
subset_size=20000, # Increased for better generalization
use_early_stopping=True,
patience=7
):
"""
Train model with best hyperparameters and anti-overfitting techniques.
Args:
learning_rate: Learning rate for optimizer (from Optuna)
temperature: Temperature for contrastive loss (from Optuna)
alignment_weight: Weight for alignment loss (from Optuna)
weight_decay: L2 regularization weight (from Optuna)
num_epochs: Number of training epochs
batch_size: Batch size for training
subset_size: Size of dataset subset
use_early_stopping: Whether to use early stopping
patience: Patience for early stopping
"""
print("="*80)
print("๐ Training with Optimized Hyperparameters")
print("="*80)
print(f"\n๐ Configuration:")
print(f" Learning rate: {learning_rate:.2e}")
print(f" Temperature: {temperature:.4f}")
print(f" Alignment weight: {alignment_weight:.4f}")
print(f" Weight decay: {weight_decay:.2e}")
print(f" Num epochs: {num_epochs}")
print(f" Batch size: {batch_size}")
print(f" Subset size: {subset_size}")
print(f" Early stopping: {use_early_stopping} (patience={patience})")
# Load data
print(f"\n๐ Loading data...")
df = pd.read_csv(config.local_dataset_path)
df_clean = df.dropna(subset=[config.column_local_image_path])
print(f" Total samples: {len(df_clean)}")
# Create dataset
dataset = CustomDataset(df_clean)
# Create subset
subset_size = min(subset_size, len(dataset))
train_size = int(0.8 * subset_size)
val_size = subset_size - train_size
np.random.seed(42)
subset_indices = np.random.choice(len(dataset), subset_size, replace=False)
subset_dataset = torch.utils.data.Subset(dataset, subset_indices)
train_dataset, val_dataset = random_split(
subset_dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(42)
)
# Create data loaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=True if torch.cuda.is_available() else False
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=2,
pin_memory=True if torch.cuda.is_available() else False
)
print(f" Train: {len(train_dataset)} samples")
print(f" Val: {len(val_dataset)} samples")
# Load feature models
print(f"\n๐ง Loading feature models...")
feature_models = load_models()
# Load main model
print(f"\n๐ฆ Loading main model...")
clip_model = CLIPModel_transformers.from_pretrained(
'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
)
# Frozen reference CLIP for text-space regularization (helps cross-domain generalization)
reference_clip = CLIPModel_transformers.from_pretrained(
'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
)
# Optionally load previous checkpoint
if os.path.exists(config.main_model_path):
user_input = input(f"\nโ ๏ธ Found existing checkpoint at {config.main_model_path}. Load it? (y/n): ")
if user_input.lower() == 'y':
print(f" Loading checkpoint...")
checkpoint = torch.load(config.main_model_path, map_location=config.device)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
clip_model.load_state_dict(checkpoint['model_state_dict'])
print(f" โ
Checkpoint loaded from epoch {checkpoint.get('epoch', '?')}")
else:
clip_model.load_state_dict(checkpoint)
print(f" โ
Checkpoint loaded")
else:
print(f" Starting from pretrained model")
else:
print(f" Starting from pretrained model")
clip_model = clip_model.to(config.device)
reference_clip = reference_clip.to(config.device)
reference_clip.eval()
for param in reference_clip.parameters():
param.requires_grad = False
# Train model with custom training function that uses weight_decay
print(f"\n๐ฏ Starting training...")
print(f"\n" + "="*80)
# We need to modify the train_model function to accept weight_decay
# For now, we'll use a modified version
model = clip_model.to(config.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', patience=3, factor=0.5
)
from transformers import CLIPProcessor
from tqdm import tqdm
from main_model import train_one_epoch, valid_one_epoch
import matplotlib.pyplot as plt
train_losses = []
val_losses = []
best_val_loss = float('inf')
patience_counter = 0
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", position=0)
for epoch in epoch_pbar:
epoch_pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")
# Training
color_model = feature_models[config.color_column]
hierarchy_model = feature_models[config.hierarchy_column]
train_loss, align_metrics = train_one_epoch(
model, train_loader, optimizer, feature_models, color_model, hierarchy_model,
config.device, processor, temperature, alignment_weight,
reference_model=reference_clip, reference_weight=0.1
)
train_losses.append(train_loss)
# Validation
val_loss = valid_one_epoch(
model, val_loader, feature_models, config.device, processor,
temperature=temperature, alignment_weight=alignment_weight,
reference_model=reference_clip, reference_weight=0.1
)
val_losses.append(val_loss)
# Learning rate scheduling
scheduler.step(val_loss)
# Update progress bar
epoch_pbar.set_postfix({
'Train Loss': f'{train_loss:.4f}',
'Val Loss': f'{val_loss:.4f}',
'LR': f'{optimizer.param_groups[0]["lr"]:.2e}',
'Best Val': f'{best_val_loss:.4f}'
})
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
# Save checkpoint
save_path = config.main_model_path.replace('.pt', '_best_optuna.pt')
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
'best_val_loss': best_val_loss,
'hyperparameters': {
'learning_rate': learning_rate,
'temperature': temperature,
'alignment_weight': alignment_weight,
'weight_decay': weight_decay,
}
}, save_path)
print(f"\n๐พ Best model saved at epoch {epoch+1}")
else:
patience_counter += 1
# Early stopping
if use_early_stopping and patience_counter >= patience:
print(f"\n๐ Early stopping triggered after {patience_counter} epochs without improvement")
break
# Plot training curves
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', color='blue', linewidth=2)
plt.plot(val_losses, label='Val Loss', color='red', linewidth=2)
plt.title('Training and Validation Loss (Optimized)', fontsize=14, fontweight='bold')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.subplot(1, 2, 2)
gap = [train_losses[i] - val_losses[i] for i in range(len(train_losses))]
plt.plot(gap, label='Train-Val Gap', color='purple', linewidth=2)
plt.axhline(y=0, color='black', linestyle='--', alpha=0.3)
plt.title('Overfitting Gap (Optimized)', fontsize=14, fontweight='bold')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Train Loss - Val Loss', fontsize=12)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('training_curves_optimized.png', dpi=300, bbox_inches='tight')
plt.close()
print("\n" + "="*80)
print("โ
Training completed!")
print(f" Best model: {save_path}")
print(f" Training curves: training_curves_optimized.png")
print("\n๐ Final results:")
print(f" Last train loss: {train_losses[-1]:.4f}")
print(f" Last validation loss: {val_losses[-1]:.4f}")
print(f" Best validation loss: {best_val_loss:.4f}")
print(f" Overfitting gap: {train_losses[-1] - val_losses[-1]:.4f}")
print("="*80)
return train_losses, val_losses
def main():
"""
Main function - Uses best parameters from Optuna optimization.
"""
print("\n" + "="*80)
print("๐ Training with Best Optuna Hyperparameters")
print("="*80)
# Best hyperparameters from Optuna optimization (Trial 29 - Best validation loss: 0.1129)
# Source: optuna_results.txt
BEST_PARAMS = {
'learning_rate': 1.42e-05, # From Optuna (best trial)
'temperature': 0.0503, # From Optuna (best trial)
'alignment_weight': 0.5639, # From Optuna (best trial)
'weight_decay': 2.76e-05, # From Optuna (best trial)
'num_epochs': 20,
'batch_size': 32,
'subset_size': 20000, # Increased for better generalization
'patience': 7
}
print(f"\nโ
Using optimized hyperparameters from Optuna:")
print(f" Learning rate: {BEST_PARAMS['learning_rate']:.2e}")
print(f" Temperature: {BEST_PARAMS['temperature']:.4f}")
print(f" Alignment weight: {BEST_PARAMS['alignment_weight']:.4f}")
print(f" Weight decay: {BEST_PARAMS['weight_decay']:.2e}")
print(f" Expected validation loss: ~0.1129 (from Optuna)\n")
train_with_best_params(**BEST_PARAMS)
if __name__ == "__main__":
main()
|