#!/usr/bin/env python3 """ Training script using best hyperparameters from Optuna optimization. This script trains the model with the optimized hyperparameters and additional regularization techniques to reduce overfitting. """ import os os.environ["TOKENIZERS_PARALLELISM"] = "false" import pandas as pd import numpy as np import torch from torch.utils.data import DataLoader, random_split from transformers import CLIPModel as CLIPModel_transformers import warnings import config from main_model import CustomDataset, load_models, train_model warnings.filterwarnings("ignore") def train_with_best_params( learning_rate=1.42e-05, # Best from Optuna temperature=0.0503, # Best from Optuna alignment_weight=0.5639, # Best from Optuna weight_decay=2.76e-05, # Best from Optuna num_epochs=20, batch_size=32, subset_size=20000, # Increased for better generalization use_early_stopping=True, patience=7 ): """ Train model with best hyperparameters and anti-overfitting techniques. Args: learning_rate: Learning rate for optimizer (from Optuna) temperature: Temperature for contrastive loss (from Optuna) alignment_weight: Weight for alignment loss (from Optuna) weight_decay: L2 regularization weight (from Optuna) num_epochs: Number of training epochs batch_size: Batch size for training subset_size: Size of dataset subset use_early_stopping: Whether to use early stopping patience: Patience for early stopping """ print("="*80) print("šŸš€ Training with Optimized Hyperparameters") print("="*80) print(f"\nšŸ“‹ Configuration:") print(f" Learning rate: {learning_rate:.2e}") print(f" Temperature: {temperature:.4f}") print(f" Alignment weight: {alignment_weight:.4f}") print(f" Weight decay: {weight_decay:.2e}") print(f" Num epochs: {num_epochs}") print(f" Batch size: {batch_size}") print(f" Subset size: {subset_size}") print(f" Early stopping: {use_early_stopping} (patience={patience})") # Load data print(f"\nšŸ“‚ Loading data...") df = pd.read_csv(config.local_dataset_path) df_clean = df.dropna(subset=[config.column_local_image_path]) print(f" Total samples: {len(df_clean)}") # Create dataset dataset = CustomDataset(df_clean) # Create subset subset_size = min(subset_size, len(dataset)) train_size = int(0.8 * subset_size) val_size = subset_size - train_size np.random.seed(42) subset_indices = np.random.choice(len(dataset), subset_size, replace=False) subset_dataset = torch.utils.data.Subset(dataset, subset_indices) train_dataset, val_dataset = random_split( subset_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42) ) # Create data loaders train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True if torch.cuda.is_available() else False ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True if torch.cuda.is_available() else False ) print(f" Train: {len(train_dataset)} samples") print(f" Val: {len(val_dataset)} samples") # Load feature models print(f"\nšŸ”§ Loading feature models...") feature_models = load_models() # Load main model print(f"\nšŸ“¦ Loading main model...") clip_model = CLIPModel_transformers.from_pretrained( 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' ) # Frozen reference CLIP for text-space regularization (helps cross-domain generalization) reference_clip = CLIPModel_transformers.from_pretrained( 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' ) # Optionally load previous checkpoint if os.path.exists(config.main_model_path): user_input = input(f"\nāš ļø Found existing checkpoint at {config.main_model_path}. Load it? (y/n): ") if user_input.lower() == 'y': print(f" Loading checkpoint...") checkpoint = torch.load(config.main_model_path, map_location=config.device) if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: clip_model.load_state_dict(checkpoint['model_state_dict']) print(f" āœ… Checkpoint loaded from epoch {checkpoint.get('epoch', '?')}") else: clip_model.load_state_dict(checkpoint) print(f" āœ… Checkpoint loaded") else: print(f" Starting from pretrained model") else: print(f" Starting from pretrained model") clip_model = clip_model.to(config.device) reference_clip = reference_clip.to(config.device) reference_clip.eval() for param in reference_clip.parameters(): param.requires_grad = False # Train model with custom training function that uses weight_decay print(f"\nšŸŽÆ Starting training...") print(f"\n" + "="*80) # We need to modify the train_model function to accept weight_decay # For now, we'll use a modified version model = clip_model.to(config.device) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', patience=3, factor=0.5 ) from transformers import CLIPProcessor from tqdm import tqdm from main_model import train_one_epoch, valid_one_epoch import matplotlib.pyplot as plt train_losses = [] val_losses = [] best_val_loss = float('inf') patience_counter = 0 processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", position=0) for epoch in epoch_pbar: epoch_pbar.set_description(f"Epoch {epoch+1}/{num_epochs}") # Training color_model = feature_models[config.color_column] hierarchy_model = feature_models[config.hierarchy_column] train_loss, align_metrics = train_one_epoch( model, train_loader, optimizer, feature_models, color_model, hierarchy_model, config.device, processor, temperature, alignment_weight, reference_model=reference_clip, reference_weight=0.1 ) train_losses.append(train_loss) # Validation val_loss = valid_one_epoch( model, val_loader, feature_models, config.device, processor, temperature=temperature, alignment_weight=alignment_weight, reference_model=reference_clip, reference_weight=0.1 ) val_losses.append(val_loss) # Learning rate scheduling scheduler.step(val_loss) # Update progress bar epoch_pbar.set_postfix({ 'Train Loss': f'{train_loss:.4f}', 'Val Loss': f'{val_loss:.4f}', 'LR': f'{optimizer.param_groups[0]["lr"]:.2e}', 'Best Val': f'{best_val_loss:.4f}' }) # Save best model if val_loss < best_val_loss: best_val_loss = val_loss patience_counter = 0 # Save checkpoint save_path = config.main_model_path.replace('.pt', '_best_optuna.pt') torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_loss': train_loss, 'val_loss': val_loss, 'best_val_loss': best_val_loss, 'hyperparameters': { 'learning_rate': learning_rate, 'temperature': temperature, 'alignment_weight': alignment_weight, 'weight_decay': weight_decay, } }, save_path) print(f"\nšŸ’¾ Best model saved at epoch {epoch+1}") else: patience_counter += 1 # Early stopping if use_early_stopping and patience_counter >= patience: print(f"\nšŸ›‘ Early stopping triggered after {patience_counter} epochs without improvement") break # Plot training curves plt.figure(figsize=(12, 5)) plt.subplot(1, 2, 1) plt.plot(train_losses, label='Train Loss', color='blue', linewidth=2) plt.plot(val_losses, label='Val Loss', color='red', linewidth=2) plt.title('Training and Validation Loss (Optimized)', fontsize=14, fontweight='bold') plt.xlabel('Epoch', fontsize=12) plt.ylabel('Loss', fontsize=12) plt.legend(fontsize=11) plt.grid(True, alpha=0.3) plt.subplot(1, 2, 2) gap = [train_losses[i] - val_losses[i] for i in range(len(train_losses))] plt.plot(gap, label='Train-Val Gap', color='purple', linewidth=2) plt.axhline(y=0, color='black', linestyle='--', alpha=0.3) plt.title('Overfitting Gap (Optimized)', fontsize=14, fontweight='bold') plt.xlabel('Epoch', fontsize=12) plt.ylabel('Train Loss - Val Loss', fontsize=12) plt.legend(fontsize=11) plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig('training_curves_optimized.png', dpi=300, bbox_inches='tight') plt.close() print("\n" + "="*80) print("āœ… Training completed!") print(f" Best model: {save_path}") print(f" Training curves: training_curves_optimized.png") print("\nšŸ“Š Final results:") print(f" Last train loss: {train_losses[-1]:.4f}") print(f" Last validation loss: {val_losses[-1]:.4f}") print(f" Best validation loss: {best_val_loss:.4f}") print(f" Overfitting gap: {train_losses[-1] - val_losses[-1]:.4f}") print("="*80) return train_losses, val_losses def main(): """ Main function - Uses best parameters from Optuna optimization. """ print("\n" + "="*80) print("šŸš€ Training with Best Optuna Hyperparameters") print("="*80) # Best hyperparameters from Optuna optimization (Trial 29 - Best validation loss: 0.1129) # Source: optuna_results.txt BEST_PARAMS = { 'learning_rate': 1.42e-05, # From Optuna (best trial) 'temperature': 0.0503, # From Optuna (best trial) 'alignment_weight': 0.5639, # From Optuna (best trial) 'weight_decay': 2.76e-05, # From Optuna (best trial) 'num_epochs': 20, 'batch_size': 32, 'subset_size': 20000, # Increased for better generalization 'patience': 7 } print(f"\nāœ… Using optimized hyperparameters from Optuna:") print(f" Learning rate: {BEST_PARAMS['learning_rate']:.2e}") print(f" Temperature: {BEST_PARAMS['temperature']:.4f}") print(f" Alignment weight: {BEST_PARAMS['alignment_weight']:.4f}") print(f" Weight decay: {BEST_PARAMS['weight_decay']:.2e}") print(f" Expected validation loss: ~0.1129 (from Optuna)\n") train_with_best_params(**BEST_PARAMS) if __name__ == "__main__": main()