#!/usr/bin/env python3 """ Optuna hyperparameter optimization for the main CLIP model. This script uses Optuna to find the best hyperparameters to reduce overfitting. """ import os import sys # Add parent directory to path to import modules sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) os.environ["TOKENIZERS_PARALLELISM"] = "false" import pandas as pd import numpy as np import torch from torch.utils.data import DataLoader, random_split from transformers import CLIPModel as CLIPModel_transformers import optuna from optuna.trial import TrialState import warnings import config from main_model import ( CustomDataset, load_models, train_one_epoch_enhanced, valid_one_epoch ) from transformers import CLIPProcessor warnings.filterwarnings("ignore") # Global variables for data (to avoid reloading for each trial) TRAIN_LOADER = None VAL_LOADER = None FEATURE_MODELS = None DEVICE = None def prepare_data(subset_size=5000, batch_size=32): """ Prepare data loaders for optimization. Use a smaller subset for faster trials. """ print(f"\nšŸ“‚ Loading data...") df = pd.read_csv(config.local_dataset_path) df_clean = df.dropna(subset=[config.column_local_image_path]) print(f" Total samples: {len(df_clean)}") # Create dataset dataset = CustomDataset(df_clean) # Create smaller subset for optimization subset_size = min(subset_size, len(dataset)) train_size = int(0.8 * subset_size) val_size = subset_size - train_size np.random.seed(42) subset_indices = np.random.choice(len(dataset), subset_size, replace=False) subset_dataset = torch.utils.data.Subset(dataset, subset_indices) train_dataset, val_dataset = random_split( subset_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42) ) # Create data loaders train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True if torch.cuda.is_available() else False ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True if torch.cuda.is_available() else False ) print(f" Train: {len(train_dataset)} samples") print(f" Val: {len(val_dataset)} samples") return train_loader, val_loader def objective(trial): """ Objective function for Optuna optimization. Returns validation loss to minimize. """ global TRAIN_LOADER, VAL_LOADER, FEATURE_MODELS, DEVICE # Suggest hyperparameters learning_rate = trial.suggest_float("learning_rate", 1e-6, 5e-5, log=True) temperature = trial.suggest_float("temperature", 0.05, 0.15) alignment_weight = trial.suggest_float("alignment_weight", 0.1, 0.6) weight_decay = trial.suggest_float("weight_decay", 1e-5, 5e-4, log=True) print(f"\n{'='*80}") print(f"Trial {trial.number}") print(f" LR: {learning_rate:.2e}, Temp: {temperature:.4f}") print(f" Align weight: {alignment_weight:.3f}, Weight decay: {weight_decay:.2e}") print(f"{'='*80}") # Create fresh model for this trial clip_model = CLIPModel_transformers.from_pretrained( 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' ).to(DEVICE) # Optimizer with weight decay for regularization optimizer = torch.optim.AdamW( clip_model.parameters(), lr=learning_rate, weight_decay=weight_decay ) # Create processor processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') # Train for a few epochs (reduced for faster optimization) num_epochs = 5 best_val_loss = float('inf') patience_counter = 0 patience = 2 for epoch in range(num_epochs): # Training color_model = FEATURE_MODELS[config.color_column] hierarchy_model = FEATURE_MODELS[config.hierarchy_column] train_loss, metrics = train_one_epoch_enhanced( clip_model, TRAIN_LOADER, optimizer, FEATURE_MODELS, color_model, hierarchy_model, DEVICE, processor, temperature=temperature, alignment_weight=alignment_weight ) # Validation val_loss = valid_one_epoch( clip_model, VAL_LOADER, FEATURE_MODELS, DEVICE, processor, temperature=temperature, alignment_weight=alignment_weight ) print(f" Epoch {epoch+1}/{num_epochs} - Train: {train_loss:.4f}, Val: {val_loss:.4f}") # Track best validation loss if val_loss < best_val_loss: best_val_loss = val_loss patience_counter = 0 else: patience_counter += 1 # Early stopping within trial if patience_counter >= patience: print(f" Early stopping at epoch {epoch+1}") break # Report intermediate value for pruning trial.report(val_loss, epoch) # Handle pruning based on intermediate value if trial.should_prune(): print(f" Trial pruned at epoch {epoch+1}") raise optuna.TrialPruned() # Clean up memory del clip_model, optimizer, processor if torch.cuda.is_available(): torch.cuda.empty_cache() return best_val_loss def main(): """ Main function to run Optuna optimization. """ global TRAIN_LOADER, VAL_LOADER, FEATURE_MODELS, DEVICE print("="*80) print("šŸ” Optuna Hyperparameter Optimization") print("="*80) # Set device DEVICE = config.device print(f"\nDevice: {DEVICE}") # Load feature models once print("\nšŸ”§ Loading feature models...") FEATURE_MODELS = load_models() # Prepare data once (use smaller subset for faster optimization) TRAIN_LOADER, VAL_LOADER = prepare_data(subset_size=5000, batch_size=32) # Create Optuna study print("\nšŸŽÆ Creating Optuna study...") study = optuna.create_study( direction="minimize", pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=2), study_name="clip_hyperparameter_optimization" ) # Run optimization print("\nšŸš€ Starting optimization...") print(f" Running 30 trials (this may take a while)...\n") study.optimize( objective, n_trials=30, timeout=None, catch=(Exception,), show_progress_bar=True ) # Print results print("\n" + "="*80) print("āœ… Optimization Complete!") print("="*80) print(f"\nšŸ“Š Best trial:") trial = study.best_trial print(f" Value (Val Loss): {trial.value:.4f}") print(f"\n Best hyperparameters:") for key, value in trial.params.items(): if 'learning_rate' in key or 'weight_decay' in key: print(f" {key}: {value:.2e}") else: print(f" {key}: {value:.4f}") # Save results in parent directory results_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "optuna_results.txt") with open(results_file, 'w') as f: f.write("="*80 + "\n") f.write("Optuna Hyperparameter Optimization Results\n") f.write("="*80 + "\n\n") f.write(f"Best trial value (validation loss): {trial.value:.4f}\n\n") f.write("Best hyperparameters:\n") for key, value in trial.params.items(): if 'learning_rate' in key or 'weight_decay' in key: f.write(f" {key}: {value:.2e}\n") else: f.write(f" {key}: {value:.4f}\n") f.write("\n" + "="*80 + "\n") f.write("All trials:\n") f.write("="*80 + "\n\n") df_results = study.trials_dataframe() f.write(df_results.to_string()) print(f"\nšŸ’¾ Results saved to: {results_file}") # Save study for later analysis import pickle study_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'optuna_study.pkl') with open(study_file, 'wb') as f: pickle.dump(study, f) print(f"šŸ’¾ Study object saved to: {study_file}") # Print pruned trials statistics pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED]) complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE]) print(f"\nšŸ“ˆ Statistics:") print(f" Number of finished trials: {len(study.trials)}") print(f" Number of pruned trials: {len(pruned_trials)}") print(f" Number of complete trials: {len(complete_trials)}") # Visualization (optional, requires optuna-dashboard or matplotlib) try: from optuna.visualization import plot_optimization_history, plot_param_importances parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # Plot optimization history fig1 = plot_optimization_history(study) history_file = os.path.join(parent_dir, "optuna_optimization_history.png") fig1.write_image(history_file) print(f"šŸ“Š Optimization history saved to: {history_file}") # Plot parameter importances fig2 = plot_param_importances(study) importance_file = os.path.join(parent_dir, "optuna_param_importances.png") fig2.write_image(importance_file) print(f"šŸ“Š Parameter importances saved to: {importance_file}") except Exception as e: print(f"\nāš ļø Visualization skipped: {e}") print(" Install plotly and kaleido for visualizations: pip install plotly kaleido") print("\n" + "="*80) print("šŸŽ‰ Done! Update your config with the best hyperparameters.") print("="*80) if __name__ == "__main__": main()