gap-clip / optuna /optuna_optimisation.py
Leacb4's picture
Upload optuna/optuna_optimisation.py with huggingface_hub
4912235 verified
#!/usr/bin/env python3
"""
Optuna hyperparameter optimization for the main CLIP model.
This script uses Optuna to find the best hyperparameters to reduce overfitting.
"""
import os
import sys
# Add parent directory to path to import modules
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
from transformers import CLIPModel as CLIPModel_transformers
import optuna
from optuna.trial import TrialState
import warnings
import config
from main_model import (
CustomDataset,
load_models,
train_one_epoch_enhanced,
valid_one_epoch
)
from transformers import CLIPProcessor
warnings.filterwarnings("ignore")
# Global variables for data (to avoid reloading for each trial)
TRAIN_LOADER = None
VAL_LOADER = None
FEATURE_MODELS = None
DEVICE = None
def prepare_data(subset_size=5000, batch_size=32):
"""
Prepare data loaders for optimization.
Use a smaller subset for faster trials.
"""
print(f"\nπŸ“‚ Loading data...")
df = pd.read_csv(config.local_dataset_path)
df_clean = df.dropna(subset=[config.column_local_image_path])
print(f" Total samples: {len(df_clean)}")
# Create dataset
dataset = CustomDataset(df_clean)
# Create smaller subset for optimization
subset_size = min(subset_size, len(dataset))
train_size = int(0.8 * subset_size)
val_size = subset_size - train_size
np.random.seed(42)
subset_indices = np.random.choice(len(dataset), subset_size, replace=False)
subset_dataset = torch.utils.data.Subset(dataset, subset_indices)
train_dataset, val_dataset = random_split(
subset_dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(42)
)
# Create data loaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=True if torch.cuda.is_available() else False
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=2,
pin_memory=True if torch.cuda.is_available() else False
)
print(f" Train: {len(train_dataset)} samples")
print(f" Val: {len(val_dataset)} samples")
return train_loader, val_loader
def objective(trial):
"""
Objective function for Optuna optimization.
Returns validation loss to minimize.
"""
global TRAIN_LOADER, VAL_LOADER, FEATURE_MODELS, DEVICE
# Suggest hyperparameters
learning_rate = trial.suggest_float("learning_rate", 1e-6, 5e-5, log=True)
temperature = trial.suggest_float("temperature", 0.05, 0.15)
alignment_weight = trial.suggest_float("alignment_weight", 0.1, 0.6)
weight_decay = trial.suggest_float("weight_decay", 1e-5, 5e-4, log=True)
print(f"\n{'='*80}")
print(f"Trial {trial.number}")
print(f" LR: {learning_rate:.2e}, Temp: {temperature:.4f}")
print(f" Align weight: {alignment_weight:.3f}, Weight decay: {weight_decay:.2e}")
print(f"{'='*80}")
# Create fresh model for this trial
clip_model = CLIPModel_transformers.from_pretrained(
'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
).to(DEVICE)
# Optimizer with weight decay for regularization
optimizer = torch.optim.AdamW(
clip_model.parameters(),
lr=learning_rate,
weight_decay=weight_decay
)
# Create processor
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
# Train for a few epochs (reduced for faster optimization)
num_epochs = 5
best_val_loss = float('inf')
patience_counter = 0
patience = 2
for epoch in range(num_epochs):
# Training
color_model = FEATURE_MODELS[config.color_column]
hierarchy_model = FEATURE_MODELS[config.hierarchy_column]
train_loss, metrics = train_one_epoch_enhanced(
clip_model, TRAIN_LOADER, optimizer, FEATURE_MODELS,
color_model, hierarchy_model, DEVICE, processor,
temperature=temperature, alignment_weight=alignment_weight
)
# Validation
val_loss = valid_one_epoch(
clip_model, VAL_LOADER, FEATURE_MODELS, DEVICE, processor,
temperature=temperature, alignment_weight=alignment_weight
)
print(f" Epoch {epoch+1}/{num_epochs} - Train: {train_loss:.4f}, Val: {val_loss:.4f}")
# Track best validation loss
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
else:
patience_counter += 1
# Early stopping within trial
if patience_counter >= patience:
print(f" Early stopping at epoch {epoch+1}")
break
# Report intermediate value for pruning
trial.report(val_loss, epoch)
# Handle pruning based on intermediate value
if trial.should_prune():
print(f" Trial pruned at epoch {epoch+1}")
raise optuna.TrialPruned()
# Clean up memory
del clip_model, optimizer, processor
if torch.cuda.is_available():
torch.cuda.empty_cache()
return best_val_loss
def main():
"""
Main function to run Optuna optimization.
"""
global TRAIN_LOADER, VAL_LOADER, FEATURE_MODELS, DEVICE
print("="*80)
print("πŸ” Optuna Hyperparameter Optimization")
print("="*80)
# Set device
DEVICE = config.device
print(f"\nDevice: {DEVICE}")
# Load feature models once
print("\nπŸ”§ Loading feature models...")
FEATURE_MODELS = load_models()
# Prepare data once (use smaller subset for faster optimization)
TRAIN_LOADER, VAL_LOADER = prepare_data(subset_size=5000, batch_size=32)
# Create Optuna study
print("\n🎯 Creating Optuna study...")
study = optuna.create_study(
direction="minimize",
pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=2),
study_name="clip_hyperparameter_optimization"
)
# Run optimization
print("\nπŸš€ Starting optimization...")
print(f" Running 30 trials (this may take a while)...\n")
study.optimize(
objective,
n_trials=30,
timeout=None,
catch=(Exception,),
show_progress_bar=True
)
# Print results
print("\n" + "="*80)
print("βœ… Optimization Complete!")
print("="*80)
print(f"\nπŸ“Š Best trial:")
trial = study.best_trial
print(f" Value (Val Loss): {trial.value:.4f}")
print(f"\n Best hyperparameters:")
for key, value in trial.params.items():
if 'learning_rate' in key or 'weight_decay' in key:
print(f" {key}: {value:.2e}")
else:
print(f" {key}: {value:.4f}")
# Save results in parent directory
results_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "optuna_results.txt")
with open(results_file, 'w') as f:
f.write("="*80 + "\n")
f.write("Optuna Hyperparameter Optimization Results\n")
f.write("="*80 + "\n\n")
f.write(f"Best trial value (validation loss): {trial.value:.4f}\n\n")
f.write("Best hyperparameters:\n")
for key, value in trial.params.items():
if 'learning_rate' in key or 'weight_decay' in key:
f.write(f" {key}: {value:.2e}\n")
else:
f.write(f" {key}: {value:.4f}\n")
f.write("\n" + "="*80 + "\n")
f.write("All trials:\n")
f.write("="*80 + "\n\n")
df_results = study.trials_dataframe()
f.write(df_results.to_string())
print(f"\nπŸ’Ύ Results saved to: {results_file}")
# Save study for later analysis
import pickle
study_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'optuna_study.pkl')
with open(study_file, 'wb') as f:
pickle.dump(study, f)
print(f"πŸ’Ύ Study object saved to: {study_file}")
# Print pruned trials statistics
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
print(f"\nπŸ“ˆ Statistics:")
print(f" Number of finished trials: {len(study.trials)}")
print(f" Number of pruned trials: {len(pruned_trials)}")
print(f" Number of complete trials: {len(complete_trials)}")
# Visualization (optional, requires optuna-dashboard or matplotlib)
try:
from optuna.visualization import plot_optimization_history, plot_param_importances
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Plot optimization history
fig1 = plot_optimization_history(study)
history_file = os.path.join(parent_dir, "optuna_optimization_history.png")
fig1.write_image(history_file)
print(f"πŸ“Š Optimization history saved to: {history_file}")
# Plot parameter importances
fig2 = plot_param_importances(study)
importance_file = os.path.join(parent_dir, "optuna_param_importances.png")
fig2.write_image(importance_file)
print(f"πŸ“Š Parameter importances saved to: {importance_file}")
except Exception as e:
print(f"\n⚠️ Visualization skipped: {e}")
print(" Install plotly and kaleido for visualizations: pip install plotly kaleido")
print("\n" + "="*80)
print("πŸŽ‰ Done! Update your config with the best hyperparameters.")
print("="*80)
if __name__ == "__main__":
main()