|
|
|
|
|
""" |
|
|
Main file for training the CLIP model with color and hierarchy alignment. |
|
|
This file centralizes all the logic for training the main model. It uses |
|
|
pre-trained color and hierarchy models to guide the main model's learning |
|
|
through contrastive and alignment loss functions. It handles data loading, |
|
|
training with validation, and checkpoint saving. |
|
|
""" |
|
|
|
|
|
import os |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader, random_split |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import matplotlib.pyplot as plt |
|
|
from transformers import CLIPProcessor, CLIPModel as CLIPModel_transformers |
|
|
import warnings |
|
|
from tqdm import tqdm |
|
|
import json |
|
|
import config |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def enhanced_contrastive_loss(text_features, image_features, attribute_features, |
|
|
color_model, hierarchy_model, colors, hierarchies, temperature=0.07, alignment_weight=0.3, |
|
|
reference_text_features=None, reference_weight=0.1): |
|
|
""" |
|
|
Enhanced contrastive loss with direct alignment between color/hierarchy models and main model. |
|
|
|
|
|
This loss combines the original triple contrastive loss with direct alignment losses |
|
|
that force the main model's color and hierarchy dimensions to align with the |
|
|
specialized color and hierarchy models. |
|
|
|
|
|
Args: |
|
|
text_features: Main model text embeddings [batch_size, embed_dim] |
|
|
image_features: Main model image embeddings [batch_size, embed_dim] |
|
|
attribute_features: Concatenated color + hierarchy features [batch_size, color_dim + hierarchy_dim] |
|
|
color_model: Pre-trained color model for extracting color embeddings |
|
|
hierarchy_model: Pre-trained hierarchy model for extracting hierarchy embeddings |
|
|
colors: List of color strings for this batch [batch_size] |
|
|
hierarchies: List of hierarchy strings for this batch [batch_size] |
|
|
temperature: Temperature scaling parameter for contrastive loss (default: 0.07) |
|
|
alignment_weight: Weight for the alignment loss component (default: 0.3) |
|
|
|
|
|
Returns: |
|
|
Tuple of (total_loss, metrics_dict) where metrics_dict contains detailed loss components |
|
|
""" |
|
|
|
|
|
|
|
|
text_features_norm = F.normalize(text_features, dim=-1) |
|
|
image_features_norm = F.normalize(image_features, dim=-1) |
|
|
attribute_features_norm = F.normalize(attribute_features, dim=-1) |
|
|
|
|
|
text_image_logits = (text_features_norm[:, config.color_emb_dim+config.hierarchy_emb_dim:] @ |
|
|
image_features_norm[:, config.color_emb_dim+config.hierarchy_emb_dim:].T) / temperature |
|
|
text_attr_logits = (text_features_norm[:, :config.color_emb_dim+config.hierarchy_emb_dim] @ |
|
|
attribute_features_norm.T) / temperature |
|
|
image_attr_logits = (attribute_features_norm @ |
|
|
image_features_norm[:,:config.color_emb_dim+config.hierarchy_emb_dim].T) / temperature |
|
|
|
|
|
|
|
|
weight_text_image = 0.7 |
|
|
weight_attr_based = 0.15 |
|
|
|
|
|
original_logits = (weight_text_image * text_image_logits + |
|
|
weight_attr_based * text_attr_logits + |
|
|
weight_attr_based * image_attr_logits) |
|
|
|
|
|
labels = torch.arange(len(text_features)).to(text_features.device) |
|
|
original_loss = (F.cross_entropy(original_logits, labels) + |
|
|
F.cross_entropy(original_logits.T, labels)) / 2 |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
color_embeddings = color_model.get_text_embeddings(colors) |
|
|
hierarchy_embeddings = hierarchy_model.get_text_embeddings(hierarchies) |
|
|
|
|
|
|
|
|
main_color_text = text_features[:, :config.color_emb_dim] |
|
|
main_color_image = image_features[:, :config.color_emb_dim] |
|
|
|
|
|
|
|
|
main_hierarchy_text = text_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim] |
|
|
main_hierarchy_image = image_features[:, config.color_emb_dim:config.color_emb_dim+config.hierarchy_emb_dim] |
|
|
|
|
|
|
|
|
color_embeddings_norm = F.normalize(color_embeddings, dim=-1) |
|
|
main_color_text_norm = F.normalize(main_color_text, dim=-1) |
|
|
main_color_image_norm = F.normalize(main_color_image, dim=-1) |
|
|
|
|
|
hierarchy_embeddings_norm = F.normalize(hierarchy_embeddings, dim=-1) |
|
|
main_hierarchy_text_norm = F.normalize(main_hierarchy_text, dim=-1) |
|
|
main_hierarchy_image_norm = F.normalize(main_hierarchy_image, dim=-1) |
|
|
|
|
|
|
|
|
color_text_alignment_loss = F.mse_loss(main_color_text_norm, color_embeddings_norm) |
|
|
color_image_alignment_loss = F.mse_loss(main_color_image_norm, color_embeddings_norm) |
|
|
color_text_cosine_loss = 1 - F.cosine_similarity(main_color_text_norm, color_embeddings_norm).mean() |
|
|
color_image_cosine_loss = 1 - F.cosine_similarity(main_color_image_norm, color_embeddings_norm).mean() |
|
|
|
|
|
|
|
|
color_alignment_loss = ( |
|
|
color_text_alignment_loss + color_image_alignment_loss + |
|
|
color_text_cosine_loss + color_image_cosine_loss |
|
|
) / 4 |
|
|
|
|
|
|
|
|
hierarchy_text_alignment_loss = F.mse_loss(main_hierarchy_text_norm, hierarchy_embeddings_norm) |
|
|
hierarchy_image_alignment_loss = F.mse_loss(main_hierarchy_image_norm, hierarchy_embeddings_norm) |
|
|
hierarchy_text_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_text_norm, hierarchy_embeddings_norm).mean() |
|
|
hierarchy_image_cosine_loss = 1 - F.cosine_similarity(main_hierarchy_image_norm, hierarchy_embeddings_norm).mean() |
|
|
|
|
|
|
|
|
hierarchy_alignment_loss = ( |
|
|
hierarchy_text_alignment_loss + hierarchy_image_alignment_loss + |
|
|
hierarchy_text_cosine_loss + hierarchy_image_cosine_loss |
|
|
) / 4 |
|
|
|
|
|
|
|
|
alignment_loss = (color_alignment_loss + hierarchy_alignment_loss) / 2 |
|
|
|
|
|
|
|
|
reference_loss = 0.0 |
|
|
if reference_text_features is not None: |
|
|
reference_loss = F.mse_loss( |
|
|
F.normalize(text_features, dim=-1), |
|
|
F.normalize(reference_text_features, dim=-1) |
|
|
) |
|
|
|
|
|
|
|
|
total_loss = (1 - alignment_weight) * original_loss + alignment_weight * alignment_loss |
|
|
if reference_text_features is not None: |
|
|
total_loss = total_loss + reference_weight * reference_loss |
|
|
|
|
|
return total_loss, { |
|
|
'original_loss': original_loss.item(), |
|
|
'alignment_loss': alignment_loss.item(), |
|
|
'reference_loss': reference_loss if isinstance(reference_loss, float) else reference_loss.item(), |
|
|
'color_text_alignment': color_text_alignment_loss.item(), |
|
|
'color_image_alignment': color_image_alignment_loss.item(), |
|
|
'color_text_cosine': color_text_cosine_loss.item(), |
|
|
'color_image_cosine': color_image_cosine_loss.item(), |
|
|
'hierarchy_text_alignment': hierarchy_text_alignment_loss.item(), |
|
|
'hierarchy_image_alignment': hierarchy_image_alignment_loss.item(), |
|
|
'hierarchy_text_cosine': hierarchy_text_cosine_loss.item(), |
|
|
'hierarchy_image_cosine': hierarchy_image_cosine_loss.item() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_one_epoch(model, train_loader, optimizer, feature_models, color_model, hierarchy_model, |
|
|
device, clip_processor, temperature=0.07, alignment_weight=0.3, |
|
|
reference_model=None, reference_weight=0.1): |
|
|
""" |
|
|
Enhanced training with direct color and hierarchy alignment loss. |
|
|
|
|
|
This function trains the model using the enhanced contrastive loss that includes |
|
|
direct alignment between the main model's color/hierarchy dimensions and the |
|
|
specialized color/hierarchy models. |
|
|
|
|
|
Args: |
|
|
model: Main CLIP model to train |
|
|
train_loader: DataLoader for training data |
|
|
optimizer: Optimizer instance |
|
|
feature_models: Dictionary containing color and hierarchy models |
|
|
color_model: Pre-trained color model for alignment |
|
|
hierarchy_model: Pre-trained hierarchy model for alignment |
|
|
device: Device to train on |
|
|
clip_processor: CLIP processor for text preprocessing |
|
|
temperature: Temperature scaling parameter for contrastive loss (default: 0.07) |
|
|
alignment_weight: Weight for the alignment loss component (default: 0.3) |
|
|
|
|
|
Returns: |
|
|
Tuple of (average_loss, metrics_dict) where metrics_dict contains detailed loss components |
|
|
""" |
|
|
model.train() |
|
|
total_loss = 0.0 |
|
|
total_metrics = { |
|
|
'original_loss': 0.0, |
|
|
'alignment_loss': 0.0, |
|
|
'reference_loss': 0.0, |
|
|
'color_text_alignment': 0.0, |
|
|
'color_image_alignment': 0.0, |
|
|
'color_text_cosine': 0.0, |
|
|
'color_image_cosine': 0.0, |
|
|
'hierarchy_text_alignment': 0.0, |
|
|
'hierarchy_image_alignment': 0.0, |
|
|
'hierarchy_text_cosine': 0.0, |
|
|
'hierarchy_image_cosine': 0.0 |
|
|
} |
|
|
num_batches = 0 |
|
|
|
|
|
pbar = tqdm(train_loader, desc="Training Enhanced", leave=False) |
|
|
|
|
|
for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar): |
|
|
|
|
|
images = images.to(device) |
|
|
images = images.expand(-1, 3, -1, -1) |
|
|
|
|
|
|
|
|
text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt") |
|
|
text_inputs = {k: v.to(device) for k, v in text_inputs.items()} |
|
|
|
|
|
|
|
|
reference_text_features = None |
|
|
if reference_model is not None: |
|
|
with torch.no_grad(): |
|
|
reference_text_features = reference_model.get_text_features(**text_inputs) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
outputs = model(**text_inputs, pixel_values=images) |
|
|
|
|
|
text_features = outputs.text_embeds |
|
|
image_features = outputs.image_embeds |
|
|
|
|
|
|
|
|
if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'): |
|
|
color_features = feature_models[config.color_column].get_color_name_embeddings(colors) |
|
|
else: |
|
|
color_features = feature_models[config.color_column].get_text_embeddings(colors) |
|
|
hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy) |
|
|
concat_features = torch.cat((color_features, hierarchy_features), dim=1) |
|
|
|
|
|
|
|
|
loss, metrics = enhanced_contrastive_loss( |
|
|
text_features, image_features, concat_features, |
|
|
color_model, hierarchy_model, colors, hierarchy, temperature, alignment_weight, |
|
|
reference_text_features=reference_text_features, reference_weight=reference_weight |
|
|
) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
for key, value in metrics.items(): |
|
|
total_metrics[key] += value |
|
|
num_batches += 1 |
|
|
|
|
|
|
|
|
pbar.set_postfix({ |
|
|
'Loss': f'{loss.item():.4f}', |
|
|
'Align': f'{metrics["alignment_loss"]:.4f}', |
|
|
'ColCos': f'{metrics["color_text_cosine"]:.3f}', |
|
|
'HierCos': f'{metrics["hierarchy_text_cosine"]:.3f}' |
|
|
}) |
|
|
|
|
|
avg_metrics = {key: value / num_batches for key, value in total_metrics.items()} |
|
|
return total_loss / num_batches, avg_metrics |
|
|
|
|
|
def valid_one_epoch(model, val_loader, feature_models, device, clip_processor, temperature=0.07, alignment_weight=0.3, |
|
|
reference_model=None, reference_weight=0.1): |
|
|
""" |
|
|
Validate the model for one epoch using enhanced contrastive loss. |
|
|
|
|
|
Args: |
|
|
model: Main CLIP model to validate |
|
|
val_loader: DataLoader for validation data |
|
|
feature_models: Dictionary containing color and hierarchy models |
|
|
device: Device to validate on |
|
|
clip_processor: CLIP processor for text preprocessing |
|
|
temperature: Temperature scaling parameter for contrastive loss (default: 0.07) |
|
|
alignment_weight: Weight for the alignment loss component (default: 0.3) |
|
|
|
|
|
Returns: |
|
|
Average validation loss for the epoch |
|
|
""" |
|
|
model.eval() |
|
|
total_loss = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
|
|
|
color_model = feature_models[config.color_column] |
|
|
hierarchy_model = feature_models[config.hierarchy_column] |
|
|
|
|
|
|
|
|
pbar = tqdm(val_loader, desc="Validation", leave=False) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch_idx, (images, texts, colors, hierarchy) in enumerate(pbar): |
|
|
|
|
|
images = images.to(device) |
|
|
images = images.expand(-1, 3, -1, -1) |
|
|
|
|
|
|
|
|
text_inputs = clip_processor(text=texts, padding=True, return_tensors="pt") |
|
|
text_inputs = {k: v.to(device) for k, v in text_inputs.items()} |
|
|
|
|
|
|
|
|
reference_text_features = None |
|
|
if reference_model is not None: |
|
|
reference_text_features = reference_model.get_text_features(**text_inputs) |
|
|
|
|
|
|
|
|
outputs = model(**text_inputs, pixel_values=images) |
|
|
|
|
|
text_features = outputs.text_embeds |
|
|
image_features = outputs.image_embeds |
|
|
|
|
|
|
|
|
if hasattr(feature_models[config.color_column], 'get_color_name_embeddings'): |
|
|
color_features = feature_models[config.color_column].get_color_name_embeddings(colors) |
|
|
else: |
|
|
color_features = feature_models[config.color_column].get_text_embeddings(colors) |
|
|
hierarchy_features = feature_models[config.hierarchy_column].get_text_embeddings(hierarchy) |
|
|
concat_features = torch.cat((color_features, hierarchy_features), dim=1) |
|
|
|
|
|
|
|
|
loss, metrics = enhanced_contrastive_loss( |
|
|
text_features, image_features, concat_features, |
|
|
color_model, hierarchy_model, colors, hierarchy, |
|
|
temperature, alignment_weight, |
|
|
reference_text_features=reference_text_features, reference_weight=reference_weight |
|
|
) |
|
|
|
|
|
total_loss += loss.item() |
|
|
num_batches += 1 |
|
|
|
|
|
|
|
|
pbar.set_postfix({ |
|
|
'Loss': f'{loss.item():.4f}', |
|
|
'Avg Loss': f'{total_loss/num_batches:.4f}' |
|
|
}) |
|
|
|
|
|
return total_loss / num_batches |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomDataset(Dataset): |
|
|
""" |
|
|
Custom dataset for main model training. |
|
|
|
|
|
Handles loading images from local paths, extracting text descriptions, |
|
|
and applying appropriate transformations for training and validation. |
|
|
""" |
|
|
|
|
|
def __init__(self, dataframe, use_local_images=True, image_size=224): |
|
|
""" |
|
|
Initialize the custom dataset. |
|
|
|
|
|
Args: |
|
|
dataframe: DataFrame with columns for image paths, text descriptions, colors, and hierarchy labels |
|
|
use_local_images: Whether to use local images (default: True) |
|
|
image_size: Size of images after resizing (default: 224) |
|
|
""" |
|
|
self.dataframe = dataframe |
|
|
self.use_local_images = use_local_images |
|
|
self.image_size = image_size |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((image_size, image_size)), |
|
|
transforms.RandomHorizontalFlip(p=0.5), |
|
|
transforms.RandomRotation(15), |
|
|
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15), |
|
|
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)), |
|
|
transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.2), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
self.val_transform = transforms.Compose([ |
|
|
transforms.Resize((image_size, image_size)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
self.training_mode = True |
|
|
|
|
|
def set_training_mode(self, training=True): |
|
|
""" |
|
|
Switch between training and validation transforms. |
|
|
|
|
|
Args: |
|
|
training: If True, use training transforms with augmentation; if False, use validation transforms |
|
|
""" |
|
|
self.training_mode = training |
|
|
|
|
|
def __len__(self): |
|
|
"""Return the number of samples in the dataset.""" |
|
|
return len(self.dataframe) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
""" |
|
|
Get a sample from the dataset. |
|
|
|
|
|
Args: |
|
|
idx: Index of the sample |
|
|
|
|
|
Returns: |
|
|
Tuple of (image_tensor, description_text, color_label, hierarchy_label) |
|
|
""" |
|
|
row = self.dataframe.iloc[idx] |
|
|
|
|
|
image_data = row[config.column_local_image_path] |
|
|
image = Image.open(image_data).convert("RGB") |
|
|
|
|
|
|
|
|
if self.training_mode: |
|
|
image = self.transform(image) |
|
|
else: |
|
|
image = self.val_transform(image) |
|
|
|
|
|
|
|
|
description = row[config.text_column] |
|
|
color = row[config.color_column] |
|
|
hierarchy = row[config.hierarchy_column] |
|
|
|
|
|
return image, description, color, hierarchy |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_models(): |
|
|
""" |
|
|
Load color and hierarchy models from checkpoints. |
|
|
|
|
|
This function loads the pre-trained color and hierarchy models along with |
|
|
their tokenizers and extractors, and prepares them for use in main model training. |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping model names to model instances: |
|
|
- 'color': ColorCLIP model instance |
|
|
- 'hierarchy': Hierarchy model instance |
|
|
""" |
|
|
from color_model import ColorCLIP, Tokenizer |
|
|
from hierarchy_model import Model, HierarchyExtractor |
|
|
|
|
|
|
|
|
tokenizer = Tokenizer() |
|
|
|
|
|
|
|
|
if os.path.exists(config.tokeniser_path): |
|
|
with open(config.tokeniser_path, 'r') as f: |
|
|
vocab_dict = json.load(f) |
|
|
tokenizer.load_vocab(vocab_dict) |
|
|
print(f"Tokenizer vocabulary loaded from {config.tokeniser_path}") |
|
|
else: |
|
|
print(f"Warning: {config.tokeniser_path} not found. Using default tokenizer.") |
|
|
|
|
|
|
|
|
checkpoint = torch.load(config.color_model_path, map_location=config.device) |
|
|
|
|
|
|
|
|
vocab_size_from_checkpoint = checkpoint['text_encoder.embedding.weight'].shape[0] |
|
|
print(f"Vocab size from checkpoint: {vocab_size_from_checkpoint}") |
|
|
print(f"Vocab size from tokenizer: {tokenizer.counter}") |
|
|
|
|
|
|
|
|
vocab_size = max(vocab_size_from_checkpoint, tokenizer.counter) |
|
|
|
|
|
|
|
|
color_model = ColorCLIP(vocab_size=vocab_size, embedding_dim=config.color_emb_dim).to(config.device) |
|
|
color_model.tokenizer = tokenizer |
|
|
|
|
|
|
|
|
color_model.load_state_dict(checkpoint) |
|
|
print(f"Color model loaded from {config.color_model_path}") |
|
|
|
|
|
color_model.eval() |
|
|
color_model.name = config.color_column |
|
|
|
|
|
|
|
|
hierarchy_checkpoint = torch.load(config.hierarchy_model_path, map_location=config.device) |
|
|
hierarchy_classes = hierarchy_checkpoint.get('hierarchy_classes', []) |
|
|
hierarchy_model = Model( |
|
|
num_hierarchy_classes=len(hierarchy_classes), |
|
|
embed_dim=config.hierarchy_emb_dim |
|
|
).to(config.device) |
|
|
hierarchy_model.load_state_dict(hierarchy_checkpoint['model_state']) |
|
|
|
|
|
|
|
|
hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=False) |
|
|
hierarchy_model.set_hierarchy_extractor(hierarchy_extractor) |
|
|
hierarchy_model.eval() |
|
|
hierarchy_model.name = config.hierarchy_column |
|
|
|
|
|
feature_models = {model.name: model for model in [color_model, hierarchy_model]} |
|
|
|
|
|
return feature_models |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_model(model, train_loader, val_loader, feature_models, device, |
|
|
num_epochs=20, learning_rate=1e-5, temperature=0.07, |
|
|
save_path=config.main_model_path, alignment_weight=0.3, |
|
|
color_alignment_model=None, weight_decay=3e-4, |
|
|
reference_model=None, reference_weight=0.1): |
|
|
""" |
|
|
Custom training loop using train_one_epoch and valid_one_epoch functions. |
|
|
|
|
|
This function handles the complete training process including: |
|
|
- Training and validation loops |
|
|
- Learning rate scheduling |
|
|
- Early stopping |
|
|
- Model checkpointing |
|
|
- Training curve visualization |
|
|
|
|
|
Args: |
|
|
model: Main CLIP model to train |
|
|
train_loader: DataLoader for training data |
|
|
val_loader: DataLoader for validation data |
|
|
feature_models: Dictionary containing color and hierarchy models |
|
|
device: Device to train on |
|
|
num_epochs: Number of training epochs (default: 20) |
|
|
learning_rate: Learning rate for optimizer (default: 1e-5) |
|
|
temperature: Temperature scaling parameter for contrastive loss (default: 0.07) |
|
|
save_path: Path to save model checkpoints (default: main_model_path) |
|
|
alignment_weight: Weight for alignment loss component if using enhanced loss (default: 0.3) |
|
|
color_alignment_model: Optional color model for alignment (default: None, uses feature_models) |
|
|
weight_decay: L2 regularization weight (default: 3e-4, increased to reduce overfitting) |
|
|
|
|
|
Returns: |
|
|
Tuple of (training_losses, validation_losses) lists |
|
|
""" |
|
|
model = model.to(device) |
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) |
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5) |
|
|
|
|
|
train_losses = [] |
|
|
val_losses = [] |
|
|
best_val_loss = float('inf') |
|
|
patience_counter = 0 |
|
|
patience = 7 |
|
|
|
|
|
print(f"Starting training for {num_epochs} epochs...") |
|
|
print(f"Learning rate: {learning_rate}") |
|
|
print(f"Temperature: {temperature}") |
|
|
print(f"Weight decay: {weight_decay}") |
|
|
print(f"Alignment weight: {alignment_weight}") |
|
|
print(f"Device: {device}") |
|
|
print(f"Training samples: {len(train_loader.dataset)}") |
|
|
print(f"Validation samples: {len(val_loader.dataset)}") |
|
|
print(f"Batch size: {train_loader.batch_size}") |
|
|
print(f"Estimated time per epoch: ~{len(train_loader) * 2 / 60:.1f} minutes") |
|
|
|
|
|
|
|
|
processor = CLIPProcessor.from_pretrained('laion/CLIP-ViT-B-32-laion2B-s34B-b79K') |
|
|
|
|
|
|
|
|
if reference_model is not None: |
|
|
reference_model = reference_model.to(device) |
|
|
reference_model.eval() |
|
|
for param in reference_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", position=0) |
|
|
|
|
|
for epoch in epoch_pbar: |
|
|
|
|
|
epoch_pbar.set_description(f"Epoch {epoch+1}/{num_epochs}") |
|
|
|
|
|
|
|
|
if color_alignment_model is None: |
|
|
color_alignment_model = feature_models[config.color_column] |
|
|
hierarchy_model = feature_models[config.hierarchy_column] |
|
|
train_loss, align_metrics = train_one_epoch( |
|
|
model, train_loader, optimizer, feature_models, color_alignment_model, hierarchy_model, |
|
|
device, processor, temperature, alignment_weight, |
|
|
reference_model=reference_model, reference_weight=reference_weight |
|
|
) |
|
|
train_losses.append(train_loss) |
|
|
|
|
|
|
|
|
val_loss = valid_one_epoch( |
|
|
model, val_loader, feature_models, device, processor, |
|
|
temperature=temperature, alignment_weight=alignment_weight, |
|
|
reference_model=reference_model, reference_weight=reference_weight |
|
|
) |
|
|
val_losses.append(val_loss) |
|
|
|
|
|
|
|
|
scheduler.step(val_loss) |
|
|
|
|
|
|
|
|
overfitting_gap = val_loss - train_loss |
|
|
|
|
|
|
|
|
postfix = { |
|
|
'Train Loss': f'{train_loss:.4f}', |
|
|
'Val Loss': f'{val_loss:.4f}', |
|
|
'Gap': f'{overfitting_gap:.4f}', |
|
|
'LR': f'{optimizer.param_groups[0]["lr"]:.2e}', |
|
|
'Best Val': f'{best_val_loss:.4f}' |
|
|
} |
|
|
if align_metrics is not None: |
|
|
postfix.update({ |
|
|
'Align': f"{align_metrics['alignment_loss']:.3f}", |
|
|
'ColCos': f"{align_metrics['color_text_cosine']:.3f}", |
|
|
'HierCos': f"{align_metrics['hierarchy_text_cosine']:.3f}" |
|
|
}) |
|
|
epoch_pbar.set_postfix(postfix) |
|
|
|
|
|
|
|
|
if overfitting_gap > 0.15 and epoch > 3: |
|
|
print(f"\nโ ๏ธ Warning: Significant overfitting detected at epoch {epoch+1} (gap={overfitting_gap:.4f})") |
|
|
|
|
|
|
|
|
if val_loss < best_val_loss: |
|
|
best_val_loss = val_loss |
|
|
patience_counter = 0 |
|
|
|
|
|
|
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'train_loss': train_loss, |
|
|
'val_loss': val_loss, |
|
|
'best_val_loss': best_val_loss, |
|
|
}, save_path) |
|
|
else: |
|
|
patience_counter += 1 |
|
|
|
|
|
|
|
|
if patience_counter >= patience: |
|
|
print(f"\n๐ Early stopping triggered after {patience_counter} epochs without improvement") |
|
|
break |
|
|
|
|
|
|
|
|
plt.figure(figsize=(15, 5)) |
|
|
|
|
|
|
|
|
plt.subplot(1, 3, 1) |
|
|
plt.plot(train_losses, label='Train Loss', color='blue', linewidth=2) |
|
|
plt.plot(val_losses, label='Val Loss', color='red', linewidth=2) |
|
|
plt.title('Training and Validation Loss', fontsize=12, fontweight='bold') |
|
|
plt.xlabel('Epoch') |
|
|
plt.ylabel('Loss') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
plt.subplot(1, 3, 2) |
|
|
gap = [val_losses[i] - train_losses[i] for i in range(len(train_losses))] |
|
|
plt.plot(gap, label='Overfitting Gap', color='purple', linewidth=2) |
|
|
plt.axhline(y=0, color='black', linestyle='--', alpha=0.3) |
|
|
plt.axhline(y=0.1, color='red', linestyle='--', alpha=0.3, label='Warning threshold') |
|
|
plt.title('Overfitting Gap (Val - Train)', fontsize=12, fontweight='bold') |
|
|
plt.xlabel('Epoch') |
|
|
plt.ylabel('Gap') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
plt.subplot(1, 3, 3) |
|
|
epochs = list(range(len(train_losses))) |
|
|
plt.plot(epochs, train_losses, 'o-', label='Train Loss', color='blue', linewidth=2) |
|
|
plt.plot(epochs, val_losses, 's-', label='Val Loss', color='red', linewidth=2) |
|
|
plt.fill_between(epochs, train_losses, val_losses, alpha=0.2, color='red') |
|
|
plt.title('Loss Comparison', fontsize=12, fontweight='bold') |
|
|
plt.xlabel('Epoch') |
|
|
plt.ylabel('Loss') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig('training_curves.png', dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f"\nTraining completed!") |
|
|
print(f"Best validation loss: {best_val_loss:.4f}") |
|
|
print(f"Final model saved to: {save_path}") |
|
|
print(f"Training curves saved to: training_curves.png") |
|
|
|
|
|
return train_losses, val_losses |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
print("="*80) |
|
|
print("๐ Training of the model with alignement color and hierarchy") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
num_epochs = 20 |
|
|
learning_rate = 1.5e-5 |
|
|
temperature = 0.09 |
|
|
alignment_weight = 0.2 |
|
|
weight_decay = 5e-4 |
|
|
batch_size = 32 |
|
|
subset_size = 20000 |
|
|
|
|
|
|
|
|
print(f"\n๐ Loading the data...") |
|
|
df = pd.read_csv(config.local_dataset_path) |
|
|
print(f" Data downloaded: {len(df)} samples") |
|
|
|
|
|
|
|
|
df_clean = df.dropna(subset=[config.column_local_image_path]) |
|
|
print(f" After filtering NaN: {len(df_clean)} samples") |
|
|
|
|
|
|
|
|
dataset = CustomDataset(df_clean) |
|
|
|
|
|
|
|
|
print(f"\n๐ Creation of a subset of {subset_size} samples...") |
|
|
subset_size = min(subset_size, len(dataset)) |
|
|
train_size = int(0.8 * subset_size) |
|
|
val_size = subset_size - train_size |
|
|
|
|
|
|
|
|
np.random.seed(42) |
|
|
subset_indices = np.random.choice(len(dataset), subset_size, replace=False) |
|
|
subset_dataset = torch.utils.data.Subset(dataset, subset_indices) |
|
|
|
|
|
train_dataset, val_dataset = random_split( |
|
|
subset_dataset, |
|
|
[train_size, val_size], |
|
|
generator=torch.Generator().manual_seed(42) |
|
|
) |
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
num_workers=2, |
|
|
pin_memory=True if torch.cuda.is_available() else False |
|
|
) |
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=2, |
|
|
pin_memory=True if torch.cuda.is_available() else False |
|
|
) |
|
|
|
|
|
print(f" Train: {len(train_dataset)} samples") |
|
|
print(f" Validation: {len(val_dataset)} samples") |
|
|
|
|
|
|
|
|
print(f"\n๐ง Loading models...") |
|
|
feature_models = load_models() |
|
|
|
|
|
|
|
|
print(f"\n๐ฆ Loading main model...") |
|
|
clip_model = CLIPModel_transformers.from_pretrained( |
|
|
'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' |
|
|
) |
|
|
|
|
|
reference_clip = CLIPModel_transformers.from_pretrained( |
|
|
'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clip_model = clip_model.to(config.device) |
|
|
reference_clip = reference_clip.to(config.device) |
|
|
reference_clip.eval() |
|
|
for param in reference_clip.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
print(f"\n๐ฏ Beginning training...") |
|
|
print(f"\n" + "="*80) |
|
|
|
|
|
train_losses, val_losses = train_model( |
|
|
model=clip_model, |
|
|
train_loader=train_loader, |
|
|
val_loader=val_loader, |
|
|
feature_models=feature_models, |
|
|
device=config.device, |
|
|
num_epochs=num_epochs, |
|
|
learning_rate=learning_rate, |
|
|
temperature=temperature, |
|
|
save_path=config.main_model_path, |
|
|
alignment_weight=alignment_weight, |
|
|
color_alignment_model=feature_models[config.color_column], |
|
|
weight_decay=weight_decay, |
|
|
reference_model=reference_clip, |
|
|
reference_weight=0.1 |
|
|
) |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("โ
Training finished!") |
|
|
print(f" Model saved: {config.main_model_path}") |
|
|
print(f" Training curves: training_curves.png") |
|
|
print("\n๐ Final results:") |
|
|
print(f" Last train loss: {train_losses[-1]:.4f}") |
|
|
print(f" Last validation loss: {val_losses[-1]:.4f}") |
|
|
print(f" Best validation loss: {min(val_losses):.4f}") |
|
|
print(f" Overfitting gap (val-train): {val_losses[-1] - train_losses[-1]:.4f}") |
|
|
if val_losses[-1] - train_losses[-1] > 0.1: |
|
|
print(" โ ๏ธ Warning: Significant overfitting detected!") |
|
|
elif val_losses[-1] - train_losses[-1] < 0.05: |
|
|
print(" โ
Good generalization!") |
|
|
print("="*80) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|