gap-clip / main_model.py
Leacb4's picture
Upload main_model.py with huggingface_hub
a31cbe6 verified
#!/usr/bin/env python3
"""
Main file for training the CLIP model with color and hierarchy alignment.
This file centralizes all the logic for training the main model. It uses
pre-trained color and hierarchy models to guide the main model's learning
through contrastive and alignment loss functions. It handles data loading,
training with validation, and checkpoint saving.
"""
import os
# Set environment variable to disable tokenizers parallelism warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers
import warnings
from tqdm import tqdm
import json
import config
# Suppress warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# -------------------------------
# Loss Functions
# -------------------------------
def enhanced_contrastive_loss(text_features, image_features, attribute_features,
color_model, hierarchy_model, colors, hierarchies, temperature=0.07, alignment_weight=0.3,
reference_text_features=None, reference_weight=0.1):
"""
Enhanced contrastive loss with direct alignment between color/hierarchy models and main model.
This loss combines the original triple contrastive loss with direct alignment losses
that force the main model's color and hierarchy dimensions to align with the
specialized color and hierarchy models.
Args:
text_features: Main model text embeddings [batch_size, embed_dim]
image_features: Main model image embeddings [batch_size, embed_dim]
attribute_features: Concatenated color + hierarchy features [batch_size, color_dim + hierarchy_dim]
color_model: Pre-trained color model for extracting color embeddings
hierarchy_model: Pre-trained hierarchy model for extracting hierarchy embeddings
colors: List of color strings for this batch [batch_size]
hierarchies: List of hierarchy strings for this batch [batch_size]
temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
alignment_weight: Weight for the alignment loss component (default: 0.3)
Returns:
Tuple of (total_loss, metrics_dict) where metrics_dict contains detailed loss components
"""
# Original triple contrastive loss
text_features_norm = F.normalize(text_features, dim=-1)
image_features_norm = F.normalize(image_features, dim=-1)
attribute_features_norm = F.normalize(attribute_features, dim=-1)
text_image_logits = (text_features_norm[:, config.color_emb_dim+config.hierarchy_emb_dim:] @
image_features_norm[:, config.color_emb_dim+config.hierarchy_emb_dim:].T) / temperature
text_attr_logits = (text_features_norm[:, :config.color_emb_dim+config.hierarchy_emb_dim] @
attribute_features_norm.T) / temperature
image_attr_logits = (attribute_features_norm @
image_features_norm[:,:config.color_emb_dim+config.hierarchy_emb_dim].T) / temperature
# Weight distribution for original loss
weight_text_image = 0.7
weight_attr_based = 0.15
original_logits = (weight_text_image * text_image_logits +
weight_attr_based * text_attr_logits +
weight_attr_based * image_attr_logits)
labels = torch.arange(len(text_features)).to(text_features.device)
original_loss = (F.cross_entropy(original_logits, labels) +
F.cross_entropy(original_logits.T, labels)) / 2
# Direct alignment loss between color model and main model first 16 dims
with torch.no_grad():
color_embeddings = color_model.get_text_embeddings(colors)
hierarchy_embeddings = hierarchy_model.get_text_embeddings(hierarchies)
# Extract color dimensions from main model embeddings
main_color_text = text_features[:, :config.color_emb_dim]
main_color_image = image_features[:, :config.color_emb_dim]
# Extract hierarchy dimensions from main model embeddings
main_hierarchy_text = text_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim]
main_hierarchy_image = image_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim]
# Normalize for better correlation
color_embeddings_norm = F.normalize(color_embeddings, dim=-1)
main_color_text_norm = F.normalize(main_color_text, dim=-1)
main_color_image_norm = F.normalize(main_color_image, dim=-1)
hierarchy_embeddings_norm = F.normalize(hierarchy_embeddings, dim=-1)
main_hierarchy_text_norm = F.normalize(main_hierarchy_text, dim=-1)
main_hierarchy_image_norm = F.normalize(main_hierarchy_image, dim=-1)
# Color alignment loss using MSE and cosine similarity
color_text_alignment_loss = F.mse_loss(main_color_text_norm, color_embeddings_norm)
color_image_alignment_loss = F.mse_loss(main_color_image_norm, color_embeddings_norm)
color_text_cosine_loss = 1 - F.cosine_similarity(main_color_text_norm, color_embeddings_norm).mean()
color_image_cosine_loss = 1 - F.cosine_similarity(main_color_image_norm, color_embeddings_norm).mean()
# Color alignment loss
color_alignment_loss = (
color_text_alignment_loss + color_image_alignment_loss +
color_text_cosine_loss + color_image_cosine_loss
) / 4
# Hierarchy alignment loss using MSE and cosine similarity
hierarchy_text_alignment_loss = F.mse_loss(main_hierarchy_text_norm, hierarchy_embeddings_norm)
hierarchy_image_alignment_loss = F.mse_loss(main_hierarchy_image_norm, hierarchy_embeddings_norm)
hierarchy_text_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_text_norm, hierarchy_embeddings_norm).mean()
hierarchy_image_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_image_norm, hierarchy_embeddings_norm).mean()
# Hierarchy alignment loss
hierarchy_alignment_loss = (
hierarchy_text_alignment_loss + hierarchy_image_alignment_loss +
hierarchy_text_cosine_loss + hierarchy_image_cosine_loss
) / 4
# Combined alignment loss
alignment_loss = (color_alignment_loss + hierarchy_alignment_loss) / 2
# Optional guidance to keep text space close to base CLIP (helps cross-domain generalization)
reference_loss = 0.0
if reference_text_features is not None:
reference_loss = F.mse_loss(
F.normalize(text_features, dim=-1),
F.normalize(reference_text_features, dim=-1)
)
# Combine losses
total_loss = (1 - alignment_weight) * original_loss + alignment_weight * alignment_loss
if reference_text_features is not None:
total_loss = total_loss + reference_weight * reference_loss
return total_loss, {
'original_loss': original_loss.item(),
'alignment_loss': alignment_loss.item(),
'reference_loss': reference_loss if isinstance(reference_loss, float) else reference_loss.item(),
'color_text_alignment': color_text_alignment_loss.item(),
'color_image_alignment': color_image_alignment_loss.item(),
'color_text_cosine': color_text_cosine_loss.item(),
'color_image_cosine': color_image_cosine_loss.item(),
'hierarchy_text_alignment': hierarchy_text_alignment_loss.item(),
'hierarchy_image_alignment': hierarchy_image_alignment_loss.item(),
'hierarchy_text_cosine': hierarchy_text_cosine_loss.item(),
'hierarchy_image_cosine': hierarchy_image_cosine_loss.item()
}
# -------------------------------
# Training Functions
# -------------------------------
def train_one_epoch(model, train_loader, optimizer, feature_models, color_model, hierarchy_model,
device, clip_processor, temperature=0.07, alignment_weight=0.3,
reference_model=None, reference_weight=0.1):
"""
Enhanced training with direct color and hierarchy alignment loss.
This function trains the model using the enhanced contrastive loss that includes
direct alignment between the main model's color/hierarchy dimensions and the
specialized color/hierarchy models.
Args:
model: Main CLIP model to train
train_loader: DataLoader for training data
optimizer: Optimizer instance
feature_models: Dictionary containing color and hierarchy models
color_model: Pre-trained color model for alignment
hierarchy_model: Pre-trained hierarchy model for alignment
device: Device to train on
clip_processor: CLIP processor for text preprocessing
temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
alignment_weight: Weight for the alignment loss component (default: 0.3)
Returns:
Tuple of (average_loss, metrics_dict) where metrics_dict contains detailed loss components
"""
model.train()
total_loss = 0.0
total_metrics = {
'original_loss': 0.0,
'alignment_loss': 0.0,
'reference_loss': 0.0,
'color_text_alignment': 0.0,
'color_image_alignment': 0.0,
'color_text_cosine': 0.0,
'color_image_cosine': 0.0,
'hierarchy_text_alignment': 0.0,
'hierarchy_image_alignment': 0.0,
'hierarchy_text_cosine': 0.0,
'hierarchy_image_cosine': 0.0
}
num_batches = 0
pbar = tqdm(train_loader, desc="Training Enhanced", leave=False)
for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar):
# Move data to device
images = images.to(device)
images = images.expand(-1, 3, -1, -1)
# Process text inputs
text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
# Optional reference text features to keep close to base CLIP
reference_text_features = None
if reference_model is not None:
with torch.no_grad():
reference_text_features = reference_model.get_text_features(**text_inputs)
# Forward pass
optimizer.zero_grad()
outputs = model(**text_inputs, pixel_values=images)
text_features = outputs.text_embeds
image_features = outputs.image_embeds
# Get feature embeddings
if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'):
color_features = feature_models[config.color_column].get_color_name_embeddings(colors)
else:
color_features = feature_models[config.color_column].get_text_embeddings(colors)
hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy)
concat_features = torch.cat((color_features, hierarchy_features), dim=1)
# Calculate enhanced loss with hierarchy alignment
loss, metrics = enhanced_contrastive_loss(
text_features, image_features, concat_features,
color_model, hierarchy_model, colors, hierarchy, temperature, alignment_weight,
reference_text_features=reference_text_features, reference_weight=reference_weight
)
# Backward pass
loss.backward()
# Gradient clipping to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
for key, value in metrics.items():
total_metrics[key] += value
num_batches += 1
# Update progress bar
pbar.set_postfix({
'Loss': f'{loss.item():.4f}',
'Align': f'{metrics["alignment_loss"]:.4f}',
'ColCos': f'{metrics["color_text_cosine"]:.3f}',
'HierCos': f'{metrics["hierarchy_text_cosine"]:.3f}'
})
avg_metrics = {key: value / num_batches for key, value in total_metrics.items()}
return total_loss / num_batches, avg_metrics
def valid_one_epoch(model, val_loader, feature_models, device, clip_processor, temperature=0.07, alignment_weight=0.3,
reference_model=None, reference_weight=0.1):
"""
Validate the model for one epoch using enhanced contrastive loss.
Args:
model: Main CLIP model to validate
val_loader: DataLoader for validation data
feature_models: Dictionary containing color and hierarchy models
device: Device to validate on
clip_processor: CLIP processor for text preprocessing
temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
alignment_weight: Weight for the alignment loss component (default: 0.3)
Returns:
Average validation loss for the epoch
"""
model.eval()
total_loss = 0.0
num_batches = 0
# Extract color and hierarchy models
color_model = feature_models[config.color_column]
hierarchy_model = feature_models[config.hierarchy_column]
# Create progress bar for validation
pbar = tqdm(val_loader, desc="Validation", leave=False)
with torch.no_grad():
for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar):
# Move data to device
images = images.to(device)
images = images.expand(-1, 3, -1, -1) # Ensure 3 channels
# Process text inputs
text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt")
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
# Optional reference text features
reference_text_features = None
if reference_model is not None:
reference_text_features = reference_model.get_text_features(**text_inputs)
# Forward pass
outputs = model(**text_inputs, pixel_values=images)
text_features = outputs.text_embeds
image_features = outputs.image_embeds
# Get feature embeddings
if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'):
color_features = feature_models[config.color_column].get_color_name_embeddings(colors)
else:
color_features = feature_models[config.color_column].get_text_embeddings(colors)
hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy)
concat_features = torch.cat((color_features, hierarchy_features), dim=1)
# Calculate loss with all required arguments
loss, metrics = enhanced_contrastive_loss(
text_features, image_features, concat_features,
color_model, hierarchy_model, colors, hierarchy,
temperature, alignment_weight,
reference_text_features=reference_text_features, reference_weight=reference_weight
)
total_loss += loss.item()
num_batches += 1
# Update progress bar
pbar.set_postfix({
'Loss': f'{loss.item():.4f}',
'Avg Loss': f'{total_loss/num_batches:.4f}'
})
return total_loss / num_batches
# -------------------------------
# Dataset
# -------------------------------
class CustomDataset(Dataset):
"""
Custom dataset for main model training.
Handles loading images from local paths, extracting text descriptions,
and applying appropriate transformations for training and validation.
"""
def __init__(self, dataframe, use_local_images=True, image_size=224):
"""
Initialize the custom dataset.
Args:
dataframe: DataFrame with columns for image paths, text descriptions, colors, and hierarchy labels
use_local_images: Whether to use local images (default: True)
image_size: Size of images after resizing (default: 224)
"""
self.dataframe = dataframe
self.use_local_images = use_local_images
self.image_size = image_size
# Transforms with augmentation for training (increased augmentation to reduce overfitting)
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15), # Increased for more variation
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15), # Increased intensity
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)), # Increased transform range
transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.2), # Add blur
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Transforms for validation (no augmentation)
self.val_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.training_mode = True
def set_training_mode(self, training=True):
"""
Switch between training and validation transforms.
Args:
training: If True, use training transforms with augmentation; if False, use validation transforms
"""
self.training_mode = training
def __len__(self):
"""Return the number of samples in the dataset."""
return len(self.dataframe)
def __getitem__(self, idx):
"""
Get a sample from the dataset.
Args:
idx: Index of the sample
Returns:
Tuple of (image_tensor, description_text, color_label, hierarchy_label)
"""
row = self.dataframe.iloc[idx]
image_data = row[config.column_local_image_path]
image = Image.open(image_data).convert("RGB")
# Apply appropriate transform
if self.training_mode:
image = self.transform(image)
else:
image = self.val_transform(image)
# Get text and labels
description = row[config.text_column]
color = row[config.color_column]
hierarchy = row[config.hierarchy_column]
return image, description, color, hierarchy
# -------------------------------
# Model Loading
# -------------------------------
def load_models():
"""
Load color and hierarchy models from checkpoints.
This function loads the pre-trained color and hierarchy models along with
their tokenizers and extractors, and prepares them for use in main model training.
Returns:
Dictionary mapping model names to model instances:
- 'color': ColorCLIP model instance
- 'hierarchy': Hierarchy model instance
"""
from color_model import ColorCLIP, Tokenizer
from hierarchy_model import Model, HierarchyExtractor
# Initialize tokenizer first
tokenizer = Tokenizer()
# Load vocabulary if available
if os.path.exists(config.tokeniser_path):
with open(config.tokeniser_path, 'r') as f:
vocab_dict = json.load(f)
tokenizer.load_vocab(vocab_dict)
print(f"Tokenizer vocabulary loaded from {config.tokeniser_path}")
else:
print(f"Warning: {config.tokeniser_path} not found. Using default tokenizer.")
# Load trained model first to get correct vocab size
checkpoint = torch.load(config.color_model_path, map_location=config.device)
# Extract vocab size from the checkpoint's embedding layer
vocab_size_from_checkpoint = checkpoint['text_encoder.embedding.weight'].shape[0]
print(f"Vocab size from checkpoint: {vocab_size_from_checkpoint}")
print(f"Vocab size from tokenizer: {tokenizer.counter}")
# Use the larger of the two to ensure compatibility
vocab_size = max(vocab_size_from_checkpoint, tokenizer.counter)
# Initialize model with correct vocab size
color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=config.color_emb_dim).to(config.device)
color_model.tokenizer = tokenizer
# Load the checkpoint
color_model.load_state_dict(checkpoint)
print(f"Color model loaded from {config.color_model_path}")
color_model.eval()
color_model.name = config.color_column
# Load hierarchy model
hierarchy_checkpoint = torch.load(config.hierarchy_model_path, map_location=config.device)
hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', [])
hierarchy_model = Model(
num_hierarchy_classes=len(hierarchy_classes),
embed_dim=config.hierarchy_emb_dim
).to(config.device)
hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state'])
# Set up hierarchy extractor
hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False)
hierarchy_model.set_hierarchy_extractor(hierarchy_extractor)
hierarchy_model.eval()
hierarchy_model.name = config.hierarchy_column
feature_models = {model.name: model for model in [color_model, hierarchy_model]}
return feature_models
# -------------------------------
# Main Training Function
# -------------------------------
def train_model(model, train_loader, val_loader, feature_models, device,
num_epochs=20, learning_rate=1e-5, temperature=0.07,
save_path=config.main_model_path, alignment_weight=0.3,
color_alignment_model=None, weight_decay=3e-4,
reference_model=None, reference_weight=0.1):
"""
Custom training loop using train_one_epoch and valid_one_epoch functions.
This function handles the complete training process including:
- Training and validation loops
- Learning rate scheduling
- Early stopping
- Model checkpointing
- Training curve visualization
Args:
model: Main CLIP model to train
train_loader: DataLoader for training data
val_loader: DataLoader for validation data
feature_models: Dictionary containing color and hierarchy models
device: Device to train on
num_epochs: Number of training epochs (default: 20)
learning_rate: Learning rate for optimizer (default: 1e-5)
temperature: Temperature scaling parameter for contrastive loss (default: 0.07)
save_path: Path to save model checkpoints (default: main_model_path)
alignment_weight: Weight for alignment loss component if using enhanced loss (default: 0.3)
color_alignment_model: Optional color model for alignment (default: None, uses feature_models)
weight_decay: L2 regularization weight (default: 3e-4, increased to reduce overfitting)
Returns:
Tuple of (training_losses, validation_losses) lists
"""
model = model.to(device)
# Use AdamW with weight decay for better regularization (reduces overfitting)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
train_losses = []
val_losses = []
best_val_loss = float('inf')
patience_counter = 0
patience = 7 # Increased from 5 to 7 for better convergence
print(f"Starting training for {num_epochs} epochs...")
print(f"Learning rate: {learning_rate}")
print(f"Temperature: {temperature}")
print(f"Weight decay: {weight_decay}")
print(f"Alignment weight: {alignment_weight}")
print(f"Device: {device}")
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Batch size: {train_loader.batch_size}")
print(f"Estimated time per epoch: ~{len(train_loader) * 2 / 60:.1f} minutes")
# Create processor once for efficiency
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K')
# Freeze and move reference model (used for text-space regularization)
if reference_model is not None:
reference_model = reference_model.to(device)
reference_model.eval()
for param in reference_model.parameters():
param.requires_grad = False
# Create progress bar for epochs
epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", position=0)
for epoch in epoch_pbar:
# Update epoch progress bar
epoch_pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")
# Training
if color_alignment_model is None:
color_alignment_model = feature_models[config.color_column]
hierarchy_model = feature_models[config.hierarchy_column]
train_loss, align_metrics = train_one_epoch(
model, train_loader, optimizer, feature_models, color_alignment_model, hierarchy_model,
device, processor, temperature, alignment_weight,
reference_model=reference_model, reference_weight=reference_weight
)
train_losses.append(train_loss)
# Validation
val_loss = valid_one_epoch(
model, val_loader, feature_models, device, processor,
temperature=temperature, alignment_weight=alignment_weight,
reference_model=reference_model, reference_weight=reference_weight
)
val_losses.append(val_loss)
# Learning rate scheduling
scheduler.step(val_loss)
# Calculate overfitting gap
overfitting_gap = val_loss - train_loss
# Update epoch progress bar with metrics
postfix = {
'Train Loss': f'{train_loss:.4f}',
'Val Loss': f'{val_loss:.4f}',
'Gap': f'{overfitting_gap:.4f}',
'LR': f'{optimizer.param_groups[0]["lr"]:.2e}',
'Best Val': f'{best_val_loss:.4f}'
}
if align_metrics is not None:
postfix.update({
'Align': f"{align_metrics['alignment_loss']:.3f}",
'ColCos': f"{align_metrics['color_text_cosine']:.3f}",
'HierCos': f"{align_metrics['hierarchy_text_cosine']:.3f}"
})
epoch_pbar.set_postfix(postfix)
# Warning if overfitting is detected
if overfitting_gap > 0.15 and epoch > 3:
print(f"\nโš ๏ธ Warning: Significant overfitting detected at epoch {epoch+1} (gap={overfitting_gap:.4f})")
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
# Save checkpoint
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
'best_val_loss': best_val_loss,
}, save_path)
else:
patience_counter += 1
# Early stopping
if patience_counter >= patience:
print(f"\n๐Ÿ›‘ Early stopping triggered after {patience_counter} epochs without improvement")
break
# Plot training curves with overfitting analysis
plt.figure(figsize=(15, 5))
# Plot 1: Training and Validation Loss
plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss', color='blue', linewidth=2)
plt.plot(val_losses, label='Val Loss', color='red', linewidth=2)
plt.title('Training and Validation Loss', fontsize=12, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)
# Plot 2: Overfitting Gap (Val Loss - Train Loss)
plt.subplot(1, 3, 2)
gap = [val_losses[i] - train_losses[i] for i in range(len(train_losses))]
plt.plot(gap, label='Overfitting Gap', color='purple', linewidth=2)
plt.axhline(y=0, color='black', linestyle='--', alpha=0.3)
plt.axhline(y=0.1, color='red', linestyle='--', alpha=0.3, label='Warning threshold')
plt.title('Overfitting Gap (Val - Train)', fontsize=12, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Gap')
plt.legend()
plt.grid(True, alpha=0.3)
# Plot 3: Loss comparison
plt.subplot(1, 3, 3)
epochs = list(range(len(train_losses)))
plt.plot(epochs, train_losses, 'o-', label='Train Loss', color='blue', linewidth=2)
plt.plot(epochs, val_losses, 's-', label='Val Loss', color='red', linewidth=2)
plt.fill_between(epochs, train_losses, val_losses, alpha=0.2, color='red')
plt.title('Loss Comparison', fontsize=12, fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('training_curves.png', dpi=300, bbox_inches='tight')
plt.close()
print(f"\nTraining completed!")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Final model saved to: {save_path}")
print(f"Training curves saved to: training_curves.png")
return train_losses, val_losses
# -------------------------------
# Main Function
# -------------------------------
def main():
print("="*80)
print("๐Ÿš€ Training of the model with alignement color and hierarchy")
print("="*80)
# Configuration (optimized to reduce overfitting)
num_epochs = 20
learning_rate = 1.5e-5 # Reduced slightly to prevent overfitting
temperature = 0.09 # Increased from 0.07 for softer contrastive learning
alignment_weight = 0.2 # Reduced from 0.3 to prevent overfitting on alignment
weight_decay = 5e-4 # Increased weight decay for stronger regularization
batch_size = 32
subset_size = 20000 # Increased dataset size for better generalization
# Load the data
print(f"\n๐Ÿ“‚ Loading the data...")
df = pd.read_csv(config.local_dataset_path)
print(f" Data downloaded: {len(df)} samples")
# filter the rows with NaN values
df_clean = df.dropna(subset=[config.column_local_image_path])
print(f" After filtering NaN: {len(df_clean)} samples")
# Creation of datasets
dataset = CustomDataset(df_clean)
# Creation of a subset for a faster training
print(f"\n๐Ÿ“Š Creation of a subset of {subset_size} samples...")
subset_size = min(subset_size, len(dataset))
train_size = int(0.8 * subset_size)
val_size = subset_size - train_size
# Creation of a subset with random indexes but reproductibles
np.random.seed(42)
subset_indices = np.random.choice(len(dataset), subset_size, replace=False)
subset_dataset = torch.utils.data.Subset(dataset, subset_indices)
train_dataset, val_dataset = random_split(
subset_dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(42)
)
# Creation of dataloaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=True if torch.cuda.is_available() else False
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=2,
pin_memory=True if torch.cuda.is_available() else False
)
print(f" Train: {len(train_dataset)} samples")
print(f" Validation: {len(val_dataset)} samples")
# Loading models
print(f"\n๐Ÿ”ง Loading models...")
feature_models = load_models()
# Load or create the main model
print(f"\n๐Ÿ“ฆ Loading main model...")
clip_model = CLIPModel_transformers.from_pretrained(
'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
)
# Frozen reference CLIP to regularize text space (improves cross-domain generalization)
reference_clip = CLIPModel_transformers.from_pretrained(
'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
)
# # Load the model
# if os.path.exists(config.main_model_path):
# print(f" Model found {config.main_model_path}")
# print(f" Loading checkpoint...")
# checkpoint = torch.load(config.main_model_path, map_location=config.device)
# if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
# clip_model.load_state_dict(checkpoint['model_state_dict'])
# print(f" โœ… Checkpoint loaded from {checkpoint.get('epoch', '?')}")
# else:
# clip_model.load_state_dict(checkpoint)
# print(f" โœ… Checkpoint loaded")
# else:
# print(f" New model, no checkpoint found")
# Move the model on the device
clip_model = clip_model.to(config.device)
reference_clip = reference_clip.to(config.device)
reference_clip.eval()
for param in reference_clip.parameters():
param.requires_grad = False
# Training with enhanced loss
print(f"\n๐ŸŽฏ Beginning training...")
print(f"\n" + "="*80)
train_losses, val_losses = train_model(
model=clip_model,
train_loader=train_loader,
val_loader=val_loader,
feature_models=feature_models,
device=config.device,
num_epochs=num_epochs,
learning_rate=learning_rate,
temperature=temperature,
save_path=config.main_model_path,
alignment_weight=alignment_weight,
color_alignment_model=feature_models[config.color_column],
weight_decay=weight_decay,
reference_model=reference_clip,
reference_weight=0.1
)
print("\n" + "="*80)
print("โœ… Training finished!")
print(f" Model saved: {config.main_model_path}")
print(f" Training curves: training_curves.png")
print("\n๐Ÿ“Š Final results:")
print(f" Last train loss: {train_losses[-1]:.4f}")
print(f" Last validation loss: {val_losses[-1]:.4f}")
print(f" Best validation loss: {min(val_losses):.4f}")
print(f" Overfitting gap (val-train): {val_losses[-1] - train_losses[-1]:.4f}")
if val_losses[-1] - train_losses[-1] > 0.1:
print(" โš ๏ธ Warning: Significant overfitting detected!")
elif val_losses[-1] - train_losses[-1] < 0.05:
print(" โœ… Good generalization!")
print("="*80)
if __name__ == "__main__":
main()