#!/usr/bin/env python3 """ Main file for training the CLIP model with color and hierarchy alignment. This file centralizes all the logic for training the main model. It uses pre-trained color and hierarchy models to guide the main model's learning through contrastive and alignment loss functions. It handles data loading, training with validation, and checkpoint saving. """ import os # Set environment variable to disable tokenizers parallelism warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" import pandas as pd import numpy as np import torch import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms from PIL import Image import matplotlib.pyplot as plt from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers import warnings from tqdm import tqdm import config # Suppress warnings warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) # ------------------------------- # Loss Functions # ------------------------------- def enhanced_contrastive_loss(text_features, image_features, attribute_features, color_model, hierarchy_model, colors, hierarchies, temperature=0.07, alignment_weight=0.3, reference_text_features=None, reference_image_features=None, reference_weight=0.1): """ Enhanced contrastive loss with direct alignment between color/hierarchy models and main model. This loss combines the original triple contrastive loss with direct alignment losses that force the main model's color and hierarchy dimensions to align with the specialized color and hierarchy models. Args: text_features: Main model text embeddings [batch_size, embed_dim] image_features: Main model image embeddings [batch_size, embed_dim] attribute_features: Concatenated color + hierarchy features [batch_size, color_dim + hierarchy_dim] color_model: Pre-trained color model for extracting color embeddings hierarchy_model: Pre-trained hierarchy model for extracting hierarchy embeddings colors: List of color strings for this batch [batch_size] hierarchies: List of hierarchy strings for this batch [batch_size] temperature: Temperature scaling parameter for contrastive loss (default: 0.07) alignment_weight: Weight for the alignment loss component (default: 0.3) Returns: Tuple of (total_loss, metrics_dict) where metrics_dict contains detailed loss components """ # Original triple contrastive loss text_features_norm = F.normalize(text_features, dim=-1) image_features_norm = F.normalize(image_features, dim=-1) attribute_features_norm = F.normalize(attribute_features, dim=-1) text_image_logits = (text_features_norm[:, config.color_emb_dim+config.hierarchy_emb_dim:] @ image_features_norm[:, config.color_emb_dim+config.hierarchy_emb_dim:].T) / temperature text_attr_logits = (text_features_norm[:, :config.color_emb_dim+config.hierarchy_emb_dim] @ attribute_features_norm.T) / temperature image_attr_logits = (attribute_features_norm @ image_features_norm[:,:config.color_emb_dim+config.hierarchy_emb_dim].T) / temperature # Weight distribution for original loss weight_text_image = 0.7 weight_attr_based = 0.15 original_logits = (weight_text_image * text_image_logits + weight_attr_based * text_attr_logits + weight_attr_based * image_attr_logits) labels = torch.arange(len(text_features)).to(text_features.device) original_loss = (F.cross_entropy(original_logits, labels) + F.cross_entropy(original_logits.T, labels)) / 2 # Direct alignment loss between color model and main model first 16 dims with torch.no_grad(): color_embeddings = color_model.get_text_embeddings(colors) hierarchy_embeddings = hierarchy_model.get_text_embeddings(hierarchies) # Extract color dimensions from main model embeddings main_color_text = text_features[:, :config.color_emb_dim] main_color_image = image_features[:, :config.color_emb_dim] # Extract hierarchy dimensions from main model embeddings main_hierarchy_text = text_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim] main_hierarchy_image = image_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim] # Normalize for better correlation color_embeddings_norm = F.normalize(color_embeddings, dim=-1) main_color_text_norm = F.normalize(main_color_text, dim=-1) main_color_image_norm = F.normalize(main_color_image, dim=-1) hierarchy_embeddings_norm = F.normalize(hierarchy_embeddings, dim=-1) main_hierarchy_text_norm = F.normalize(main_hierarchy_text, dim=-1) main_hierarchy_image_norm = F.normalize(main_hierarchy_image, dim=-1) # Color alignment loss (cosine-only: more natural for normalized embeddings) color_text_cosine_loss = 1 - F.cosine_similarity(main_color_text_norm, color_embeddings_norm).mean() color_image_cosine_loss = 1 - F.cosine_similarity(main_color_image_norm, color_embeddings_norm).mean() color_alignment_loss = (color_text_cosine_loss + color_image_cosine_loss) / 2 # Hierarchy alignment loss (cosine-only) hierarchy_text_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_text_norm, hierarchy_embeddings_norm).mean() hierarchy_image_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_image_norm, hierarchy_embeddings_norm).mean() hierarchy_alignment_loss = (hierarchy_text_cosine_loss + hierarchy_image_cosine_loss) / 2 # Combined alignment loss alignment_loss = (color_alignment_loss + hierarchy_alignment_loss) / 2 # Reference loss to keep embeddings close to base CLIP (preserves zero-shot capability) reference_loss = 0.0 if reference_text_features is not None: text_ref_loss = F.mse_loss( F.normalize(text_features, dim=-1), F.normalize(reference_text_features, dim=-1) ) if reference_image_features is not None: image_ref_loss = F.mse_loss( F.normalize(image_features, dim=-1), F.normalize(reference_image_features, dim=-1) ) reference_loss = (text_ref_loss + image_ref_loss) / 2 else: reference_loss = text_ref_loss # Combine losses total_loss = (1 - alignment_weight) * original_loss + alignment_weight * alignment_loss if reference_text_features is not None: total_loss = total_loss + reference_weight * reference_loss return total_loss, { 'original_loss': original_loss.item(), 'alignment_loss': alignment_loss.item(), 'reference_loss': reference_loss if isinstance(reference_loss, float) else reference_loss.item(), 'color_text_cosine': color_text_cosine_loss.item(), 'color_image_cosine': color_image_cosine_loss.item(), 'hierarchy_text_cosine': hierarchy_text_cosine_loss.item(), 'hierarchy_image_cosine': hierarchy_image_cosine_loss.item() } # ------------------------------- # Training Functions # ------------------------------- def train_one_epoch(model, train_loader, optimizer, feature_models, color_model, hierarchy_model, device, clip_processor, temperature=0.07, alignment_weight=0.3, reference_model=None, reference_weight=0.1): """ Enhanced training with direct color and hierarchy alignment loss. This function trains the model using the enhanced contrastive loss that includes direct alignment between the main model's color/hierarchy dimensions and the specialized color/hierarchy models. Args: model: Main CLIP model to train train_loader: DataLoader for training data optimizer: Optimizer instance feature_models: Dictionary containing color and hierarchy models color_model: Pre-trained color model for alignment hierarchy_model: Pre-trained hierarchy model for alignment device: Device to train on clip_processor: CLIP processor for text preprocessing temperature: Temperature scaling parameter for contrastive loss (default: 0.07) alignment_weight: Weight for the alignment loss component (default: 0.3) Returns: Tuple of (average_loss, metrics_dict) where metrics_dict contains detailed loss components """ model.train() total_loss = 0.0 total_metrics = { 'original_loss': 0.0, 'alignment_loss': 0.0, 'reference_loss': 0.0, 'color_text_cosine': 0.0, 'color_image_cosine': 0.0, 'hierarchy_text_cosine': 0.0, 'hierarchy_image_cosine': 0.0 } num_batches = 0 pbar = tqdm(train_loader, desc="Training Enhanced", leave=False) for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar): # Move data to device images = images.to(device) images = images.expand(-1, 3, -1, -1) # Process text inputs text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt") text_inputs = {k: v.to(device) for k, v in text_inputs.items()} # Reference features to keep embeddings close to base CLIP reference_text_features = None reference_image_features = None if reference_model is not None: with torch.no_grad(): reference_text_features = reference_model.get_text_features(**text_inputs) reference_image_features = reference_model.get_image_features(pixel_values=images) # Forward pass optimizer.zero_grad() outputs = model(**text_inputs, pixel_values=images) text_features = outputs.text_embeds image_features = outputs.image_embeds # Get feature embeddings if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'): color_features = feature_models[config.color_column].get_color_name_embeddings(colors) else: color_features = feature_models[config.color_column].get_text_embeddings(colors) hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy) concat_features = torch.cat((color_features, hierarchy_features), dim=1) # Calculate enhanced loss with hierarchy alignment loss, metrics = enhanced_contrastive_loss( text_features, image_features, concat_features, color_model, hierarchy_model, colors, hierarchy, temperature, alignment_weight, reference_text_features=reference_text_features, reference_image_features=reference_image_features, reference_weight=reference_weight ) # Backward pass loss.backward() # Gradient clipping to prevent exploding gradients torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() total_loss += loss.item() for key, value in metrics.items(): total_metrics[key] += value num_batches += 1 # Update progress bar pbar.set_postfix({ 'Loss': f'{loss.item():.4f}', 'Align': f'{metrics["alignment_loss"]:.4f}', 'ColCos': f'{metrics["color_text_cosine"]:.3f}', 'HierCos': f'{metrics["hierarchy_text_cosine"]:.3f}' }) avg_metrics = {key: value / num_batches for key, value in total_metrics.items()} return total_loss / num_batches, avg_metrics def valid_one_epoch(model, val_loader, feature_models, device, clip_processor, temperature=0.07, alignment_weight=0.3, reference_model=None, reference_weight=0.1): """ Validate the model for one epoch using enhanced contrastive loss. Args: model: Main CLIP model to validate val_loader: DataLoader for validation data feature_models: Dictionary containing color and hierarchy models device: Device to validate on clip_processor: CLIP processor for text preprocessing temperature: Temperature scaling parameter for contrastive loss (default: 0.07) alignment_weight: Weight for the alignment loss component (default: 0.3) Returns: Average validation loss for the epoch """ model.eval() total_loss = 0.0 num_batches = 0 # Extract color and hierarchy models color_model = feature_models[config.color_column] hierarchy_model = feature_models[config.hierarchy_column] # Create progress bar for validation pbar = tqdm(val_loader, desc="Validation", leave=False) with torch.no_grad(): for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar): # Move data to device images = images.to(device) images = images.expand(-1, 3, -1, -1) # Ensure 3 channels # Process text inputs text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt") text_inputs = {k: v.to(device) for k, v in text_inputs.items()} # Reference features to keep embeddings close to base CLIP reference_text_features = None reference_image_features = None if reference_model is not None: reference_text_features = reference_model.get_text_features(**text_inputs) reference_image_features = reference_model.get_image_features(pixel_values=images) # Forward pass outputs = model(**text_inputs, pixel_values=images) text_features = outputs.text_embeds image_features = outputs.image_embeds # Get feature embeddings if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'): color_features = feature_models[config.color_column].get_color_name_embeddings(colors) else: color_features = feature_models[config.color_column].get_text_embeddings(colors) hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy) concat_features = torch.cat((color_features, hierarchy_features), dim=1) # Calculate loss with all required arguments loss, metrics = enhanced_contrastive_loss( text_features, image_features, concat_features, color_model, hierarchy_model, colors, hierarchy, temperature, alignment_weight, reference_text_features=reference_text_features, reference_image_features=reference_image_features, reference_weight=reference_weight ) total_loss += loss.item() num_batches += 1 # Update progress bar pbar.set_postfix({ 'Loss': f'{loss.item():.4f}', 'Avg Loss': f'{total_loss/num_batches:.4f}' }) return total_loss / num_batches # ------------------------------- # Dataset # ------------------------------- class CustomDataset(Dataset): """ Custom dataset for main model training. Handles loading images from local paths, extracting text descriptions, and applying appropriate transformations for training and validation. """ def __init__(self, dataframe, use_local_images=True, image_size=224): """ Initialize the custom dataset. Args: dataframe: DataFrame with columns for image paths, text descriptions, colors, and hierarchy labels use_local_images: Whether to use local images (default: True) image_size: Size of images after resizing (default: 224) """ self.dataframe = dataframe self.use_local_images = use_local_images self.image_size = image_size # Transforms with augmentation for training (increased augmentation to reduce overfitting) self.transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(15), # Increased for more variation transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15), # Increased intensity transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)), # Increased transform range transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.2), # Add blur transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Transforms for validation (no augmentation) self.val_transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) self.training_mode = True def set_training_mode(self, training=True): """ Switch between training and validation transforms. Args: training: If True, use training transforms with augmentation; if False, use validation transforms """ self.training_mode = training def __len__(self): """Return the number of samples in the dataset.""" return len(self.dataframe) def __getitem__(self, idx): """ Get a sample from the dataset. Args: idx: Index of the sample Returns: Tuple of (image_tensor, description_text, color_label, hierarchy_label) """ row = self.dataframe.iloc[idx] image_data = row[config.column_local_image_path] image = Image.open(image_data).convert("RGB") # Apply appropriate transform if self.training_mode: image = self.transform(image) else: image = self.val_transform(image) # Get text and labels description = row[config.text_column] color = row[config.color_column] hierarchy = row[config.hierarchy_column] return image, description, color, hierarchy # ------------------------------- # Model Loading # ------------------------------- def load_models(): """ Load color and hierarchy models from checkpoints. Returns: Dictionary mapping model names to model instances: - 'color': ColorCLIP model instance - 'hierarchy': HierarchyModel instance """ from training.color_model import ColorCLIP from training.hierarchy_model import HierarchyModel # --- Color model --- print("Loading ColorCLIP (CLIP-backbone) ...") color_model = ColorCLIP.from_checkpoint(config.color_model_path, device=config.device) color_model.eval() color_model.name = config.color_column # --- Hierarchy model --- print("Loading HierarchyModel (CLIP-backbone) ...") hierarchy_model = HierarchyModel.from_checkpoint(config.hierarchy_model_path, device=config.device) hierarchy_model.eval() hierarchy_model.name = config.hierarchy_column feature_models = {model.name: model for model in [color_model, hierarchy_model]} return feature_models # ------------------------------- # Main Training Function # ------------------------------- def train_model(model, train_loader, val_loader, feature_models, device, num_epochs=20, learning_rate=1e-5, temperature=0.07, save_path=config.main_model_path, alignment_weight=0.3, color_alignment_model=None, weight_decay=3e-4, reference_model=None, reference_weight=0.1): """ Custom training loop using train_one_epoch and valid_one_epoch functions. This function handles the complete training process including: - Training and validation loops - Learning rate scheduling - Early stopping - Model checkpointing - Training curve visualization Args: model: Main CLIP model to train train_loader: DataLoader for training data val_loader: DataLoader for validation data feature_models: Dictionary containing color and hierarchy models device: Device to train on num_epochs: Number of training epochs (default: 20) learning_rate: Learning rate for optimizer (default: 1e-5) temperature: Temperature scaling parameter for contrastive loss (default: 0.07) save_path: Path to save model checkpoints (default: main_model_path) alignment_weight: Weight for alignment loss component if using enhanced loss (default: 0.3) color_alignment_model: Optional color model for alignment (default: None, uses feature_models) weight_decay: L2 regularization weight (default: 3e-4, increased to reduce overfitting) Returns: Tuple of (training_losses, validation_losses) lists """ model = model.to(device) # Use AdamW with weight decay for better regularization (reduces overfitting) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5) train_losses = [] val_losses = [] best_val_loss = float('inf') patience_counter = 0 patience = 7 # Increased from 5 to 7 for better convergence print(f"Starting training for {num_epochs} epochs...") print(f"Learning rate: {learning_rate}") print(f"Temperature: {temperature}") print(f"Weight decay: {weight_decay}") print(f"Alignment weight: {alignment_weight}") print(f"Device: {device}") print(f"Training samples: {len(train_loader.dataset)}") print(f"Validation samples: {len(val_loader.dataset)}") print(f"Batch size: {train_loader.batch_size}") print(f"Estimated time per epoch: ~{len(train_loader) * 2 / 60:.1f} minutes") # Create processor once for efficiency processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') # Freeze and move reference model (used for text-space regularization) if reference_model is not None: reference_model = reference_model.to(device) reference_model.eval() for param in reference_model.parameters(): param.requires_grad = False # Create progress bar for epochs epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", position=0) for epoch in epoch_pbar: # Update epoch progress bar epoch_pbar.set_description(f"Epoch {epoch+1}/{num_epochs}") # Training if color_alignment_model is None: color_alignment_model = feature_models[config.color_column] hierarchy_model = feature_models[config.hierarchy_column] train_loss, align_metrics = train_one_epoch( model, train_loader, optimizer, feature_models, color_alignment_model, hierarchy_model, device, processor, temperature, alignment_weight, reference_model=reference_model, reference_weight=reference_weight ) train_losses.append(train_loss) # Validation val_loss = valid_one_epoch( model, val_loader, feature_models, device, processor, temperature=temperature, alignment_weight=alignment_weight, reference_model=reference_model, reference_weight=reference_weight ) val_losses.append(val_loss) # Learning rate scheduling scheduler.step(val_loss) # Calculate overfitting gap overfitting_gap = val_loss - train_loss # Update epoch progress bar with metrics postfix = { 'Train Loss': f'{train_loss:.4f}', 'Val Loss': f'{val_loss:.4f}', 'Gap': f'{overfitting_gap:.4f}', 'LR': f'{optimizer.param_groups[0]["lr"]:.2e}', 'Best Val': f'{best_val_loss:.4f}' } if align_metrics is not None: postfix.update({ 'Align': f"{align_metrics['alignment_loss']:.3f}", 'ColCos': f"{align_metrics['color_text_cosine']:.3f}", 'HierCos': f"{align_metrics['hierarchy_text_cosine']:.3f}" }) epoch_pbar.set_postfix(postfix) # Warning if overfitting is detected if overfitting_gap > 0.15 and epoch > 3: print(f"\nāš ļø Warning: Significant overfitting detected at epoch {epoch+1} (gap={overfitting_gap:.4f})") # Save best model if val_loss < best_val_loss: best_val_loss = val_loss patience_counter = 0 # Save checkpoint torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_loss': train_loss, 'val_loss': val_loss, 'best_val_loss': best_val_loss, }, save_path) else: patience_counter += 1 # Early stopping if patience_counter >= patience: print(f"\nšŸ›‘ Early stopping triggered after {patience_counter} epochs without improvement") break # Plot training curves with overfitting analysis plt.figure(figsize=(15, 5)) # Plot 1: Training and Validation Loss plt.subplot(1, 3, 1) plt.plot(train_losses, label='Train Loss', color='blue', linewidth=2) plt.plot(val_losses, label='Val Loss', color='red', linewidth=2) plt.title('Training and Validation Loss', fontsize=12, fontweight='bold') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.grid(True, alpha=0.3) # Plot 2: Overfitting Gap (Val Loss - Train Loss) plt.subplot(1, 3, 2) gap = [val_losses[i] - train_losses[i] for i in range(len(train_losses))] plt.plot(gap, label='Overfitting Gap', color='purple', linewidth=2) plt.axhline(y=0, color='black', linestyle='--', alpha=0.3) plt.axhline(y=0.1, color='red', linestyle='--', alpha=0.3, label='Warning threshold') plt.title('Overfitting Gap (Val - Train)', fontsize=12, fontweight='bold') plt.xlabel('Epoch') plt.ylabel('Gap') plt.legend() plt.grid(True, alpha=0.3) # Plot 3: Loss comparison plt.subplot(1, 3, 3) epochs = list(range(len(train_losses))) plt.plot(epochs, train_losses, 'o-', label='Train Loss', color='blue', linewidth=2) plt.plot(epochs, val_losses, 's-', label='Val Loss', color='red', linewidth=2) plt.fill_between(epochs, train_losses, val_losses, alpha=0.2, color='red') plt.title('Loss Comparison', fontsize=12, fontweight='bold') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.grid(True, alpha=0.3) plt.tight_layout() curves_path = str(config.ROOT_DIR / "figures" / "training_curves.png") plt.savefig(curves_path, dpi=300, bbox_inches='tight') plt.close() print(f"\nTraining completed!") print(f"Best validation loss: {best_val_loss:.4f}") print(f"Final model saved to: {save_path}") print(f"Training curves saved to: {curves_path}") return train_losses, val_losses # ------------------------------- # Main Function # ------------------------------- def main(): print("="*80) print("šŸš€ Training of the model with alignment color and hierarchy") print("="*80) # Configuration (tuned for zero-shot + separation balance) num_epochs = 10 learning_rate = 1.5e-5 temperature = 0.09 alignment_weight = 0.10 # reduced from 0.2: softer alignment preserves CLIP zero-shot reference_weight = 0.25 # increased from 0.1: stronger regularization toward base CLIP weight_decay = 1e-3 # increased from 5e-4: better generalization batch_size = 128 subset_size = 100000 # Load the data print(f"\nšŸ“‚ Loading the data...") df = pd.read_csv(config.local_dataset_path) print(f" Data downloaded: {len(df)} samples") # filter the rows with NaN values df_clean = df.dropna(subset=[config.column_local_image_path]) df_clean = df_clean[df_clean[config.column_local_image_path].astype(str).str.len() > 0] print(f" After filtering NaN: {len(df_clean)} samples") # Creation of datasets dataset = CustomDataset(df_clean) # Sample 100k for training subset_size = min(subset_size, len(dataset)) train_size = int(0.8 * subset_size) val_size = subset_size - train_size np.random.seed(42) subset_indices = np.random.choice(len(dataset), subset_size, replace=False) subset_dataset = torch.utils.data.Subset(dataset, subset_indices) train_dataset, val_dataset = random_split( subset_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42) ) # Creation of dataloaders train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True if torch.cuda.is_available() else False ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True if torch.cuda.is_available() else False ) print(f" Train: {len(train_dataset)} samples") print(f" Validation: {len(val_dataset)} samples") # Loading models print(f"\nšŸ”§ Loading models...") feature_models = load_models() # Load or create the main model print(f"\nšŸ“¦ Loading main model...") clip_model = CLIPModel_transformers.from_pretrained( 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' ) # Frozen reference CLIP to regularize text space (improves cross-domain generalization) reference_clip = CLIPModel_transformers.from_pretrained( 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' ) # Move the model on the device clip_model = clip_model.to(config.device) reference_clip = reference_clip.to(config.device) reference_clip.eval() for param in reference_clip.parameters(): param.requires_grad = False # Training with enhanced loss print(f"\nšŸŽÆ Beginning training...") print(f"\n" + "="*80) train_losses, val_losses = train_model( model=clip_model, train_loader=train_loader, val_loader=val_loader, feature_models=feature_models, device=config.device, num_epochs=num_epochs, learning_rate=learning_rate, temperature=temperature, save_path=config.main_model_path, alignment_weight=alignment_weight, color_alignment_model=feature_models[config.color_column], weight_decay=weight_decay, reference_model=reference_clip, reference_weight=reference_weight ) print("\n" + "="*80) print("āœ… Training finished!") print(f" Model saved: {config.main_model_path}") print(f" Training curves: figures/training_curves.png") print("\nšŸ“Š Final results:") print(f" Last train loss: {train_losses[-1]:.4f}") print(f" Last validation loss: {val_losses[-1]:.4f}") print(f" Best validation loss: {min(val_losses):.4f}") print(f" Overfitting gap (val-train): {val_losses[-1] - train_losses[-1]:.4f}") if val_losses[-1] - train_losses[-1] > 0.1: print(" āš ļø Warning: Significant overfitting detected!") elif val_losses[-1] - train_losses[-1] < 0.05: print(" āœ… Good generalization!") print("="*80) if __name__ == "__main__": main()