""" Hierarchy model for learning clothing category-aligned embeddings. This file contains the hierarchy model that learns to encode images and texts in an embedding space specialized for representing clothing categories (dress, shirt, etc.). It includes a regex pattern-based hierarchy extractor, a ResNet image encoder, a hierarchy embedding encoder, and loss functions for training. """ import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms, models from PIL import Image from tqdm import tqdm from sklearn.model_selection import train_test_split import re import requests from io import BytesIO import config # ------------------------- # 1) Dataset # ------------------------- class HierarchyDataset(Dataset): """ Dataset class for hierarchy embedding training. Handles loading images from local paths or URLs, extracting hierarchy information from text descriptions, and applying appropriate transformations for training. """ def __init__(self, dataframe, use_local_images=True, image_size=224): """ Initialize the hierarchy dataset. Args: dataframe: DataFrame with columns for image paths/URLs, text descriptions, and hierarchy labels use_local_images: Whether to prefer local images over URLs (default: True) image_size: Size of images after resizing (default: 224) """ self.dataframe = dataframe self.use_local_images = use_local_images self.image_size = image_size # transforms with data augmentation for training self.transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.RandomHorizontalFlip(p=0.3), transforms.RandomRotation(10), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Validation transforms (no augmentation) self.val_transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Check local image availability if use_local_images: if config.column_local_image_path not in dataframe.columns: print(f"āš ļø Column {config.column_local_image_path} not found. Using URLs.") self.use_local_images = False else: local_available = dataframe[config.column_local_image_path].notna().sum() total = len(dataframe) print(f"šŸ“ Local images available: {local_available}/{total} ({local_available/total*100:.1f}%)") def set_training_mode(self, training=True): """ Switch between training and validation transforms. Args: training: If True, use training transforms with augmentation; if False, use validation transforms """ self.training_mode = training def __len__(self): """Return the number of samples in the dataset.""" return len(self.dataframe) def __getitem__(self, idx): """ Get a sample from the dataset. Args: idx: Index of the sample Returns: Tuple of (image_tensor, description_text, hierarchy_label) """ row = self.dataframe.iloc[idx] # Try to load local image first if self.use_local_images and pd.notna(row.get(config.column_local_image_path, '')): local_path = row[config.column_local_image_path] image = Image.open(local_path).convert("RGB") # Check if image is a dictionary of bytes elif isinstance(row[config.column_url_image], dict): image = Image.open(BytesIO(row[config.column_url_image]['bytes'])).convert('RGB') # Otherwise, try to download from URL else: image = self._download_image(row[config.column_url_image]) # Apply transforms if hasattr(self, 'training_mode') and not self.training_mode: image = self.val_transform(image) else: image = self.transform(image) description = row[config.text_column] hierarchy = row[config.hierarchy_column] return image, description, hierarchy def _download_image(self, img_url): """ Download an image from a URL with timeout. Args: img_url: URL of the image to download Returns: PIL Image object """ response = requests.get(img_url, timeout=10) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert("RGB") return image # ------------------------- # 2) Hierarchy Extractor # ------------------------- class HierarchyExtractor: """ Extract hierarchy categories directly from text using pattern matching. This class uses regex patterns to identify clothing categories (e.g., shirt, dress) from text descriptions, handling variations, plurals, and common fashion terms. """ def __init__(self, hierarchy_classes, verbose=False): """ Initialize the hierarchy extractor. Args: hierarchy_classes: List of hierarchy class names verbose: Whether to print initialization information (default: False) """ self.hierarchy_classes = sorted(hierarchy_classes) self.class_to_idx = {cls: idx for idx, cls in enumerate(self.hierarchy_classes)} self.idx_to_class = {idx: cls for idx, cls in enumerate(self.hierarchy_classes)} # Create patterns for each hierarchy self.patterns = self._create_patterns() if verbose: print(f"šŸŽÆ Hierarchy extractor initialized with {len(self.hierarchy_classes)} classes") print(f"šŸ“‹ Classes: {self.hierarchy_classes}") def _create_patterns(self): """ Create regex patterns for each hierarchy class. Creates patterns that match variations, plurals, and common fashion terms for each hierarchy class. Returns: Dictionary mapping hierarchy classes to regex patterns """ patterns = {} for hierarchy in self.hierarchy_classes: # Create variations of the hierarchy name variations = [hierarchy.lower()] # Add common variations if '-' in hierarchy: variations.append(hierarchy.replace('-', ' ')) variations.append(hierarchy.replace('-', '')) # Add plural forms if not hierarchy.endswith('s'): variations.append(hierarchy + 's') # Add common fashion terms fashion_terms = { 'shirt': ['shirt', 'shirts', 'tee', 't-shirt', 'tshirt'], 'jacket': ['jacket', 'jackets', 'coat', 'coats'], 'pant': ['pant', 'pants', 'trouser', 'trousers', 'jean', 'jeans'], 'dress': ['dress', 'dresses'], 'skirt': ['skirt', 'skirts'], 'shoe': ['shoe', 'shoes', 'boot', 'boots', 'sneaker', 'sneakers'], 'bag': ['bag', 'bags', 'handbag', 'handbags', 'purse', 'purses'], 'hat': ['hat', 'hats', 'cap', 'caps'], 'scarf': ['scarf', 'scarves'], 'belt': ['belt', 'belts'], 'sock': ['sock', 'socks'], 'underwear': ['underwear', 'underpant', 'underpants'], 'sweater': ['sweater', 'sweaters', 'jumper', 'jumpers'], 'blouse': ['blouse', 'blouses'], 'vest': ['vest', 'vests'], 'short': ['short', 'shorts'], 'legging': ['legging', 'leggings'], 'suit': ['suit', 'suits'], 'tie': ['tie', 'ties'], 'glove': ['glove', 'gloves'], 'sandal': ['sandal', 'sandals'] } # Add fashion terms if hierarchy matches for key, terms in fashion_terms.items(): if key in hierarchy.lower(): variations.extend(terms) # Create regex pattern pattern = r'\b(' + '|'.join(re.escape(v) for v in variations) + r')\b' patterns[hierarchy] = pattern return patterns def extract_hierarchy(self, text): """ Extract hierarchy category from text using pattern matching. Args: text: Input text string Returns: Hierarchy class name if found, None otherwise """ text_lower = text.lower() # Try exact match first for hierarchy in self.hierarchy_classes: if hierarchy.lower() in text_lower: return hierarchy # Try pattern matching for hierarchy, pattern in self.patterns.items(): if re.search(pattern, text_lower): return hierarchy # If no match found, return the most common hierarchy or None return None def extract_hierarchy_idx(self, text): """ Extract hierarchy index from text. Args: text: Input text string Returns: Hierarchy index if found, None otherwise """ hierarchy = self.extract_hierarchy(text) if hierarchy: return self.class_to_idx[hierarchy] return None def get_hierarchy_embedding(self, text, embed_dim=config.hierarchy_emb_dim): """ Create embedding from hierarchy index extracted from text. Args: text: Input text string embed_dim: Dimension of the embedding (default: hierarchy_emb_dim) Returns: Embedding tensor of shape (embed_dim,) """ hierarchy_idx = self.extract_hierarchy_idx(text) if hierarchy_idx is not None: # Create one-hot encoding embedding = torch.zeros(embed_dim) # Use the hierarchy index to set some values start_idx = (hierarchy_idx * 3) % embed_dim embedding[start_idx] = 1.0 embedding[(start_idx + 1) % embed_dim] = 0.5 embedding[(start_idx + 2) % embed_dim] = 0.3 return embedding else: # Return zero embedding for unknown hierarchy return torch.zeros(embed_dim) # ------------------------- # 3) Models # ------------------------- class PretrainedImageEncoder(nn.Module): """ Image encoder based on pretrained ResNet18 for extracting image embeddings. Uses a pretrained ResNet18 backbone and freezes early layers to prevent overfitting. Adds a custom projection head to output embeddings of the specified dimension. """ def __init__(self, embed_dim, dropout=0.3): """ Initialize the pretrained image encoder. Args: embed_dim: Dimension of the output embedding dropout: Dropout rate for regularization (default: 0.3) """ super().__init__() self.backbone = models.resnet18(pretrained=True) backbone_dim = 512 # Remove the final classification layer self.backbone = nn.Sequential(*list(self.backbone.children())[:-1]) # Add custom projection head self.projection = nn.Sequential( nn.Flatten(), nn.Dropout(dropout), nn.Linear(backbone_dim, embed_dim * 2), nn.ReLU(inplace=True), nn.Dropout(dropout * 0.5), nn.Linear(embed_dim * 2, embed_dim), nn.LayerNorm(embed_dim) ) # Fine-tune only the last few layers self._freeze_backbone_layers() def _freeze_backbone_layers(self): """ Freeze early layers to prevent overfitting. Freezes the first 70% of backbone layers, allowing only the last layers to be fine-tuned during training. """ if hasattr(self.backbone, 'children'): layers = list(self.backbone.children()) freeze_until = int(len(layers) * 0.7) for i, layer in enumerate(layers): if i < freeze_until: for param in layer.parameters(): param.requires_grad = False def forward(self, x): """ Forward pass through the image encoder. Args: x: Image tensor [batch_size, channels, height, width] Returns: Image embeddings [batch_size, embed_dim] """ features = self.backbone(x) return self.projection(features) class HierarchyEncoder(nn.Module): """ Encoder that takes hierarchy indices directly. Uses an embedding layer to convert hierarchy indices to embeddings, followed by a projection head to output embeddings of the specified dimension. """ def __init__(self, num_hierarchies, embed_dim, dropout=0.3): """ Initialize the hierarchy encoder. Args: num_hierarchies: Number of hierarchy classes embed_dim: Dimension of the output embedding dropout: Dropout rate for regularization (default: 0.3) """ super().__init__() self.num_hierarchies = num_hierarchies self.embed_dim = embed_dim # Embedding layer self.embedding = nn.Embedding(num_hierarchies, embed_dim) # Projection layer self.projection = nn.Sequential( nn.Linear(embed_dim, embed_dim * 2), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(embed_dim * 2, embed_dim), nn.LayerNorm(embed_dim) ) # Initialize weights self._init_weights() def _init_weights(self): """ Initialize weights properly using Xavier uniform initialization. """ nn.init.xavier_uniform_(self.embedding.weight) for module in self.projection.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) def forward(self, hierarchy_indices): """ Forward pass through the hierarchy encoder. Args: hierarchy_indices: Tensor of hierarchy indices [batch_size] Returns: Hierarchy embeddings [batch_size, embed_dim] Note: Includes workaround for MPS device: embedding layers don't work well with MPS, so embedding lookup is done on CPU and results are moved back to device. """ # hierarchy_indices: (B,) - batch of hierarchy indices # Workaround for MPS: embedding layers don't work well with MPS, so do lookup on CPU device = next(self.parameters()).device if device.type == 'mps': # Move indices to CPU for embedding lookup indices_cpu = hierarchy_indices.cpu() # Use functional embedding with explicit weight handling for MPS compatibility emb_weight = self.embedding.weight.cpu() emb = F.embedding(indices_cpu, emb_weight) # Move result back to model device (MPS) - ensure it's contiguous emb = emb.contiguous().to(device) else: emb = self.embedding(hierarchy_indices) # Ensure emb is on the same device as projection before calling it return self.projection(emb) class HierarchyClassifierHead(nn.Module): """ Classifier head for hierarchy classification. Multi-layer perceptron that takes embeddings as input and outputs classification logits for hierarchy classes. """ def __init__(self, in_dim, num_classes, hidden_dim=None, dropout=0.3): """ Initialize the hierarchy classifier head. Args: in_dim: Input embedding dimension num_classes: Number of hierarchy classes hidden_dim: Hidden layer dimension (default: max(in_dim // 2, num_classes * 2)) dropout: Dropout rate for regularization (default: 0.3) """ super().__init__() if hidden_dim is None: hidden_dim = max(in_dim // 2, num_classes * 2) self.classifier = nn.Sequential( nn.Linear(in_dim, hidden_dim), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(inplace=True), nn.Dropout(dropout * 0.5), nn.Linear(hidden_dim // 2, num_classes) ) def forward(self, x): """ Forward pass through the classifier head. Args: x: Input embeddings [batch_size, in_dim] Returns: Classification logits [batch_size, num_classes] """ return self.classifier(x) class Model(nn.Module): """ Main hierarchy model for learning clothing category-aligned embeddings. Combines image encoder, hierarchy encoder, and classifier heads to learn aligned embeddings for images and text descriptions based on clothing categories. """ def __init__(self, num_hierarchy_classes, embed_dim, dropout=0.3): """ Initialize the hierarchy model. Args: num_hierarchy_classes: Number of hierarchy classes embed_dim: Dimension of the embedding space dropout: Dropout rate for regularization (default: 0.3) """ super().__init__() self.img_enc = PretrainedImageEncoder(embed_dim, dropout) self.hierarchy_enc = HierarchyEncoder(num_hierarchy_classes, embed_dim, dropout) self.hierarchy_head_img = HierarchyClassifierHead(embed_dim, num_hierarchy_classes, dropout=dropout) self.hierarchy_head_txt = HierarchyClassifierHead(embed_dim, num_hierarchy_classes, dropout=dropout) self.num_hierarchy_classes = num_hierarchy_classes def forward(self, image=None, hierarchy_indices=None): """ Forward pass through the model. Args: image: Optional image tensor [batch_size, channels, height, width] hierarchy_indices: Optional hierarchy indices tensor [batch_size] Returns: Dictionary containing: - 'z_img': Image embeddings [batch_size, embed_dim] (if image provided) - 'z_txt': Text embeddings [batch_size, embed_dim] (if hierarchy_indices provided) - 'hierarchy_logits_img': Image classification logits [batch_size, num_classes] (if image provided) - 'hierarchy_logits_txt': Text classification logits [batch_size, num_classes] (if hierarchy_indices provided) """ out = {} if image is not None: z_img = self.img_enc(image) z_img = F.normalize(z_img, p=2, dim=1) hierarchy_logits_img = self.hierarchy_head_img(z_img) out['hierarchy_logits_img'] = hierarchy_logits_img out['z_img'] = z_img if hierarchy_indices is not None: z_txt = self.hierarchy_enc(hierarchy_indices) z_txt = F.normalize(z_txt, p=2, dim=1) hierarchy_logits_txt = self.hierarchy_head_txt(z_txt) out['hierarchy_logits_txt'] = hierarchy_logits_txt out['z_txt'] = z_txt return out def set_hierarchy_extractor(self, hierarchy_extractor): """ Set the hierarchy extractor for text processing. Args: hierarchy_extractor: HierarchyExtractor instance """ self.hierarchy_extractor = hierarchy_extractor def get_text_embeddings(self, text): """ Get text embeddings for a given text string or list of strings. Args: text: Text string or list of text strings Returns: Text embeddings tensor [batch_size, embed_dim] Raises: ValueError: If hierarchy cannot be extracted from text """ with torch.no_grad(): # Get the device of the model model_device = next(self.parameters()).device # Handle case where text is a list/tuple of hierarchies if isinstance(text, (list, tuple)): # Process multiple hierarchies hierarchy_indices = [] for hierarchy_text in text: if isinstance(hierarchy_text, str): hierarchy_idx = self.hierarchy_extractor.extract_hierarchy_idx(hierarchy_text) if hierarchy_idx is None: raise ValueError(f"Could not extract hierarchy for text: '{hierarchy_text}'. Available classes: {self.hierarchy_extractor.hierarchy_classes}") hierarchy_indices.append(hierarchy_idx) else: raise ValueError(f"Expected string, got {type(hierarchy_text)}: {hierarchy_text}") # Convert to tensor and move to device hierarchy_indices = torch.tensor(hierarchy_indices, device=model_device) # Get text embeddings for all hierarchies output = self.forward(hierarchy_indices=hierarchy_indices) return output['z_txt'] # Handle single string case elif isinstance(text, str): # Extract hierarchy index from text hierarchy_idx = self.hierarchy_extractor.extract_hierarchy_idx(text) if hierarchy_idx is None: raise ValueError(f"Could not extract hierarchy for text: '{text}'. Available classes: {self.hierarchy_extractor.hierarchy_classes}") # Convert to tensor and move to device hierarchy_indices = torch.tensor([hierarchy_idx], device=model_device) # Get text embeddings output = self.forward(hierarchy_indices=hierarchy_indices) return output['z_txt'] else: raise ValueError(f"Expected string or list/tuple of strings, got {type(text)}: {text}") def get_image_embeddings(self, image): """ Get image embeddings for a given image tensor. Args: image: Image tensor [channels, height, width] or [batch_size, channels, height, width] Returns: Image embeddings tensor [batch_size, embed_dim] Raises: ValueError: If image is not a torch.Tensor """ with torch.no_grad(): if not isinstance(image, torch.Tensor): raise ValueError("Image must be a torch.Tensor") # Ensure image is on the same device as model device = next(self.parameters()).device if image.device != device: image = image.to(device) # Add batch dimension if needed if image.dim() == 3: image = image.unsqueeze(0) # Get image embeddings output = self.forward(image=image) return output['z_img'] # ------------------------- # 4) Loss functions # ------------------------- class Loss(nn.Module): """ Combined loss function for hierarchy model training. Combines classification loss, contrastive loss, and consistency loss to learn aligned embeddings while maintaining classification accuracy. """ def __init__(self, hierarchy_classes, classification_weight=1.0, consistency_weight=0.3, contrastive_weight=0.2, temperature=0.07, label_smoothing=0.1): """ Initialize the loss function. Args: hierarchy_classes: List of hierarchy class names classification_weight: Weight for classification loss (default: 1.0) consistency_weight: Weight for consistency loss (default: 0.3) contrastive_weight: Weight for contrastive loss (default: 0.2) temperature: Temperature scaling for contrastive loss (default: 0.07) label_smoothing: Label smoothing parameter (default: 0.1) """ super().__init__() self.classification_weight = classification_weight self.consistency_weight = consistency_weight self.contrastive_weight = contrastive_weight self.temperature = temperature self.hierarchy_classes = sorted(list(set(hierarchy_classes))) self.num_classes = len(self.hierarchy_classes) self.class_to_idx = {cls: idx for idx, cls in enumerate(self.hierarchy_classes)} # Loss functions with label smoothing self.ce = nn.CrossEntropyLoss(label_smoothing=label_smoothing) self.mse = nn.MSELoss() def contrastive_loss(self, img_emb, txt_emb): """ InfoNCE contrastive loss for aligning image and text embeddings. Args: img_emb: Image embeddings [batch_size, embed_dim] txt_emb: Text embeddings [batch_size, embed_dim] Returns: Contrastive loss value """ sim_matrix = torch.matmul(img_emb, txt_emb.T) / self.temperature labels = torch.arange(img_emb.size(0), device=img_emb.device) loss_i2t = F.cross_entropy(sim_matrix, labels) loss_t2i = F.cross_entropy(sim_matrix.T, labels) return (loss_i2t + loss_t2i) / 2 def forward(self, img_logits, txt_logits, img_embeddings, txt_embeddings, target_hierarchies): """ Forward pass through the loss function. Args: img_logits: Image classification logits [batch_size, num_classes] txt_logits: Text classification logits [batch_size, num_classes] img_embeddings: Image embeddings [batch_size, embed_dim] txt_embeddings: Text embeddings [batch_size, embed_dim] target_hierarchies: List of target hierarchy class names [batch_size] Returns: Combined loss value """ device = img_embeddings.device # Convert hierarchy names to indices target_classes = torch.tensor([ self.class_to_idx.get(hierarchy, 0) for hierarchy in target_hierarchies ], device=device) # 1. Classification loss classification_loss = (self.ce(img_logits, target_classes) + self.ce(txt_logits, target_classes)) / 2 # 2. Contrastive loss for alignment contrastive_loss = self.contrastive_loss(img_embeddings, txt_embeddings) # 3. Consistency loss between modalities consistency_loss = self.mse(img_embeddings, txt_embeddings) # Combined loss total_loss = (self.classification_weight * classification_loss + self.contrastive_weight * contrastive_loss + self.consistency_weight * consistency_loss) return total_loss # ------------------------- # 5) Training # ------------------------- def collate_fn(batch, hierarchy_extractor): """ Collate function for DataLoader that processes batches and extracts hierarchy indices. Args: batch: List of (image, description, hierarchy) tuples hierarchy_extractor: HierarchyExtractor instance Returns: Dictionary containing: - 'image': Stacked image tensors [batch_size, channels, height, width] - 'hierarchy_indices': Hierarchy indices tensor [batch_size] - hierarchy_column: List of hierarchy class names [batch_size] """ images = torch.stack([b[0] for b in batch], dim=0) texts = [b[1] for b in batch] hierarchies = [b[2] for b in batch] # Extract hierarchy indices from texts hierarchy_indices = [] for text in texts: idx = hierarchy_extractor.extract_hierarchy_idx(text) if idx is not None: hierarchy_indices.append(idx) else: # If no hierarchy found, use the target hierarchy target_hierarchy = hierarchies[len(hierarchy_indices)] idx = hierarchy_extractor.class_to_idx.get(target_hierarchy, 0) hierarchy_indices.append(idx) hierarchy_indices = torch.tensor(hierarchy_indices, dtype=torch.long) return { 'image': images, 'hierarchy_indices': hierarchy_indices, config.hierarchy_column: hierarchies } def calculate_accuracy(logits, target_hierarchies, hierarchy_classes): """ Calculate classification accuracy. Args: logits: Classification logits [batch_size, num_classes] target_hierarchies: List of target hierarchy class names [batch_size] hierarchy_classes: List of hierarchy class names Returns: Accuracy score (float between 0 and 1) """ batch_size = logits.size(0) correct = 0 pred_indices = torch.argmax(logits, dim=1).cpu().numpy() for i in range(batch_size): pred_class = hierarchy_classes[pred_indices[i]] if pred_indices[i] < len(hierarchy_classes) else "" target_class = target_hierarchies[i] if pred_class == target_class: correct += 1 return correct / batch_size def train_one_epoch(model, dataloader, optimizer, device, hierarchy_classes, scheduler=None): """ Train the model for one epoch. Args: model: Model instance to train dataloader: DataLoader for training data optimizer: Optimizer instance device: Device to train on hierarchy_classes: List of hierarchy class names scheduler: Optional learning rate scheduler Returns: Dictionary containing training metrics: - 'loss': Average training loss - 'acc_img': Average image classification accuracy - 'acc_txt': Average text classification accuracy """ model.train() total_loss = 0.0 total_acc_img = 0.0 total_acc_txt = 0.0 num_batches = 0 loss_fn = Loss( hierarchy_classes, classification_weight=1.0, consistency_weight=0.3, contrastive_weight=0.2, label_smoothing=0.1 ).to(device) pbar = tqdm(dataloader, desc="Training", leave=False) for batch in pbar: images = batch['image'].to(device) hierarchy_indices = batch['hierarchy_indices'].to(device) target_hierarchies = batch[config.hierarchy_column] # Set dataset to training mode if hasattr(dataloader.dataset, 'set_training_mode'): dataloader.dataset.set_training_mode(True) out = model(image=images, hierarchy_indices=hierarchy_indices) hierarchy_logits_img = out['hierarchy_logits_img'] hierarchy_logits_txt = out['hierarchy_logits_txt'] z_img, z_txt = out['z_img'], out['z_txt'] # Calculate loss loss = loss_fn(hierarchy_logits_img, hierarchy_logits_txt, z_img, z_txt, target_hierarchies) optimizer.zero_grad() loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() if scheduler is not None: scheduler.step() # Calculate accuracies acc_img = calculate_accuracy(hierarchy_logits_img, target_hierarchies, hierarchy_classes) acc_txt = calculate_accuracy(hierarchy_logits_txt, target_hierarchies, hierarchy_classes) total_loss += loss.item() total_acc_img += acc_img total_acc_txt += acc_txt num_batches += 1 pbar.set_postfix({ 'loss': f'{loss.item():.4f}', 'acc_img': f'{acc_img:.3f}', 'acc_txt': f'{acc_txt:.3f}', }) return { 'loss': total_loss / num_batches, 'acc_img': total_acc_img / num_batches, 'acc_txt': total_acc_txt / num_batches } def validate(model, dataloader, device, hierarchy_classes): """ Validate the model on validation data. Args: model: Model instance to validate dataloader: DataLoader for validation data device: Device to validate on hierarchy_classes: List of hierarchy class names Returns: Dictionary containing validation metrics: - 'loss': Average validation loss - 'acc_img': Average image classification accuracy - 'acc_txt': Average text classification accuracy """ model.eval() total_loss = 0.0 total_acc_img = 0.0 total_acc_txt = 0.0 num_batches = 0 loss_fn = Loss( hierarchy_classes, classification_weight=1.0, consistency_weight=0.3, contrastive_weight=0.2 ).to(device) pbar = tqdm(dataloader, desc="Validation", leave=False) with torch.no_grad(): for batch in pbar: images = batch['image'].to(device) hierarchy_indices = batch['hierarchy_indices'].to(device) target_hierarchies = batch[config.hierarchy_column] # Set dataset to validation mode if hasattr(dataloader.dataset, 'set_training_mode'): dataloader.dataset.set_training_mode(False) out = model(image=images, hierarchy_indices=hierarchy_indices) hierarchy_logits_img = out['hierarchy_logits_img'] hierarchy_logits_txt = out['hierarchy_logits_txt'] z_img, z_txt = out['z_img'], out['z_txt'] # Calculate loss loss = loss_fn(hierarchy_logits_img, hierarchy_logits_txt, z_img, z_txt, target_hierarchies) # Calculate accuracies acc_img = calculate_accuracy(hierarchy_logits_img, target_hierarchies, hierarchy_classes) acc_txt = calculate_accuracy(hierarchy_logits_txt, target_hierarchies, hierarchy_classes) total_loss += loss.item() total_acc_img += acc_img total_acc_txt += acc_txt num_batches += 1 pbar.set_postfix({ 'loss': f'{loss.item():.4f}', 'acc_img': f'{acc_img:.3f}', 'acc_txt': f'{acc_txt:.3f}', }) return { 'loss': total_loss / num_batches, 'acc_img': total_acc_img / num_batches, 'acc_txt': total_acc_txt / num_batches } # ------------------------- # 6) Main training script # ------------------------- if __name__ == "__main__": # Configuration device = config.device batch_size = 16 lr = 5e-5 epochs = 20 val_split = 0.2 dropout = 0.4 weight_decay = 1e-3 print(f"šŸš€ Starting hierarchical training on device: {device}") print(f"šŸ“Š Config: {epochs} epochs, batch={batch_size}, lr={lr}, embed_dim={config.hierarchy_emb_dim}") # Load dataset print(f"šŸ“ Using dataset: { config.local_dataset_path}") df = pd.read_csv(config.local_dataset_path) print(f"šŸ“ Loaded {len(df)} samples") # Get unique hierarchy classes hierarchy_classes = sorted(df[config.hierarchy_column].unique().tolist()) print(f"šŸ“‹ Found {len(hierarchy_classes)} hierarchy classes") # Create hierarchy extractor hierarchy_extractor = HierarchyExtractor(hierarchy_classes, verbose=True) # Train/validation split train_df, val_df = train_test_split( df, test_size=val_split, random_state=42, stratify=df[config.hierarchy_column] ) train_df = train_df.reset_index(drop=True) val_df = val_df.reset_index(drop=True) print(f"šŸ“ˆ Train: {len(train_df)}, Validation: {len(val_df)}") # Create datasets train_ds = HierarchyDataset(train_df, image_size=224) val_ds = HierarchyDataset(val_df, image_size=224) # Create data loaders train_dl = DataLoader( train_ds, batch_size=batch_size, shuffle=True, collate_fn=lambda batch: collate_fn(batch, hierarchy_extractor) ) val_dl = DataLoader( val_ds, batch_size=batch_size, shuffle=False, collate_fn=lambda batch: collate_fn(batch, hierarchy_extractor) ) # Create model model = Model( num_hierarchy_classes=len(hierarchy_classes), embed_dim=config.hierarchy_emb_dim, dropout=dropout ).to(device) # Optimizer and scheduler optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=lr/10) print(f"šŸŽÆ Model parameters: {sum(p.numel() for p in model.parameters()):,}") print("\n" + "="*80) # Training loop best_val_loss = float('inf') training_history = {'train_loss': [], 'val_loss': [], 'val_acc_img': [], 'val_acc_txt': []} for e in range(epochs): print(f"\nšŸ”„ Epoch {e+1}/{epochs}") print("-" * 50) # Training train_metrics = train_one_epoch(model, train_dl, optimizer, device, hierarchy_classes, scheduler) # Validation val_metrics = validate(model, val_dl, device, hierarchy_classes) # Track history training_history['train_loss'].append(train_metrics['loss']) training_history['val_loss'].append(val_metrics['loss']) training_history['val_acc_img'].append(val_metrics['acc_img']) training_history['val_acc_txt'].append(val_metrics['acc_txt']) # Display results print(f"šŸ“Š TRAIN - Loss: {train_metrics['loss']:.6f} | " f"Img Acc: {train_metrics['acc_img']:.3f} | " f"Txt Acc: {train_metrics['acc_txt']:.3f}") print(f"āœ… VAL - Loss: {val_metrics['loss']:.6f} | " f"Img Acc: {val_metrics['acc_img']:.3f} | " f"Txt Acc: {val_metrics['acc_txt']:.3f}") # Save best model if val_metrics['loss'] < best_val_loss: best_val_loss = val_metrics['loss'] print(f"šŸ’¾ New best validation loss! Saving model...") torch.save({ 'model_state': model.state_dict(), 'hierarchy_classes': hierarchy_classes, 'epoch': e+1, 'config': { 'embed_dim': config.hierarchy_emb_dim, 'dropout': dropout } }, config.hierarchy_model_path) # Save model every 2 epochs if (e + 1) % 2 == 0: print(f"šŸ’¾ Saving checkpoint at epoch {e+1}...") torch.save({ 'model_state': model.state_dict(), 'hierarchy_classes': hierarchy_classes, 'epoch': e+1, 'config': { 'embed_dim': config.hierarchy_emb_dim, 'dropout': dropout } }, f"model_checkpoint_epoch_{e+1}.pth") print("\n" + "="*80) print("šŸŽ‰ Training completed!") print(f"šŸ† Best validation loss: {best_val_loss:.6f}") print(f"\nšŸ“ˆ Final validation accuracy: Image={training_history['val_acc_img'][-1]:.3f}, Text={training_history['val_acc_txt'][-1]:.3f}")