Zero-Shot Image Classification
Transformers
Safetensors
English
clip
fashion
multimodal
image-search
text-search
embeddings
contrastive-learning
zero-shot-classification
Instructions to use Leacb4/gap-clip with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Leacb4/gap-clip with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("zero-shot-image-classification", model="Leacb4/gap-clip") pipe( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/hub/parrots.png", candidate_labels=["animals", "humans", "landscape"], )# Load model directly from transformers import AutoProcessor, AutoModelForZeroShotImageClassification processor = AutoProcessor.from_pretrained("Leacb4/gap-clip") model = AutoModelForZeroShotImageClassification.from_pretrained("Leacb4/gap-clip") - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """ | |
| Main file for training the CLIP model with color and hierarchy alignment. | |
| This file centralizes all the logic for training the main model. It uses | |
| pre-trained color and hierarchy models to guide the main model's learning | |
| through contrastive and alignment loss functions. It handles data loading, | |
| training with validation, and checkpoint saving. | |
| """ | |
| import os | |
| # Set environment variable to disable tokenizers parallelism warnings | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader, random_split | |
| from torchvision import transforms | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers | |
| import warnings | |
| from tqdm import tqdm | |
| import config | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # ------------------------------- | |
| # Loss Functions | |
| # ------------------------------- | |
| def enhanced_contrastive_loss(text_features, image_features, attribute_features, | |
| color_model, hierarchy_model, colors, hierarchies, temperature=0.07, alignment_weight=0.3, | |
| reference_text_features=None, reference_image_features=None, reference_weight=0.1): | |
| """ | |
| Enhanced contrastive loss with direct alignment between color/hierarchy models and main model. | |
| This loss combines the original triple contrastive loss with direct alignment losses | |
| that force the main model's color and hierarchy dimensions to align with the | |
| specialized color and hierarchy models. | |
| Args: | |
| text_features: Main model text embeddings [batch_size, embed_dim] | |
| image_features: Main model image embeddings [batch_size, embed_dim] | |
| attribute_features: Concatenated color + hierarchy features [batch_size, color_dim + hierarchy_dim] | |
| color_model: Pre-trained color model for extracting color embeddings | |
| hierarchy_model: Pre-trained hierarchy model for extracting hierarchy embeddings | |
| colors: List of color strings for this batch [batch_size] | |
| hierarchies: List of hierarchy strings for this batch [batch_size] | |
| temperature: Temperature scaling parameter for contrastive loss (default: 0.07) | |
| alignment_weight: Weight for the alignment loss component (default: 0.3) | |
| Returns: | |
| Tuple of (total_loss, metrics_dict) where metrics_dict contains detailed loss components | |
| """ | |
| # Original triple contrastive loss | |
| text_features_norm = F.normalize(text_features, dim=-1) | |
| image_features_norm = F.normalize(image_features, dim=-1) | |
| attribute_features_norm = F.normalize(attribute_features, dim=-1) | |
| text_image_logits = (text_features_norm[:, config.color_emb_dim+config.hierarchy_emb_dim:] @ | |
| image_features_norm[:, config.color_emb_dim+config.hierarchy_emb_dim:].T) / temperature | |
| text_attr_logits = (text_features_norm[:, :config.color_emb_dim+config.hierarchy_emb_dim] @ | |
| attribute_features_norm.T) / temperature | |
| image_attr_logits = (attribute_features_norm @ | |
| image_features_norm[:,:config.color_emb_dim+config.hierarchy_emb_dim].T) / temperature | |
| # Weight distribution for original loss | |
| weight_text_image = 0.7 | |
| weight_attr_based = 0.15 | |
| original_logits = (weight_text_image * text_image_logits + | |
| weight_attr_based * text_attr_logits + | |
| weight_attr_based * image_attr_logits) | |
| labels = torch.arange(len(text_features)).to(text_features.device) | |
| original_loss = (F.cross_entropy(original_logits, labels) + | |
| F.cross_entropy(original_logits.T, labels)) / 2 | |
| # Direct alignment loss between color model and main model first 16 dims | |
| with torch.no_grad(): | |
| color_embeddings = color_model.get_text_embeddings(colors) | |
| hierarchy_embeddings = hierarchy_model.get_text_embeddings(hierarchies) | |
| # Extract color dimensions from main model embeddings | |
| main_color_text = text_features[:, :config.color_emb_dim] | |
| main_color_image = image_features[:, :config.color_emb_dim] | |
| # Extract hierarchy dimensions from main model embeddings | |
| main_hierarchy_text = text_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim] | |
| main_hierarchy_image = image_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim] | |
| # Normalize for better correlation | |
| color_embeddings_norm = F.normalize(color_embeddings, dim=-1) | |
| main_color_text_norm = F.normalize(main_color_text, dim=-1) | |
| main_color_image_norm = F.normalize(main_color_image, dim=-1) | |
| hierarchy_embeddings_norm = F.normalize(hierarchy_embeddings, dim=-1) | |
| main_hierarchy_text_norm = F.normalize(main_hierarchy_text, dim=-1) | |
| main_hierarchy_image_norm = F.normalize(main_hierarchy_image, dim=-1) | |
| # Color alignment loss (cosine-only: more natural for normalized embeddings) | |
| color_text_cosine_loss = 1 - F.cosine_similarity(main_color_text_norm, color_embeddings_norm).mean() | |
| color_image_cosine_loss = 1 - F.cosine_similarity(main_color_image_norm, color_embeddings_norm).mean() | |
| color_alignment_loss = (color_text_cosine_loss + color_image_cosine_loss) / 2 | |
| # Hierarchy alignment loss (cosine-only) | |
| hierarchy_text_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_text_norm, hierarchy_embeddings_norm).mean() | |
| hierarchy_image_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_image_norm, hierarchy_embeddings_norm).mean() | |
| hierarchy_alignment_loss = (hierarchy_text_cosine_loss + hierarchy_image_cosine_loss) / 2 | |
| # Combined alignment loss | |
| alignment_loss = (color_alignment_loss + hierarchy_alignment_loss) / 2 | |
| # Reference loss to keep embeddings close to base CLIP (preserves zero-shot capability) | |
| reference_loss = 0.0 | |
| if reference_text_features is not None: | |
| text_ref_loss = F.mse_loss( | |
| F.normalize(text_features, dim=-1), | |
| F.normalize(reference_text_features, dim=-1) | |
| ) | |
| if reference_image_features is not None: | |
| image_ref_loss = F.mse_loss( | |
| F.normalize(image_features, dim=-1), | |
| F.normalize(reference_image_features, dim=-1) | |
| ) | |
| reference_loss = (text_ref_loss + image_ref_loss) / 2 | |
| else: | |
| reference_loss = text_ref_loss | |
| # Combine losses | |
| total_loss = (1 - alignment_weight) * original_loss + alignment_weight * alignment_loss | |
| if reference_text_features is not None: | |
| total_loss = total_loss + reference_weight * reference_loss | |
| return total_loss, { | |
| 'original_loss': original_loss.item(), | |
| 'alignment_loss': alignment_loss.item(), | |
| 'reference_loss': reference_loss if isinstance(reference_loss, float) else reference_loss.item(), | |
| 'color_text_cosine': color_text_cosine_loss.item(), | |
| 'color_image_cosine': color_image_cosine_loss.item(), | |
| 'hierarchy_text_cosine': hierarchy_text_cosine_loss.item(), | |
| 'hierarchy_image_cosine': hierarchy_image_cosine_loss.item() | |
| } | |
| # ------------------------------- | |
| # Training Functions | |
| # ------------------------------- | |
| def train_one_epoch(model, train_loader, optimizer, feature_models, color_model, hierarchy_model, | |
| device, clip_processor, temperature=0.07, alignment_weight=0.3, | |
| reference_model=None, reference_weight=0.1): | |
| """ | |
| Enhanced training with direct color and hierarchy alignment loss. | |
| This function trains the model using the enhanced contrastive loss that includes | |
| direct alignment between the main model's color/hierarchy dimensions and the | |
| specialized color/hierarchy models. | |
| Args: | |
| model: Main CLIP model to train | |
| train_loader: DataLoader for training data | |
| optimizer: Optimizer instance | |
| feature_models: Dictionary containing color and hierarchy models | |
| color_model: Pre-trained color model for alignment | |
| hierarchy_model: Pre-trained hierarchy model for alignment | |
| device: Device to train on | |
| clip_processor: CLIP processor for text preprocessing | |
| temperature: Temperature scaling parameter for contrastive loss (default: 0.07) | |
| alignment_weight: Weight for the alignment loss component (default: 0.3) | |
| Returns: | |
| Tuple of (average_loss, metrics_dict) where metrics_dict contains detailed loss components | |
| """ | |
| model.train() | |
| total_loss = 0.0 | |
| total_metrics = { | |
| 'original_loss': 0.0, | |
| 'alignment_loss': 0.0, | |
| 'reference_loss': 0.0, | |
| 'color_text_cosine': 0.0, | |
| 'color_image_cosine': 0.0, | |
| 'hierarchy_text_cosine': 0.0, | |
| 'hierarchy_image_cosine': 0.0 | |
| } | |
| num_batches = 0 | |
| pbar = tqdm(train_loader, desc="Training Enhanced", leave=False) | |
| for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar): | |
| # Move data to device | |
| images = images.to(device) | |
| images = images.expand(-1, 3, -1, -1) | |
| # Process text inputs | |
| text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt") | |
| text_inputs = {k: v.to(device) for k, v in text_inputs.items()} | |
| # Reference features to keep embeddings close to base CLIP | |
| reference_text_features = None | |
| reference_image_features = None | |
| if reference_model is not None: | |
| with torch.no_grad(): | |
| reference_text_features = reference_model.get_text_features(**text_inputs) | |
| reference_image_features = reference_model.get_image_features(pixel_values=images) | |
| # Forward pass | |
| optimizer.zero_grad() | |
| outputs = model(**text_inputs, pixel_values=images) | |
| text_features = outputs.text_embeds | |
| image_features = outputs.image_embeds | |
| # Get feature embeddings | |
| if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'): | |
| color_features = feature_models[config.color_column].get_color_name_embeddings(colors) | |
| else: | |
| color_features = feature_models[config.color_column].get_text_embeddings(colors) | |
| hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy) | |
| concat_features = torch.cat((color_features, hierarchy_features), dim=1) | |
| # Calculate enhanced loss with hierarchy alignment | |
| loss, metrics = enhanced_contrastive_loss( | |
| text_features, image_features, concat_features, | |
| color_model, hierarchy_model, colors, hierarchy, temperature, alignment_weight, | |
| reference_text_features=reference_text_features, | |
| reference_image_features=reference_image_features, | |
| reference_weight=reference_weight | |
| ) | |
| # Backward pass | |
| loss.backward() | |
| # Gradient clipping to prevent exploding gradients | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
| optimizer.step() | |
| total_loss += loss.item() | |
| for key, value in metrics.items(): | |
| total_metrics[key] += value | |
| num_batches += 1 | |
| # Update progress bar | |
| pbar.set_postfix({ | |
| 'Loss': f'{loss.item():.4f}', | |
| 'Align': f'{metrics["alignment_loss"]:.4f}', | |
| 'ColCos': f'{metrics["color_text_cosine"]:.3f}', | |
| 'HierCos': f'{metrics["hierarchy_text_cosine"]:.3f}' | |
| }) | |
| avg_metrics = {key: value / num_batches for key, value in total_metrics.items()} | |
| return total_loss / num_batches, avg_metrics | |
| def valid_one_epoch(model, val_loader, feature_models, device, clip_processor, temperature=0.07, alignment_weight=0.3, | |
| reference_model=None, reference_weight=0.1): | |
| """ | |
| Validate the model for one epoch using enhanced contrastive loss. | |
| Args: | |
| model: Main CLIP model to validate | |
| val_loader: DataLoader for validation data | |
| feature_models: Dictionary containing color and hierarchy models | |
| device: Device to validate on | |
| clip_processor: CLIP processor for text preprocessing | |
| temperature: Temperature scaling parameter for contrastive loss (default: 0.07) | |
| alignment_weight: Weight for the alignment loss component (default: 0.3) | |
| Returns: | |
| Average validation loss for the epoch | |
| """ | |
| model.eval() | |
| total_loss = 0.0 | |
| num_batches = 0 | |
| # Extract color and hierarchy models | |
| color_model = feature_models[config.color_column] | |
| hierarchy_model = feature_models[config.hierarchy_column] | |
| # Create progress bar for validation | |
| pbar = tqdm(val_loader, desc="Validation", leave=False) | |
| with torch.no_grad(): | |
| for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar): | |
| # Move data to device | |
| images = images.to(device) | |
| images = images.expand(-1, 3, -1, -1) # Ensure 3 channels | |
| # Process text inputs | |
| text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt") | |
| text_inputs = {k: v.to(device) for k, v in text_inputs.items()} | |
| # Reference features to keep embeddings close to base CLIP | |
| reference_text_features = None | |
| reference_image_features = None | |
| if reference_model is not None: | |
| reference_text_features = reference_model.get_text_features(**text_inputs) | |
| reference_image_features = reference_model.get_image_features(pixel_values=images) | |
| # Forward pass | |
| outputs = model(**text_inputs, pixel_values=images) | |
| text_features = outputs.text_embeds | |
| image_features = outputs.image_embeds | |
| # Get feature embeddings | |
| if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'): | |
| color_features = feature_models[config.color_column].get_color_name_embeddings(colors) | |
| else: | |
| color_features = feature_models[config.color_column].get_text_embeddings(colors) | |
| hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy) | |
| concat_features = torch.cat((color_features, hierarchy_features), dim=1) | |
| # Calculate loss with all required arguments | |
| loss, metrics = enhanced_contrastive_loss( | |
| text_features, image_features, concat_features, | |
| color_model, hierarchy_model, colors, hierarchy, | |
| temperature, alignment_weight, | |
| reference_text_features=reference_text_features, | |
| reference_image_features=reference_image_features, | |
| reference_weight=reference_weight | |
| ) | |
| total_loss += loss.item() | |
| num_batches += 1 | |
| # Update progress bar | |
| pbar.set_postfix({ | |
| 'Loss': f'{loss.item():.4f}', | |
| 'Avg Loss': f'{total_loss/num_batches:.4f}' | |
| }) | |
| return total_loss / num_batches | |
| # ------------------------------- | |
| # Dataset | |
| # ------------------------------- | |
| class CustomDataset(Dataset): | |
| """ | |
| Custom dataset for main model training. | |
| Handles loading images from local paths, extracting text descriptions, | |
| and applying appropriate transformations for training and validation. | |
| """ | |
| def __init__(self, dataframe, use_local_images=True, image_size=224): | |
| """ | |
| Initialize the custom dataset. | |
| Args: | |
| dataframe: DataFrame with columns for image paths, text descriptions, colors, and hierarchy labels | |
| use_local_images: Whether to use local images (default: True) | |
| image_size: Size of images after resizing (default: 224) | |
| """ | |
| self.dataframe = dataframe | |
| self.use_local_images = use_local_images | |
| self.image_size = image_size | |
| # Transforms with augmentation for training (increased augmentation to reduce overfitting) | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomRotation(15), # Increased for more variation | |
| transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15), # Increased intensity | |
| transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)), # Increased transform range | |
| transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.2), # Add blur | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Transforms for validation (no augmentation) | |
| self.val_transform = transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| self.training_mode = True | |
| def set_training_mode(self, training=True): | |
| """ | |
| Switch between training and validation transforms. | |
| Args: | |
| training: If True, use training transforms with augmentation; if False, use validation transforms | |
| """ | |
| self.training_mode = training | |
| def __len__(self): | |
| """Return the number of samples in the dataset.""" | |
| return len(self.dataframe) | |
| def __getitem__(self, idx): | |
| """ | |
| Get a sample from the dataset. | |
| Args: | |
| idx: Index of the sample | |
| Returns: | |
| Tuple of (image_tensor, description_text, color_label, hierarchy_label) | |
| """ | |
| row = self.dataframe.iloc[idx] | |
| image_data = row[config.column_local_image_path] | |
| image = Image.open(image_data).convert("RGB") | |
| # Apply appropriate transform | |
| if self.training_mode: | |
| image = self.transform(image) | |
| else: | |
| image = self.val_transform(image) | |
| # Get text and labels | |
| description = row[config.text_column] | |
| color = row[config.color_column] | |
| hierarchy = row[config.hierarchy_column] | |
| return image, description, color, hierarchy | |
| # ------------------------------- | |
| # Model Loading | |
| # ------------------------------- | |
| def load_models(): | |
| """ | |
| Load color and hierarchy models from checkpoints. | |
| Returns: | |
| Dictionary mapping model names to model instances: | |
| - 'color': ColorCLIP model instance | |
| - 'hierarchy': HierarchyModel instance | |
| """ | |
| from training.color_model import ColorCLIP | |
| from training.hierarchy_model import HierarchyModel | |
| # --- Color model --- | |
| print("Loading ColorCLIP (CLIP-backbone) ...") | |
| color_model = ColorCLIP.from_checkpoint(config.color_model_path, device=config.device) | |
| color_model.eval() | |
| color_model.name = config.color_column | |
| # --- Hierarchy model --- | |
| print("Loading HierarchyModel (CLIP-backbone) ...") | |
| hierarchy_model = HierarchyModel.from_checkpoint(config.hierarchy_model_path, device=config.device) | |
| hierarchy_model.eval() | |
| hierarchy_model.name = config.hierarchy_column | |
| feature_models = {model.name: model for model in [color_model, hierarchy_model]} | |
| return feature_models | |
| # ------------------------------- | |
| # Main Training Function | |
| # ------------------------------- | |
| def train_model(model, train_loader, val_loader, feature_models, device, | |
| num_epochs=20, learning_rate=1e-5, temperature=0.07, | |
| save_path=config.main_model_path, alignment_weight=0.3, | |
| color_alignment_model=None, weight_decay=3e-4, | |
| reference_model=None, reference_weight=0.1): | |
| """ | |
| Custom training loop using train_one_epoch and valid_one_epoch functions. | |
| This function handles the complete training process including: | |
| - Training and validation loops | |
| - Learning rate scheduling | |
| - Early stopping | |
| - Model checkpointing | |
| - Training curve visualization | |
| Args: | |
| model: Main CLIP model to train | |
| train_loader: DataLoader for training data | |
| val_loader: DataLoader for validation data | |
| feature_models: Dictionary containing color and hierarchy models | |
| device: Device to train on | |
| num_epochs: Number of training epochs (default: 20) | |
| learning_rate: Learning rate for optimizer (default: 1e-5) | |
| temperature: Temperature scaling parameter for contrastive loss (default: 0.07) | |
| save_path: Path to save model checkpoints (default: main_model_path) | |
| alignment_weight: Weight for alignment loss component if using enhanced loss (default: 0.3) | |
| color_alignment_model: Optional color model for alignment (default: None, uses feature_models) | |
| weight_decay: L2 regularization weight (default: 3e-4, increased to reduce overfitting) | |
| Returns: | |
| Tuple of (training_losses, validation_losses) lists | |
| """ | |
| model = model.to(device) | |
| # Use AdamW with weight decay for better regularization (reduces overfitting) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5) | |
| train_losses = [] | |
| val_losses = [] | |
| best_val_loss = float('inf') | |
| patience_counter = 0 | |
| patience = 7 # Increased from 5 to 7 for better convergence | |
| print(f"Starting training for {num_epochs} epochs...") | |
| print(f"Learning rate: {learning_rate}") | |
| print(f"Temperature: {temperature}") | |
| print(f"Weight decay: {weight_decay}") | |
| print(f"Alignment weight: {alignment_weight}") | |
| print(f"Device: {device}") | |
| print(f"Training samples: {len(train_loader.dataset)}") | |
| print(f"Validation samples: {len(val_loader.dataset)}") | |
| print(f"Batch size: {train_loader.batch_size}") | |
| print(f"Estimated time per epoch: ~{len(train_loader) * 2 / 60:.1f} minutes") | |
| # Create processor once for efficiency | |
| processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') | |
| # Freeze and move reference model (used for text-space regularization) | |
| if reference_model is not None: | |
| reference_model = reference_model.to(device) | |
| reference_model.eval() | |
| for param in reference_model.parameters(): | |
| param.requires_grad = False | |
| # Create progress bar for epochs | |
| epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", position=0) | |
| for epoch in epoch_pbar: | |
| # Update epoch progress bar | |
| epoch_pbar.set_description(f"Epoch {epoch+1}/{num_epochs}") | |
| # Training | |
| if color_alignment_model is None: | |
| color_alignment_model = feature_models[config.color_column] | |
| hierarchy_model = feature_models[config.hierarchy_column] | |
| train_loss, align_metrics = train_one_epoch( | |
| model, train_loader, optimizer, feature_models, color_alignment_model, hierarchy_model, | |
| device, processor, temperature, alignment_weight, | |
| reference_model=reference_model, reference_weight=reference_weight | |
| ) | |
| train_losses.append(train_loss) | |
| # Validation | |
| val_loss = valid_one_epoch( | |
| model, val_loader, feature_models, device, processor, | |
| temperature=temperature, alignment_weight=alignment_weight, | |
| reference_model=reference_model, reference_weight=reference_weight | |
| ) | |
| val_losses.append(val_loss) | |
| # Learning rate scheduling | |
| scheduler.step(val_loss) | |
| # Calculate overfitting gap | |
| overfitting_gap = val_loss - train_loss | |
| # Update epoch progress bar with metrics | |
| postfix = { | |
| 'Train Loss': f'{train_loss:.4f}', | |
| 'Val Loss': f'{val_loss:.4f}', | |
| 'Gap': f'{overfitting_gap:.4f}', | |
| 'LR': f'{optimizer.param_groups[0]["lr"]:.2e}', | |
| 'Best Val': f'{best_val_loss:.4f}' | |
| } | |
| if align_metrics is not None: | |
| postfix.update({ | |
| 'Align': f"{align_metrics['alignment_loss']:.3f}", | |
| 'ColCos': f"{align_metrics['color_text_cosine']:.3f}", | |
| 'HierCos': f"{align_metrics['hierarchy_text_cosine']:.3f}" | |
| }) | |
| epoch_pbar.set_postfix(postfix) | |
| # Warning if overfitting is detected | |
| if overfitting_gap > 0.15 and epoch > 3: | |
| print(f"\n⚠️ Warning: Significant overfitting detected at epoch {epoch+1} (gap={overfitting_gap:.4f})") | |
| # Save best model | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| patience_counter = 0 | |
| # Save checkpoint | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'train_loss': train_loss, | |
| 'val_loss': val_loss, | |
| 'best_val_loss': best_val_loss, | |
| }, save_path) | |
| else: | |
| patience_counter += 1 | |
| # Early stopping | |
| if patience_counter >= patience: | |
| print(f"\n🛑 Early stopping triggered after {patience_counter} epochs without improvement") | |
| break | |
| # Plot training curves with overfitting analysis | |
| plt.figure(figsize=(15, 5)) | |
| # Plot 1: Training and Validation Loss | |
| plt.subplot(1, 3, 1) | |
| plt.plot(train_losses, label='Train Loss', color='blue', linewidth=2) | |
| plt.plot(val_losses, label='Val Loss', color='red', linewidth=2) | |
| plt.title('Training and Validation Loss', fontsize=12, fontweight='bold') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss') | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| # Plot 2: Overfitting Gap (Val Loss - Train Loss) | |
| plt.subplot(1, 3, 2) | |
| gap = [val_losses[i] - train_losses[i] for i in range(len(train_losses))] | |
| plt.plot(gap, label='Overfitting Gap', color='purple', linewidth=2) | |
| plt.axhline(y=0, color='black', linestyle='--', alpha=0.3) | |
| plt.axhline(y=0.1, color='red', linestyle='--', alpha=0.3, label='Warning threshold') | |
| plt.title('Overfitting Gap (Val - Train)', fontsize=12, fontweight='bold') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Gap') | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| # Plot 3: Loss comparison | |
| plt.subplot(1, 3, 3) | |
| epochs = list(range(len(train_losses))) | |
| plt.plot(epochs, train_losses, 'o-', label='Train Loss', color='blue', linewidth=2) | |
| plt.plot(epochs, val_losses, 's-', label='Val Loss', color='red', linewidth=2) | |
| plt.fill_between(epochs, train_losses, val_losses, alpha=0.2, color='red') | |
| plt.title('Loss Comparison', fontsize=12, fontweight='bold') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss') | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| curves_path = str(config.ROOT_DIR / "figures" / "training_curves.png") | |
| plt.savefig(curves_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| print(f"\nTraining completed!") | |
| print(f"Best validation loss: {best_val_loss:.4f}") | |
| print(f"Final model saved to: {save_path}") | |
| print(f"Training curves saved to: {curves_path}") | |
| return train_losses, val_losses | |
| # ------------------------------- | |
| # Main Function | |
| # ------------------------------- | |
| def main(): | |
| print("="*80) | |
| print("🚀 Training of the model with alignment color and hierarchy") | |
| print("="*80) | |
| # Configuration (tuned for zero-shot + separation balance) | |
| num_epochs = 10 | |
| learning_rate = 1.5e-5 | |
| temperature = 0.09 | |
| alignment_weight = 0.10 # reduced from 0.2: softer alignment preserves CLIP zero-shot | |
| reference_weight = 0.25 # increased from 0.1: stronger regularization toward base CLIP | |
| weight_decay = 1e-3 # increased from 5e-4: better generalization | |
| batch_size = 128 | |
| subset_size = 100000 | |
| # Load the data | |
| print(f"\n📂 Loading the data...") | |
| df = pd.read_csv(config.local_dataset_path) | |
| print(f" Data downloaded: {len(df)} samples") | |
| # filter the rows with NaN values | |
| df_clean = df.dropna(subset=[config.column_local_image_path]) | |
| df_clean = df_clean[df_clean[config.column_local_image_path].astype(str).str.len() > 0] | |
| print(f" After filtering NaN: {len(df_clean)} samples") | |
| # Creation of datasets | |
| dataset = CustomDataset(df_clean) | |
| # Sample 100k for training | |
| subset_size = min(subset_size, len(dataset)) | |
| train_size = int(0.8 * subset_size) | |
| val_size = subset_size - train_size | |
| np.random.seed(42) | |
| subset_indices = np.random.choice(len(dataset), subset_size, replace=False) | |
| subset_dataset = torch.utils.data.Subset(dataset, subset_indices) | |
| train_dataset, val_dataset = random_split( | |
| subset_dataset, | |
| [train_size, val_size], | |
| generator=torch.Generator().manual_seed(42) | |
| ) | |
| # Creation of dataloaders | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=2, | |
| pin_memory=True if torch.cuda.is_available() else False | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=2, | |
| pin_memory=True if torch.cuda.is_available() else False | |
| ) | |
| print(f" Train: {len(train_dataset)} samples") | |
| print(f" Validation: {len(val_dataset)} samples") | |
| # Loading models | |
| print(f"\n🔧 Loading models...") | |
| feature_models = load_models() | |
| # Load or create the main model | |
| print(f"\n📦 Loading main model...") | |
| clip_model = CLIPModel_transformers.from_pretrained( | |
| 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' | |
| ) | |
| # Frozen reference CLIP to regularize text space (improves cross-domain generalization) | |
| reference_clip = CLIPModel_transformers.from_pretrained( | |
| 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' | |
| ) | |
| # Move the model on the device | |
| clip_model = clip_model.to(config.device) | |
| reference_clip = reference_clip.to(config.device) | |
| reference_clip.eval() | |
| for param in reference_clip.parameters(): | |
| param.requires_grad = False | |
| # Training with enhanced loss | |
| print(f"\n🎯 Beginning training...") | |
| print(f"\n" + "="*80) | |
| train_losses, val_losses = train_model( | |
| model=clip_model, | |
| train_loader=train_loader, | |
| val_loader=val_loader, | |
| feature_models=feature_models, | |
| device=config.device, | |
| num_epochs=num_epochs, | |
| learning_rate=learning_rate, | |
| temperature=temperature, | |
| save_path=config.main_model_path, | |
| alignment_weight=alignment_weight, | |
| color_alignment_model=feature_models[config.color_column], | |
| weight_decay=weight_decay, | |
| reference_model=reference_clip, | |
| reference_weight=reference_weight | |
| ) | |
| print("\n" + "="*80) | |
| print("✅ Training finished!") | |
| print(f" Model saved: {config.main_model_path}") | |
| print(f" Training curves: figures/training_curves.png") | |
| print("\n📊 Final results:") | |
| print(f" Last train loss: {train_losses[-1]:.4f}") | |
| print(f" Last validation loss: {val_losses[-1]:.4f}") | |
| print(f" Best validation loss: {min(val_losses):.4f}") | |
| print(f" Overfitting gap (val-train): {val_losses[-1] - train_losses[-1]:.4f}") | |
| if val_losses[-1] - train_losses[-1] > 0.1: | |
| print(" ⚠️ Warning: Significant overfitting detected!") | |
| elif val_losses[-1] - train_losses[-1] < 0.05: | |
| print(" ✅ Good generalization!") | |
| print("="*80) | |
| if __name__ == "__main__": | |
| main() | |