|
|
|
|
|
""" |
|
|
Optuna hyperparameter optimization for the main CLIP model. |
|
|
This script uses Optuna to find the best hyperparameters to reduce overfitting. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import torch |
|
|
from torch.utils.data import DataLoader, random_split |
|
|
from transformers import CLIPModel as CLIPModel_transformers |
|
|
import optuna |
|
|
from optuna.trial import TrialState |
|
|
import warnings |
|
|
import config |
|
|
from main_model import ( |
|
|
CustomDataset, |
|
|
load_models, |
|
|
train_one_epoch_enhanced, |
|
|
valid_one_epoch |
|
|
) |
|
|
from transformers import CLIPProcessor |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
TRAIN_LOADER = None |
|
|
VAL_LOADER = None |
|
|
FEATURE_MODELS = None |
|
|
DEVICE = None |
|
|
|
|
|
def prepare_data(subset_size=5000, batch_size=32): |
|
|
""" |
|
|
Prepare data loaders for optimization. |
|
|
Use a smaller subset for faster trials. |
|
|
""" |
|
|
print(f"\nπ Loading data...") |
|
|
df = pd.read_csv(config.local_dataset_path) |
|
|
df_clean = df.dropna(subset=[config.column_local_image_path]) |
|
|
print(f" Total samples: {len(df_clean)}") |
|
|
|
|
|
|
|
|
dataset = CustomDataset(df_clean) |
|
|
|
|
|
|
|
|
subset_size = min(subset_size, len(dataset)) |
|
|
train_size = int(0.8 * subset_size) |
|
|
val_size = subset_size - train_size |
|
|
|
|
|
np.random.seed(42) |
|
|
subset_indices = np.random.choice(len(dataset), subset_size, replace=False) |
|
|
subset_dataset = torch.utils.data.Subset(dataset, subset_indices) |
|
|
|
|
|
train_dataset, val_dataset = random_split( |
|
|
subset_dataset, |
|
|
[train_size, val_size], |
|
|
generator=torch.Generator().manual_seed(42) |
|
|
) |
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
num_workers=2, |
|
|
pin_memory=True if torch.cuda.is_available() else False |
|
|
) |
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=2, |
|
|
pin_memory=True if torch.cuda.is_available() else False |
|
|
) |
|
|
|
|
|
print(f" Train: {len(train_dataset)} samples") |
|
|
print(f" Val: {len(val_dataset)} samples") |
|
|
|
|
|
return train_loader, val_loader |
|
|
|
|
|
def objective(trial): |
|
|
""" |
|
|
Objective function for Optuna optimization. |
|
|
Returns validation loss to minimize. |
|
|
""" |
|
|
global TRAIN_LOADER, VAL_LOADER, FEATURE_MODELS, DEVICE |
|
|
|
|
|
|
|
|
learning_rate = trial.suggest_float("learning_rate", 1e-6, 5e-5, log=True) |
|
|
temperature = trial.suggest_float("temperature", 0.05, 0.15) |
|
|
alignment_weight = trial.suggest_float("alignment_weight", 0.1, 0.6) |
|
|
weight_decay = trial.suggest_float("weight_decay", 1e-5, 5e-4, log=True) |
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print(f"Trial {trial.number}") |
|
|
print(f" LR: {learning_rate:.2e}, Temp: {temperature:.4f}") |
|
|
print(f" Align weight: {alignment_weight:.3f}, Weight decay: {weight_decay:.2e}") |
|
|
print(f"{'='*80}") |
|
|
|
|
|
|
|
|
clip_model = CLIPModel_transformers.from_pretrained( |
|
|
'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' |
|
|
).to(DEVICE) |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
|
clip_model.parameters(), |
|
|
lr=learning_rate, |
|
|
weight_decay=weight_decay |
|
|
) |
|
|
|
|
|
|
|
|
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') |
|
|
|
|
|
|
|
|
num_epochs = 5 |
|
|
best_val_loss = float('inf') |
|
|
patience_counter = 0 |
|
|
patience = 2 |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
|
|
|
color_model = FEATURE_MODELS[config.color_column] |
|
|
hierarchy_model = FEATURE_MODELS[config.hierarchy_column] |
|
|
|
|
|
train_loss, metrics = train_one_epoch_enhanced( |
|
|
clip_model, TRAIN_LOADER, optimizer, FEATURE_MODELS, |
|
|
color_model, hierarchy_model, DEVICE, processor, |
|
|
temperature=temperature, alignment_weight=alignment_weight |
|
|
) |
|
|
|
|
|
|
|
|
val_loss = valid_one_epoch( |
|
|
clip_model, VAL_LOADER, FEATURE_MODELS, DEVICE, processor, |
|
|
temperature=temperature, alignment_weight=alignment_weight |
|
|
) |
|
|
|
|
|
print(f" Epoch {epoch+1}/{num_epochs} - Train: {train_loss:.4f}, Val: {val_loss:.4f}") |
|
|
|
|
|
|
|
|
if val_loss < best_val_loss: |
|
|
best_val_loss = val_loss |
|
|
patience_counter = 0 |
|
|
else: |
|
|
patience_counter += 1 |
|
|
|
|
|
|
|
|
if patience_counter >= patience: |
|
|
print(f" Early stopping at epoch {epoch+1}") |
|
|
break |
|
|
|
|
|
|
|
|
trial.report(val_loss, epoch) |
|
|
|
|
|
|
|
|
if trial.should_prune(): |
|
|
print(f" Trial pruned at epoch {epoch+1}") |
|
|
raise optuna.TrialPruned() |
|
|
|
|
|
|
|
|
del clip_model, optimizer, processor |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return best_val_loss |
|
|
|
|
|
def main(): |
|
|
""" |
|
|
Main function to run Optuna optimization. |
|
|
""" |
|
|
global TRAIN_LOADER, VAL_LOADER, FEATURE_MODELS, DEVICE |
|
|
|
|
|
print("="*80) |
|
|
print("π Optuna Hyperparameter Optimization") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
DEVICE = config.device |
|
|
print(f"\nDevice: {DEVICE}") |
|
|
|
|
|
|
|
|
print("\nπ§ Loading feature models...") |
|
|
FEATURE_MODELS = load_models() |
|
|
|
|
|
|
|
|
TRAIN_LOADER, VAL_LOADER = prepare_data(subset_size=5000, batch_size=32) |
|
|
|
|
|
|
|
|
print("\nπ― Creating Optuna study...") |
|
|
study = optuna.create_study( |
|
|
direction="minimize", |
|
|
pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=2), |
|
|
study_name="clip_hyperparameter_optimization" |
|
|
) |
|
|
|
|
|
|
|
|
print("\nπ Starting optimization...") |
|
|
print(f" Running 30 trials (this may take a while)...\n") |
|
|
|
|
|
study.optimize( |
|
|
objective, |
|
|
n_trials=30, |
|
|
timeout=None, |
|
|
catch=(Exception,), |
|
|
show_progress_bar=True |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("β
Optimization Complete!") |
|
|
print("="*80) |
|
|
|
|
|
print(f"\nπ Best trial:") |
|
|
trial = study.best_trial |
|
|
print(f" Value (Val Loss): {trial.value:.4f}") |
|
|
print(f"\n Best hyperparameters:") |
|
|
for key, value in trial.params.items(): |
|
|
if 'learning_rate' in key or 'weight_decay' in key: |
|
|
print(f" {key}: {value:.2e}") |
|
|
else: |
|
|
print(f" {key}: {value:.4f}") |
|
|
|
|
|
|
|
|
results_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "optuna_results.txt") |
|
|
with open(results_file, 'w') as f: |
|
|
f.write("="*80 + "\n") |
|
|
f.write("Optuna Hyperparameter Optimization Results\n") |
|
|
f.write("="*80 + "\n\n") |
|
|
f.write(f"Best trial value (validation loss): {trial.value:.4f}\n\n") |
|
|
f.write("Best hyperparameters:\n") |
|
|
for key, value in trial.params.items(): |
|
|
if 'learning_rate' in key or 'weight_decay' in key: |
|
|
f.write(f" {key}: {value:.2e}\n") |
|
|
else: |
|
|
f.write(f" {key}: {value:.4f}\n") |
|
|
f.write("\n" + "="*80 + "\n") |
|
|
f.write("All trials:\n") |
|
|
f.write("="*80 + "\n\n") |
|
|
|
|
|
df_results = study.trials_dataframe() |
|
|
f.write(df_results.to_string()) |
|
|
|
|
|
print(f"\nπΎ Results saved to: {results_file}") |
|
|
|
|
|
|
|
|
import pickle |
|
|
study_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'optuna_study.pkl') |
|
|
with open(study_file, 'wb') as f: |
|
|
pickle.dump(study, f) |
|
|
print(f"πΎ Study object saved to: {study_file}") |
|
|
|
|
|
|
|
|
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED]) |
|
|
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE]) |
|
|
|
|
|
print(f"\nπ Statistics:") |
|
|
print(f" Number of finished trials: {len(study.trials)}") |
|
|
print(f" Number of pruned trials: {len(pruned_trials)}") |
|
|
print(f" Number of complete trials: {len(complete_trials)}") |
|
|
|
|
|
|
|
|
try: |
|
|
from optuna.visualization import plot_optimization_history, plot_param_importances |
|
|
|
|
|
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
|
|
|
fig1 = plot_optimization_history(study) |
|
|
history_file = os.path.join(parent_dir, "optuna_optimization_history.png") |
|
|
fig1.write_image(history_file) |
|
|
print(f"π Optimization history saved to: {history_file}") |
|
|
|
|
|
|
|
|
fig2 = plot_param_importances(study) |
|
|
importance_file = os.path.join(parent_dir, "optuna_param_importances.png") |
|
|
fig2.write_image(importance_file) |
|
|
print(f"π Parameter importances saved to: {importance_file}") |
|
|
except Exception as e: |
|
|
print(f"\nβ οΈ Visualization skipped: {e}") |
|
|
print(" Install plotly and kaleido for visualizations: pip install plotly kaleido") |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("π Done! Update your config with the best hyperparameters.") |
|
|
print("="*80) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|