gap-clip / train_main_model.py
Leacb4's picture
Upload train_main_model.py with huggingface_hub
51820f5 verified
#!/usr/bin/env python3
"""
Training script using best hyperparameters from Optuna optimization.
This script trains the model with the optimized hyperparameters and additional
regularization techniques to reduce overfitting.
"""
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
from transformers import CLIPModel as CLIPModel_transformers
import warnings
import config
from main_model import CustomDataset, load_models, train_model
warnings.filterwarnings("ignore")
def train_with_best_params(
learning_rate=1.42e-05, # Best from Optuna
temperature=0.0503, # Best from Optuna
alignment_weight=0.5639, # Best from Optuna
weight_decay=2.76e-05, # Best from Optuna
num_epochs=20,
batch_size=32,
subset_size=20000, # Increased for better generalization
use_early_stopping=True,
patience=7
):
"""
Train model with best hyperparameters and anti-overfitting techniques.
Args:
learning_rate: Learning rate for optimizer (from Optuna)
temperature: Temperature for contrastive loss (from Optuna)
alignment_weight: Weight for alignment loss (from Optuna)
weight_decay: L2 regularization weight (from Optuna)
num_epochs: Number of training epochs
batch_size: Batch size for training
subset_size: Size of dataset subset
use_early_stopping: Whether to use early stopping
patience: Patience for early stopping
"""
print("="*80)
print("🚀 Training with Optimized Hyperparameters")
print("="*80)
print(f"\n📋 Configuration:")
print(f" Learning rate: {learning_rate:.2e}")
print(f" Temperature: {temperature:.4f}")
print(f" Alignment weight: {alignment_weight:.4f}")
print(f" Weight decay: {weight_decay:.2e}")
print(f" Num epochs: {num_epochs}")
print(f" Batch size: {batch_size}")
print(f" Subset size: {subset_size}")
print(f" Early stopping: {use_early_stopping} (patience={patience})")
# Load data
print(f"\n📂 Loading data...")
df = pd.read_csv(config.local_dataset_path)
df_clean = df.dropna(subset=[config.column_local_image_path])
print(f" Total samples: {len(df_clean)}")
# Create dataset
dataset = CustomDataset(df_clean)
# Create subset
subset_size = min(subset_size, len(dataset))
train_size = int(0.8 * subset_size)
val_size = subset_size - train_size
np.random.seed(42)
subset_indices = np.random.choice(len(dataset), subset_size, replace=False)
subset_dataset = torch.utils.data.Subset(dataset, subset_indices)
train_dataset, val_dataset = random_split(
subset_dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(42)
)
# Create data loaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=True if torch.cuda.is_available() else False
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=2,
pin_memory=True if torch.cuda.is_available() else False
)
print(f" Train: {len(train_dataset)} samples")
print(f" Val: {len(val_dataset)} samples")
# Load feature models
print(f"\n🔧 Loading feature models...")
feature_models = load_models()
# Load main model
print(f"\n📦 Loading main model...")
clip_model = CLIPModel_transformers.from_pretrained(
'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
)
# Frozen reference CLIP for text-space regularization (helps cross-domain generalization)
reference_clip = CLIPModel_transformers.from_pretrained(
'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
)
# Optionally load previous checkpoint
if os.path.exists(config.main_model_path):
user_input = input(f"\n⚠️ Found existing checkpoint at {config.main_model_path}. Load it? (y/n): ")
if user_input.lower() == 'y':
print(f" Loading checkpoint...")
checkpoint = torch.load(config.main_model_path, map_location=config.device)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
clip_model.load_state_dict(checkpoint['model_state_dict'])
print(f" ✅ Checkpoint loaded from epoch {checkpoint.get('epoch', '?')}")
else:
clip_model.load_state_dict(checkpoint)
print(f" ✅ Checkpoint loaded")
else:
print(f" Starting from pretrained model")
else:
print(f" Starting from pretrained model")
clip_model = clip_model.to(config.device)
reference_clip = reference_clip.to(config.device)
reference_clip.eval()
for param in reference_clip.parameters():
param.requires_grad = False
# Train model with custom training function that uses weight_decay
print(f"\n🎯 Starting training...")
print(f"\n" + "="*80)
# We need to modify the train_model function to accept weight_decay
# For now, we'll use a modified version
model = clip_model.to(config.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', patience=3, factor=0.5
)
from transformers import CLIPProcessor
from tqdm import tqdm
from main_model import train_one_epoch, valid_one_epoch
import matplotlib.pyplot as plt
train_losses = []
val_losses = []
best_val_loss = float('inf')
patience_counter = 0
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", position=0)
for epoch in epoch_pbar:
epoch_pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")
# Training
color_model = feature_models[config.color_column]
hierarchy_model = feature_models[config.hierarchy_column]
train_loss, align_metrics = train_one_epoch(
model, train_loader, optimizer, feature_models, color_model, hierarchy_model,
config.device, processor, temperature, alignment_weight,
reference_model=reference_clip, reference_weight=0.1
)
train_losses.append(train_loss)
# Validation
val_loss = valid_one_epoch(
model, val_loader, feature_models, config.device, processor,
temperature=temperature, alignment_weight=alignment_weight,
reference_model=reference_clip, reference_weight=0.1
)
val_losses.append(val_loss)
# Learning rate scheduling
scheduler.step(val_loss)
# Update progress bar
epoch_pbar.set_postfix({
'Train Loss': f'{train_loss:.4f}',
'Val Loss': f'{val_loss:.4f}',
'LR': f'{optimizer.param_groups[0]["lr"]:.2e}',
'Best Val': f'{best_val_loss:.4f}'
})
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
# Save checkpoint
save_path = config.main_model_path.replace('.pt', '_best_optuna.pt')
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
'best_val_loss': best_val_loss,
'hyperparameters': {
'learning_rate': learning_rate,
'temperature': temperature,
'alignment_weight': alignment_weight,
'weight_decay': weight_decay,
}
}, save_path)
print(f"\n💾 Best model saved at epoch {epoch+1}")
else:
patience_counter += 1
# Early stopping
if use_early_stopping and patience_counter >= patience:
print(f"\n🛑 Early stopping triggered after {patience_counter} epochs without improvement")
break
# Plot training curves
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', color='blue', linewidth=2)
plt.plot(val_losses, label='Val Loss', color='red', linewidth=2)
plt.title('Training and Validation Loss (Optimized)', fontsize=14, fontweight='bold')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.subplot(1, 2, 2)
gap = [train_losses[i] - val_losses[i] for i in range(len(train_losses))]
plt.plot(gap, label='Train-Val Gap', color='purple', linewidth=2)
plt.axhline(y=0, color='black', linestyle='--', alpha=0.3)
plt.title('Overfitting Gap (Optimized)', fontsize=14, fontweight='bold')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Train Loss - Val Loss', fontsize=12)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('training_curves_optimized.png', dpi=300, bbox_inches='tight')
plt.close()
print("\n" + "="*80)
print("✅ Training completed!")
print(f" Best model: {save_path}")
print(f" Training curves: training_curves_optimized.png")
print("\n📊 Final results:")
print(f" Last train loss: {train_losses[-1]:.4f}")
print(f" Last validation loss: {val_losses[-1]:.4f}")
print(f" Best validation loss: {best_val_loss:.4f}")
print(f" Overfitting gap: {train_losses[-1] - val_losses[-1]:.4f}")
print("="*80)
return train_losses, val_losses
def main():
"""
Main function - Uses best parameters from Optuna optimization.
"""
print("\n" + "="*80)
print("🚀 Training with Best Optuna Hyperparameters")
print("="*80)
# Best hyperparameters from Optuna optimization (Trial 29 - Best validation loss: 0.1129)
# Source: optuna_results.txt
BEST_PARAMS = {
'learning_rate': 1.42e-05, # From Optuna (best trial)
'temperature': 0.0503, # From Optuna (best trial)
'alignment_weight': 0.5639, # From Optuna (best trial)
'weight_decay': 2.76e-05, # From Optuna (best trial)
'num_epochs': 20,
'batch_size': 32,
'subset_size': 20000, # Increased for better generalization
'patience': 7
}
print(f"\n✅ Using optimized hyperparameters from Optuna:")
print(f" Learning rate: {BEST_PARAMS['learning_rate']:.2e}")
print(f" Temperature: {BEST_PARAMS['temperature']:.4f}")
print(f" Alignment weight: {BEST_PARAMS['alignment_weight']:.4f}")
print(f" Weight decay: {BEST_PARAMS['weight_decay']:.2e}")
print(f" Expected validation loss: ~0.1129 (from Optuna)\n")
train_with_best_params(**BEST_PARAMS)
if __name__ == "__main__":
main()