|
|
|
|
|
""" |
|
|
Training script using best hyperparameters from Optuna optimization. |
|
|
This script trains the model with the optimized hyperparameters and additional |
|
|
regularization techniques to reduce overfitting. |
|
|
""" |
|
|
|
|
|
import os |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import torch |
|
|
from torch.utils.data import DataLoader, random_split |
|
|
from transformers import CLIPModel as CLIPModel_transformers |
|
|
import warnings |
|
|
import config |
|
|
from main_model import CustomDataset, load_models, train_model |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
def train_with_best_params( |
|
|
learning_rate=1.42e-05, |
|
|
temperature=0.0503, |
|
|
alignment_weight=0.5639, |
|
|
weight_decay=2.76e-05, |
|
|
num_epochs=20, |
|
|
batch_size=32, |
|
|
subset_size=20000, |
|
|
use_early_stopping=True, |
|
|
patience=7 |
|
|
): |
|
|
""" |
|
|
Train model with best hyperparameters and anti-overfitting techniques. |
|
|
|
|
|
Args: |
|
|
learning_rate: Learning rate for optimizer (from Optuna) |
|
|
temperature: Temperature for contrastive loss (from Optuna) |
|
|
alignment_weight: Weight for alignment loss (from Optuna) |
|
|
weight_decay: L2 regularization weight (from Optuna) |
|
|
num_epochs: Number of training epochs |
|
|
batch_size: Batch size for training |
|
|
subset_size: Size of dataset subset |
|
|
use_early_stopping: Whether to use early stopping |
|
|
patience: Patience for early stopping |
|
|
""" |
|
|
print("="*80) |
|
|
print("🚀 Training with Optimized Hyperparameters") |
|
|
print("="*80) |
|
|
|
|
|
print(f"\n📋 Configuration:") |
|
|
print(f" Learning rate: {learning_rate:.2e}") |
|
|
print(f" Temperature: {temperature:.4f}") |
|
|
print(f" Alignment weight: {alignment_weight:.4f}") |
|
|
print(f" Weight decay: {weight_decay:.2e}") |
|
|
print(f" Num epochs: {num_epochs}") |
|
|
print(f" Batch size: {batch_size}") |
|
|
print(f" Subset size: {subset_size}") |
|
|
print(f" Early stopping: {use_early_stopping} (patience={patience})") |
|
|
|
|
|
|
|
|
print(f"\n📂 Loading data...") |
|
|
df = pd.read_csv(config.local_dataset_path) |
|
|
df_clean = df.dropna(subset=[config.column_local_image_path]) |
|
|
print(f" Total samples: {len(df_clean)}") |
|
|
|
|
|
|
|
|
dataset = CustomDataset(df_clean) |
|
|
|
|
|
|
|
|
subset_size = min(subset_size, len(dataset)) |
|
|
train_size = int(0.8 * subset_size) |
|
|
val_size = subset_size - train_size |
|
|
|
|
|
np.random.seed(42) |
|
|
subset_indices = np.random.choice(len(dataset), subset_size, replace=False) |
|
|
subset_dataset = torch.utils.data.Subset(dataset, subset_indices) |
|
|
|
|
|
train_dataset, val_dataset = random_split( |
|
|
subset_dataset, |
|
|
[train_size, val_size], |
|
|
generator=torch.Generator().manual_seed(42) |
|
|
) |
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
num_workers=2, |
|
|
pin_memory=True if torch.cuda.is_available() else False |
|
|
) |
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=2, |
|
|
pin_memory=True if torch.cuda.is_available() else False |
|
|
) |
|
|
|
|
|
print(f" Train: {len(train_dataset)} samples") |
|
|
print(f" Val: {len(val_dataset)} samples") |
|
|
|
|
|
|
|
|
print(f"\n🔧 Loading feature models...") |
|
|
feature_models = load_models() |
|
|
|
|
|
|
|
|
print(f"\n📦 Loading main model...") |
|
|
clip_model = CLIPModel_transformers.from_pretrained( |
|
|
'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' |
|
|
) |
|
|
|
|
|
reference_clip = CLIPModel_transformers.from_pretrained( |
|
|
'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' |
|
|
) |
|
|
|
|
|
|
|
|
if os.path.exists(config.main_model_path): |
|
|
user_input = input(f"\n⚠️ Found existing checkpoint at {config.main_model_path}. Load it? (y/n): ") |
|
|
if user_input.lower() == 'y': |
|
|
print(f" Loading checkpoint...") |
|
|
checkpoint = torch.load(config.main_model_path, map_location=config.device) |
|
|
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: |
|
|
clip_model.load_state_dict(checkpoint['model_state_dict']) |
|
|
print(f" ✅ Checkpoint loaded from epoch {checkpoint.get('epoch', '?')}") |
|
|
else: |
|
|
clip_model.load_state_dict(checkpoint) |
|
|
print(f" ✅ Checkpoint loaded") |
|
|
else: |
|
|
print(f" Starting from pretrained model") |
|
|
else: |
|
|
print(f" Starting from pretrained model") |
|
|
|
|
|
clip_model = clip_model.to(config.device) |
|
|
reference_clip = reference_clip.to(config.device) |
|
|
reference_clip.eval() |
|
|
for param in reference_clip.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
print(f"\n🎯 Starting training...") |
|
|
print(f"\n" + "="*80) |
|
|
|
|
|
|
|
|
|
|
|
model = clip_model.to(config.device) |
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) |
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
|
|
optimizer, mode='min', patience=3, factor=0.5 |
|
|
) |
|
|
|
|
|
from transformers import CLIPProcessor |
|
|
from tqdm import tqdm |
|
|
from main_model import train_one_epoch, valid_one_epoch |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
train_losses = [] |
|
|
val_losses = [] |
|
|
best_val_loss = float('inf') |
|
|
patience_counter = 0 |
|
|
|
|
|
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') |
|
|
epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", position=0) |
|
|
|
|
|
for epoch in epoch_pbar: |
|
|
epoch_pbar.set_description(f"Epoch {epoch+1}/{num_epochs}") |
|
|
|
|
|
|
|
|
color_model = feature_models[config.color_column] |
|
|
hierarchy_model = feature_models[config.hierarchy_column] |
|
|
train_loss, align_metrics = train_one_epoch( |
|
|
model, train_loader, optimizer, feature_models, color_model, hierarchy_model, |
|
|
config.device, processor, temperature, alignment_weight, |
|
|
reference_model=reference_clip, reference_weight=0.1 |
|
|
) |
|
|
train_losses.append(train_loss) |
|
|
|
|
|
|
|
|
val_loss = valid_one_epoch( |
|
|
model, val_loader, feature_models, config.device, processor, |
|
|
temperature=temperature, alignment_weight=alignment_weight, |
|
|
reference_model=reference_clip, reference_weight=0.1 |
|
|
) |
|
|
val_losses.append(val_loss) |
|
|
|
|
|
|
|
|
scheduler.step(val_loss) |
|
|
|
|
|
|
|
|
epoch_pbar.set_postfix({ |
|
|
'Train Loss': f'{train_loss:.4f}', |
|
|
'Val Loss': f'{val_loss:.4f}', |
|
|
'LR': f'{optimizer.param_groups[0]["lr"]:.2e}', |
|
|
'Best Val': f'{best_val_loss:.4f}' |
|
|
}) |
|
|
|
|
|
|
|
|
if val_loss < best_val_loss: |
|
|
best_val_loss = val_loss |
|
|
patience_counter = 0 |
|
|
|
|
|
|
|
|
save_path = config.main_model_path.replace('.pt', '_best_optuna.pt') |
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'train_loss': train_loss, |
|
|
'val_loss': val_loss, |
|
|
'best_val_loss': best_val_loss, |
|
|
'hyperparameters': { |
|
|
'learning_rate': learning_rate, |
|
|
'temperature': temperature, |
|
|
'alignment_weight': alignment_weight, |
|
|
'weight_decay': weight_decay, |
|
|
} |
|
|
}, save_path) |
|
|
print(f"\n💾 Best model saved at epoch {epoch+1}") |
|
|
else: |
|
|
patience_counter += 1 |
|
|
|
|
|
|
|
|
if use_early_stopping and patience_counter >= patience: |
|
|
print(f"\n🛑 Early stopping triggered after {patience_counter} epochs without improvement") |
|
|
break |
|
|
|
|
|
|
|
|
plt.figure(figsize=(12, 5)) |
|
|
|
|
|
plt.subplot(1, 2, 1) |
|
|
plt.plot(train_losses, label='Train Loss', color='blue', linewidth=2) |
|
|
plt.plot(val_losses, label='Val Loss', color='red', linewidth=2) |
|
|
plt.title('Training and Validation Loss (Optimized)', fontsize=14, fontweight='bold') |
|
|
plt.xlabel('Epoch', fontsize=12) |
|
|
plt.ylabel('Loss', fontsize=12) |
|
|
plt.legend(fontsize=11) |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
plt.subplot(1, 2, 2) |
|
|
gap = [train_losses[i] - val_losses[i] for i in range(len(train_losses))] |
|
|
plt.plot(gap, label='Train-Val Gap', color='purple', linewidth=2) |
|
|
plt.axhline(y=0, color='black', linestyle='--', alpha=0.3) |
|
|
plt.title('Overfitting Gap (Optimized)', fontsize=14, fontweight='bold') |
|
|
plt.xlabel('Epoch', fontsize=12) |
|
|
plt.ylabel('Train Loss - Val Loss', fontsize=12) |
|
|
plt.legend(fontsize=11) |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig('training_curves_optimized.png', dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("✅ Training completed!") |
|
|
print(f" Best model: {save_path}") |
|
|
print(f" Training curves: training_curves_optimized.png") |
|
|
print("\n📊 Final results:") |
|
|
print(f" Last train loss: {train_losses[-1]:.4f}") |
|
|
print(f" Last validation loss: {val_losses[-1]:.4f}") |
|
|
print(f" Best validation loss: {best_val_loss:.4f}") |
|
|
print(f" Overfitting gap: {train_losses[-1] - val_losses[-1]:.4f}") |
|
|
print("="*80) |
|
|
|
|
|
return train_losses, val_losses |
|
|
|
|
|
def main(): |
|
|
""" |
|
|
Main function - Uses best parameters from Optuna optimization. |
|
|
""" |
|
|
print("\n" + "="*80) |
|
|
print("🚀 Training with Best Optuna Hyperparameters") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
|
|
|
BEST_PARAMS = { |
|
|
'learning_rate': 1.42e-05, |
|
|
'temperature': 0.0503, |
|
|
'alignment_weight': 0.5639, |
|
|
'weight_decay': 2.76e-05, |
|
|
'num_epochs': 20, |
|
|
'batch_size': 32, |
|
|
'subset_size': 20000, |
|
|
'patience': 7 |
|
|
} |
|
|
|
|
|
print(f"\n✅ Using optimized hyperparameters from Optuna:") |
|
|
print(f" Learning rate: {BEST_PARAMS['learning_rate']:.2e}") |
|
|
print(f" Temperature: {BEST_PARAMS['temperature']:.4f}") |
|
|
print(f" Alignment weight: {BEST_PARAMS['alignment_weight']:.4f}") |
|
|
print(f" Weight decay: {BEST_PARAMS['weight_decay']:.2e}") |
|
|
print(f" Expected validation loss: ~0.1129 (from Optuna)\n") |
|
|
|
|
|
train_with_best_params(**BEST_PARAMS) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|