#!/usr/bin/env python3 """ Main file for training the CLIP model with color and hierarchy alignment. This file centralizes all the logic for training the main model. It uses pre-trained color and hierarchy models to guide the main model's learning through contrastive and alignment loss functions. It handles data loading, training with validation, and checkpoint saving. """ import os # Set environment variable to disable tokenizers parallelism warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" import pandas as pd import numpy as np import torch import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms from PIL import Image import matplotlib.pyplot as plt from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers import warnings from tqdm import tqdm import json import config # Suppress warnings warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) # ------------------------------- # Loss Functions # ------------------------------- def enhanced_contrastive_loss(text_features, image_features, attribute_features, color_model, hierarchy_model, colors, hierarchies, temperature=0.07, alignment_weight=0.3, reference_text_features=None, reference_weight=0.1): """ Enhanced contrastive loss with direct alignment between color/hierarchy models and main model. This loss combines the original triple contrastive loss with direct alignment losses that force the main model's color and hierarchy dimensions to align with the specialized color and hierarchy models. Args: text_features: Main model text embeddings [batch_size, embed_dim] image_features: Main model image embeddings [batch_size, embed_dim] attribute_features: Concatenated color + hierarchy features [batch_size, color_dim + hierarchy_dim] color_model: Pre-trained color model for extracting color embeddings hierarchy_model: Pre-trained hierarchy model for extracting hierarchy embeddings colors: List of color strings for this batch [batch_size] hierarchies: List of hierarchy strings for this batch [batch_size] temperature: Temperature scaling parameter for contrastive loss (default: 0.07) alignment_weight: Weight for the alignment loss component (default: 0.3) Returns: Tuple of (total_loss, metrics_dict) where metrics_dict contains detailed loss components """ # Original triple contrastive loss text_features_norm = F.normalize(text_features, dim=-1) image_features_norm = F.normalize(image_features, dim=-1) attribute_features_norm = F.normalize(attribute_features, dim=-1) text_image_logits = (text_features_norm[:, config.color_emb_dim+config.hierarchy_emb_dim:] @ image_features_norm[:, config.color_emb_dim+config.hierarchy_emb_dim:].T) / temperature text_attr_logits = (text_features_norm[:, :config.color_emb_dim+config.hierarchy_emb_dim] @ attribute_features_norm.T) / temperature image_attr_logits = (attribute_features_norm @ image_features_norm[:,:config.color_emb_dim+config.hierarchy_emb_dim].T) / temperature # Weight distribution for original loss weight_text_image = 0.7 weight_attr_based = 0.15 original_logits = (weight_text_image * text_image_logits + weight_attr_based * text_attr_logits + weight_attr_based * image_attr_logits) labels = torch.arange(len(text_features)).to(text_features.device) original_loss = (F.cross_entropy(original_logits, labels) + F.cross_entropy(original_logits.T, labels)) / 2 # Direct alignment loss between color model and main model first 16 dims with torch.no_grad(): color_embeddings = color_model.get_text_embeddings(colors) hierarchy_embeddings = hierarchy_model.get_text_embeddings(hierarchies) # Extract color dimensions from main model embeddings main_color_text = text_features[:, :config.color_emb_dim] main_color_image = image_features[:, :config.color_emb_dim] # Extract hierarchy dimensions from main model embeddings main_hierarchy_text = text_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim] main_hierarchy_image = image_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim] # Normalize for better correlation color_embeddings_norm = F.normalize(color_embeddings, dim=-1) main_color_text_norm = F.normalize(main_color_text, dim=-1) main_color_image_norm = F.normalize(main_color_image, dim=-1) hierarchy_embeddings_norm = F.normalize(hierarchy_embeddings, dim=-1) main_hierarchy_text_norm = F.normalize(main_hierarchy_text, dim=-1) main_hierarchy_image_norm = F.normalize(main_hierarchy_image, dim=-1) # Color alignment loss using MSE and cosine similarity color_text_alignment_loss = F.mse_loss(main_color_text_norm, color_embeddings_norm) color_image_alignment_loss = F.mse_loss(main_color_image_norm, color_embeddings_norm) color_text_cosine_loss = 1 - F.cosine_similarity(main_color_text_norm, color_embeddings_norm).mean() color_image_cosine_loss = 1 - F.cosine_similarity(main_color_image_norm, color_embeddings_norm).mean() # Color alignment loss color_alignment_loss = ( color_text_alignment_loss + color_image_alignment_loss + color_text_cosine_loss + color_image_cosine_loss ) / 4 # Hierarchy alignment loss using MSE and cosine similarity hierarchy_text_alignment_loss = F.mse_loss(main_hierarchy_text_norm, hierarchy_embeddings_norm) hierarchy_image_alignment_loss = F.mse_loss(main_hierarchy_image_norm, hierarchy_embeddings_norm) hierarchy_text_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_text_norm, hierarchy_embeddings_norm).mean() hierarchy_image_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_image_norm, hierarchy_embeddings_norm).mean() # Hierarchy alignment loss hierarchy_alignment_loss = ( hierarchy_text_alignment_loss + hierarchy_image_alignment_loss + hierarchy_text_cosine_loss + hierarchy_image_cosine_loss ) / 4 # Combined alignment loss alignment_loss = (color_alignment_loss + hierarchy_alignment_loss) / 2 # Optional guidance to keep text space close to base CLIP (helps cross-domain generalization) reference_loss = 0.0 if reference_text_features is not None: reference_loss = F.mse_loss( F.normalize(text_features, dim=-1), F.normalize(reference_text_features, dim=-1) ) # Combine losses total_loss = (1 - alignment_weight) * original_loss + alignment_weight * alignment_loss if reference_text_features is not None: total_loss = total_loss + reference_weight * reference_loss return total_loss, { 'original_loss': original_loss.item(), 'alignment_loss': alignment_loss.item(), 'reference_loss': reference_loss if isinstance(reference_loss, float) else reference_loss.item(), 'color_text_alignment': color_text_alignment_loss.item(), 'color_image_alignment': color_image_alignment_loss.item(), 'color_text_cosine': color_text_cosine_loss.item(), 'color_image_cosine': color_image_cosine_loss.item(), 'hierarchy_text_alignment': hierarchy_text_alignment_loss.item(), 'hierarchy_image_alignment': hierarchy_image_alignment_loss.item(), 'hierarchy_text_cosine': hierarchy_text_cosine_loss.item(), 'hierarchy_image_cosine': hierarchy_image_cosine_loss.item() } # ------------------------------- # Training Functions # ------------------------------- def train_one_epoch(model, train_loader, optimizer, feature_models, color_model, hierarchy_model, device, clip_processor, temperature=0.07, alignment_weight=0.3, reference_model=None, reference_weight=0.1): """ Enhanced training with direct color and hierarchy alignment loss. This function trains the model using the enhanced contrastive loss that includes direct alignment between the main model's color/hierarchy dimensions and the specialized color/hierarchy models. Args: model: Main CLIP model to train train_loader: DataLoader for training data optimizer: Optimizer instance feature_models: Dictionary containing color and hierarchy models color_model: Pre-trained color model for alignment hierarchy_model: Pre-trained hierarchy model for alignment device: Device to train on clip_processor: CLIP processor for text preprocessing temperature: Temperature scaling parameter for contrastive loss (default: 0.07) alignment_weight: Weight for the alignment loss component (default: 0.3) Returns: Tuple of (average_loss, metrics_dict) where metrics_dict contains detailed loss components """ model.train() total_loss = 0.0 total_metrics = { 'original_loss': 0.0, 'alignment_loss': 0.0, 'reference_loss': 0.0, 'color_text_alignment': 0.0, 'color_image_alignment': 0.0, 'color_text_cosine': 0.0, 'color_image_cosine': 0.0, 'hierarchy_text_alignment': 0.0, 'hierarchy_image_alignment': 0.0, 'hierarchy_text_cosine': 0.0, 'hierarchy_image_cosine': 0.0 } num_batches = 0 pbar = tqdm(train_loader, desc="Training Enhanced", leave=False) for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar): # Move data to device images = images.to(device) images = images.expand(-1, 3, -1, -1) # Process text inputs text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt") text_inputs = {k: v.to(device) for k, v in text_inputs.items()} # Optional reference text features to keep close to base CLIP reference_text_features = None if reference_model is not None: with torch.no_grad(): reference_text_features = reference_model.get_text_features(**text_inputs) # Forward pass optimizer.zero_grad() outputs = model(**text_inputs, pixel_values=images) text_features = outputs.text_embeds image_features = outputs.image_embeds # Get feature embeddings if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'): color_features = feature_models[config.color_column].get_color_name_embeddings(colors) else: color_features = feature_models[config.color_column].get_text_embeddings(colors) hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy) concat_features = torch.cat((color_features, hierarchy_features), dim=1) # Calculate enhanced loss with hierarchy alignment loss, metrics = enhanced_contrastive_loss( text_features, image_features, concat_features, color_model, hierarchy_model, colors, hierarchy, temperature, alignment_weight, reference_text_features=reference_text_features, reference_weight=reference_weight ) # Backward pass loss.backward() # Gradient clipping to prevent exploding gradients torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() total_loss += loss.item() for key, value in metrics.items(): total_metrics[key] += value num_batches += 1 # Update progress bar pbar.set_postfix({ 'Loss': f'{loss.item():.4f}', 'Align': f'{metrics["alignment_loss"]:.4f}', 'ColCos': f'{metrics["color_text_cosine"]:.3f}', 'HierCos': f'{metrics["hierarchy_text_cosine"]:.3f}' }) avg_metrics = {key: value / num_batches for key, value in total_metrics.items()} return total_loss / num_batches, avg_metrics def valid_one_epoch(model, val_loader, feature_models, device, clip_processor, temperature=0.07, alignment_weight=0.3, reference_model=None, reference_weight=0.1): """ Validate the model for one epoch using enhanced contrastive loss. Args: model: Main CLIP model to validate val_loader: DataLoader for validation data feature_models: Dictionary containing color and hierarchy models device: Device to validate on clip_processor: CLIP processor for text preprocessing temperature: Temperature scaling parameter for contrastive loss (default: 0.07) alignment_weight: Weight for the alignment loss component (default: 0.3) Returns: Average validation loss for the epoch """ model.eval() total_loss = 0.0 num_batches = 0 # Extract color and hierarchy models color_model = feature_models[config.color_column] hierarchy_model = feature_models[config.hierarchy_column] # Create progress bar for validation pbar = tqdm(val_loader, desc="Validation", leave=False) with torch.no_grad(): for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar): # Move data to device images = images.to(device) images = images.expand(-1, 3, -1, -1) # Ensure 3 channels # Process text inputs text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt") text_inputs = {k: v.to(device) for k, v in text_inputs.items()} # Optional reference text features reference_text_features = None if reference_model is not None: reference_text_features = reference_model.get_text_features(**text_inputs) # Forward pass outputs = model(**text_inputs, pixel_values=images) text_features = outputs.text_embeds image_features = outputs.image_embeds # Get feature embeddings if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'): color_features = feature_models[config.color_column].get_color_name_embeddings(colors) else: color_features = feature_models[config.color_column].get_text_embeddings(colors) hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy) concat_features = torch.cat((color_features, hierarchy_features), dim=1) # Calculate loss with all required arguments loss, metrics = enhanced_contrastive_loss( text_features, image_features, concat_features, color_model, hierarchy_model, colors, hierarchy, temperature, alignment_weight, reference_text_features=reference_text_features, reference_weight=reference_weight ) total_loss += loss.item() num_batches += 1 # Update progress bar pbar.set_postfix({ 'Loss': f'{loss.item():.4f}', 'Avg Loss': f'{total_loss/num_batches:.4f}' }) return total_loss / num_batches # ------------------------------- # Dataset # ------------------------------- class CustomDataset(Dataset): """ Custom dataset for main model training. Handles loading images from local paths, extracting text descriptions, and applying appropriate transformations for training and validation. """ def __init__(self, dataframe, use_local_images=True, image_size=224): """ Initialize the custom dataset. Args: dataframe: DataFrame with columns for image paths, text descriptions, colors, and hierarchy labels use_local_images: Whether to use local images (default: True) image_size: Size of images after resizing (default: 224) """ self.dataframe = dataframe self.use_local_images = use_local_images self.image_size = image_size # Transforms with augmentation for training (increased augmentation to reduce overfitting) self.transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(15), # Increased for more variation transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15), # Increased intensity transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)), # Increased transform range transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.2), # Add blur transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Transforms for validation (no augmentation) self.val_transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) self.training_mode = True def set_training_mode(self, training=True): """ Switch between training and validation transforms. Args: training: If True, use training transforms with augmentation; if False, use validation transforms """ self.training_mode = training def __len__(self): """Return the number of samples in the dataset.""" return len(self.dataframe) def __getitem__(self, idx): """ Get a sample from the dataset. Args: idx: Index of the sample Returns: Tuple of (image_tensor, description_text, color_label, hierarchy_label) """ row = self.dataframe.iloc[idx] image_data = row[config.column_local_image_path] image = Image.open(image_data).convert("RGB") # Apply appropriate transform if self.training_mode: image = self.transform(image) else: image = self.val_transform(image) # Get text and labels description = row[config.text_column] color = row[config.color_column] hierarchy = row[config.hierarchy_column] return image, description, color, hierarchy # ------------------------------- # Model Loading # ------------------------------- def load_models(): """ Load color and hierarchy models from checkpoints. This function loads the pre-trained color and hierarchy models along with their tokenizers and extractors, and prepares them for use in main model training. Returns: Dictionary mapping model names to model instances: - 'color': ColorCLIP model instance - 'hierarchy': Hierarchy model instance """ from color_model import ColorCLIP, Tokenizer from hierarchy_model import Model, HierarchyExtractor # Initialize tokenizer first tokenizer = Tokenizer() # Load vocabulary if available if os.path.exists(config.tokeniser_path): with open(config.tokeniser_path, 'r') as f: vocab_dict = json.load(f) tokenizer.load_vocab(vocab_dict) print(f"Tokenizer vocabulary loaded from {config.tokeniser_path}") else: print(f"Warning: {config.tokeniser_path} not found. Using default tokenizer.") # Load trained model first to get correct vocab size checkpoint = torch.load(config.color_model_path, map_location=config.device) # Extract vocab size from the checkpoint's embedding layer vocab_size_from_checkpoint = checkpoint['text_encoder.embedding.weight'].shape[0] print(f"Vocab size from checkpoint: {vocab_size_from_checkpoint}") print(f"Vocab size from tokenizer: {tokenizer.counter}") # Use the larger of the two to ensure compatibility vocab_size = max(vocab_size_from_checkpoint, tokenizer.counter) # Initialize model with correct vocab size color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=config.color_emb_dim).to(config.device) color_model.tokenizer = tokenizer # Load the checkpoint color_model.load_state_dict(checkpoint) print(f"Color model loaded from {config.color_model_path}") color_model.eval() color_model.name = config.color_column # Load hierarchy model hierarchy_checkpoint = torch.load(config.hierarchy_model_path, map_location=config.device) hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', []) hierarchy_model = Model( num_hierarchy_classes=len(hierarchy_classes), embed_dim=config.hierarchy_emb_dim ).to(config.device) hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state']) # Set up hierarchy extractor hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False) hierarchy_model.set_hierarchy_extractor(hierarchy_extractor) hierarchy_model.eval() hierarchy_model.name = config.hierarchy_column feature_models = {model.name: model for model in [color_model, hierarchy_model]} return feature_models # ------------------------------- # Main Training Function # ------------------------------- def train_model(model, train_loader, val_loader, feature_models, device, num_epochs=20, learning_rate=1e-5, temperature=0.07, save_path=config.main_model_path, alignment_weight=0.3, color_alignment_model=None, weight_decay=3e-4, reference_model=None, reference_weight=0.1): """ Custom training loop using train_one_epoch and valid_one_epoch functions. This function handles the complete training process including: - Training and validation loops - Learning rate scheduling - Early stopping - Model checkpointing - Training curve visualization Args: model: Main CLIP model to train train_loader: DataLoader for training data val_loader: DataLoader for validation data feature_models: Dictionary containing color and hierarchy models device: Device to train on num_epochs: Number of training epochs (default: 20) learning_rate: Learning rate for optimizer (default: 1e-5) temperature: Temperature scaling parameter for contrastive loss (default: 0.07) save_path: Path to save model checkpoints (default: main_model_path) alignment_weight: Weight for alignment loss component if using enhanced loss (default: 0.3) color_alignment_model: Optional color model for alignment (default: None, uses feature_models) weight_decay: L2 regularization weight (default: 3e-4, increased to reduce overfitting) Returns: Tuple of (training_losses, validation_losses) lists """ model = model.to(device) # Use AdamW with weight decay for better regularization (reduces overfitting) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5) train_losses = [] val_losses = [] best_val_loss = float('inf') patience_counter = 0 patience = 7 # Increased from 5 to 7 for better convergence print(f"Starting training for {num_epochs} epochs...") print(f"Learning rate: {learning_rate}") print(f"Temperature: {temperature}") print(f"Weight decay: {weight_decay}") print(f"Alignment weight: {alignment_weight}") print(f"Device: {device}") print(f"Training samples: {len(train_loader.dataset)}") print(f"Validation samples: {len(val_loader.dataset)}") print(f"Batch size: {train_loader.batch_size}") print(f"Estimated time per epoch: ~{len(train_loader) * 2 / 60:.1f} minutes") # Create processor once for efficiency processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') # Freeze and move reference model (used for text-space regularization) if reference_model is not None: reference_model = reference_model.to(device) reference_model.eval() for param in reference_model.parameters(): param.requires_grad = False # Create progress bar for epochs epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", position=0) for epoch in epoch_pbar: # Update epoch progress bar epoch_pbar.set_description(f"Epoch {epoch+1}/{num_epochs}") # Training if color_alignment_model is None: color_alignment_model = feature_models[config.color_column] hierarchy_model = feature_models[config.hierarchy_column] train_loss, align_metrics = train_one_epoch( model, train_loader, optimizer, feature_models, color_alignment_model, hierarchy_model, device, processor, temperature, alignment_weight, reference_model=reference_model, reference_weight=reference_weight ) train_losses.append(train_loss) # Validation val_loss = valid_one_epoch( model, val_loader, feature_models, device, processor, temperature=temperature, alignment_weight=alignment_weight, reference_model=reference_model, reference_weight=reference_weight ) val_losses.append(val_loss) # Learning rate scheduling scheduler.step(val_loss) # Calculate overfitting gap overfitting_gap = val_loss - train_loss # Update epoch progress bar with metrics postfix = { 'Train Loss': f'{train_loss:.4f}', 'Val Loss': f'{val_loss:.4f}', 'Gap': f'{overfitting_gap:.4f}', 'LR': f'{optimizer.param_groups[0]["lr"]:.2e}', 'Best Val': f'{best_val_loss:.4f}' } if align_metrics is not None: postfix.update({ 'Align': f"{align_metrics['alignment_loss']:.3f}", 'ColCos': f"{align_metrics['color_text_cosine']:.3f}", 'HierCos': f"{align_metrics['hierarchy_text_cosine']:.3f}" }) epoch_pbar.set_postfix(postfix) # Warning if overfitting is detected if overfitting_gap > 0.15 and epoch > 3: print(f"\nāš ļø Warning: Significant overfitting detected at epoch {epoch+1} (gap={overfitting_gap:.4f})") # Save best model if val_loss < best_val_loss: best_val_loss = val_loss patience_counter = 0 # Save checkpoint torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_loss': train_loss, 'val_loss': val_loss, 'best_val_loss': best_val_loss, }, save_path) else: patience_counter += 1 # Early stopping if patience_counter >= patience: print(f"\nšŸ›‘ Early stopping triggered after {patience_counter} epochs without improvement") break # Plot training curves with overfitting analysis plt.figure(figsize=(15, 5)) # Plot 1: Training and Validation Loss plt.subplot(1, 3, 1) plt.plot(train_losses, label='Train Loss', color='blue', linewidth=2) plt.plot(val_losses, label='Val Loss', color='red', linewidth=2) plt.title('Training and Validation Loss', fontsize=12, fontweight='bold') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.grid(True, alpha=0.3) # Plot 2: Overfitting Gap (Val Loss - Train Loss) plt.subplot(1, 3, 2) gap = [val_losses[i] - train_losses[i] for i in range(len(train_losses))] plt.plot(gap, label='Overfitting Gap', color='purple', linewidth=2) plt.axhline(y=0, color='black', linestyle='--', alpha=0.3) plt.axhline(y=0.1, color='red', linestyle='--', alpha=0.3, label='Warning threshold') plt.title('Overfitting Gap (Val - Train)', fontsize=12, fontweight='bold') plt.xlabel('Epoch') plt.ylabel('Gap') plt.legend() plt.grid(True, alpha=0.3) # Plot 3: Loss comparison plt.subplot(1, 3, 3) epochs = list(range(len(train_losses))) plt.plot(epochs, train_losses, 'o-', label='Train Loss', color='blue', linewidth=2) plt.plot(epochs, val_losses, 's-', label='Val Loss', color='red', linewidth=2) plt.fill_between(epochs, train_losses, val_losses, alpha=0.2, color='red') plt.title('Loss Comparison', fontsize=12, fontweight='bold') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig('training_curves.png', dpi=300, bbox_inches='tight') plt.close() print(f"\nTraining completed!") print(f"Best validation loss: {best_val_loss:.4f}") print(f"Final model saved to: {save_path}") print(f"Training curves saved to: training_curves.png") return train_losses, val_losses # ------------------------------- # Main Function # ------------------------------- def main(): print("="*80) print("šŸš€ Training of the model with alignement color and hierarchy") print("="*80) # Configuration (optimized to reduce overfitting) num_epochs = 20 learning_rate = 1.5e-5 # Reduced slightly to prevent overfitting temperature = 0.09 # Increased from 0.07 for softer contrastive learning alignment_weight = 0.2 # Reduced from 0.3 to prevent overfitting on alignment weight_decay = 5e-4 # Increased weight decay for stronger regularization batch_size = 32 subset_size = 20000 # Increased dataset size for better generalization # Load the data print(f"\nšŸ“‚ Loading the data...") df = pd.read_csv(config.local_dataset_path) print(f" Data downloaded: {len(df)} samples") # filter the rows with NaN values df_clean = df.dropna(subset=[config.column_local_image_path]) print(f" After filtering NaN: {len(df_clean)} samples") # Creation of datasets dataset = CustomDataset(df_clean) # Creation of a subset for a faster training print(f"\nšŸ“Š Creation of a subset of {subset_size} samples...") subset_size = min(subset_size, len(dataset)) train_size = int(0.8 * subset_size) val_size = subset_size - train_size # Creation of a subset with random indexes but reproductibles np.random.seed(42) subset_indices = np.random.choice(len(dataset), subset_size, replace=False) subset_dataset = torch.utils.data.Subset(dataset, subset_indices) train_dataset, val_dataset = random_split( subset_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42) ) # Creation of dataloaders train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True if torch.cuda.is_available() else False ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True if torch.cuda.is_available() else False ) print(f" Train: {len(train_dataset)} samples") print(f" Validation: {len(val_dataset)} samples") # Loading models print(f"\nšŸ”§ Loading models...") feature_models = load_models() # Load or create the main model print(f"\nšŸ“¦ Loading main model...") clip_model = CLIPModel_transformers.from_pretrained( 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' ) # Frozen reference CLIP to regularize text space (improves cross-domain generalization) reference_clip = CLIPModel_transformers.from_pretrained( 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' ) # # Load the model # if os.path.exists(config.main_model_path): # print(f" Model found {config.main_model_path}") # print(f" Loading checkpoint...") # checkpoint = torch.load(config.main_model_path, map_location=config.device) # if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: # clip_model.load_state_dict(checkpoint['model_state_dict']) # print(f" āœ… Checkpoint loaded from {checkpoint.get('epoch', '?')}") # else: # clip_model.load_state_dict(checkpoint) # print(f" āœ… Checkpoint loaded") # else: # print(f" New model, no checkpoint found") # Move the model on the device clip_model = clip_model.to(config.device) reference_clip = reference_clip.to(config.device) reference_clip.eval() for param in reference_clip.parameters(): param.requires_grad = False # Training with enhanced loss print(f"\nšŸŽÆ Beginning training...") print(f"\n" + "="*80) train_losses, val_losses = train_model( model=clip_model, train_loader=train_loader, val_loader=val_loader, feature_models=feature_models, device=config.device, num_epochs=num_epochs, learning_rate=learning_rate, temperature=temperature, save_path=config.main_model_path, alignment_weight=alignment_weight, color_alignment_model=feature_models[config.color_column], weight_decay=weight_decay, reference_model=reference_clip, reference_weight=0.1 ) print("\n" + "="*80) print("āœ… Training finished!") print(f" Model saved: {config.main_model_path}") print(f" Training curves: training_curves.png") print("\nšŸ“Š Final results:") print(f" Last train loss: {train_losses[-1]:.4f}") print(f" Last validation loss: {val_losses[-1]:.4f}") print(f" Best validation loss: {min(val_losses):.4f}") print(f" Overfitting gap (val-train): {val_losses[-1] - train_losses[-1]:.4f}") if val_losses[-1] - train_losses[-1] > 0.1: print(" āš ļø Warning: Significant overfitting detected!") elif val_losses[-1] - train_losses[-1] < 0.05: print(" āœ… Good generalization!") print("="*80) if __name__ == "__main__": main()