""" Training pipeline for signature verification model. """ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset import numpy as np from typing import Dict, List, Tuple, Optional, Callable import os import json from tqdm import tqdm import matplotlib.pyplot as plt import seaborn as sns from torch.utils.tensorboard import SummaryWriter from ..models.siamese_network import SiameseNetwork, TripletSiameseNetwork from ..data.preprocessing import SignaturePreprocessor from ..data.augmentation import SignatureAugmentationPipeline from .losses import ContrastiveLoss, TripletLoss, CombinedLoss, AdaptiveLoss class SignatureDataset(Dataset): """ Dataset for signature verification training. """ def __init__(self, data_pairs: List[Tuple[str, str, int]], preprocessor: SignaturePreprocessor, augmenter: Optional[SignatureAugmentationPipeline] = None, is_training: bool = True): """ Initialize the dataset. Args: data_pairs: List of (signature1_path, signature2_path, label) tuples preprocessor: Image preprocessor augmenter: Data augmenter is_training: Whether this is training data """ self.data_pairs = data_pairs self.preprocessor = preprocessor self.augmenter = augmenter self.is_training = is_training def __len__(self): return len(self.data_pairs) def __getitem__(self, idx): sig1_path, sig2_path, label = self.data_pairs[idx] # Load and preprocess images sig1 = self.preprocessor.load_image(sig1_path) sig2 = self.preprocessor.load_image(sig2_path) # Apply augmentation if available if self.augmenter and self.is_training: sig1 = self.augmenter.augment_image(sig1, is_training=True) sig2 = self.augmenter.augment_image(sig2, is_training=True) else: sig1 = self.preprocessor.preprocess_image(sig1) sig2 = self.preprocessor.preprocess_image(sig2) return sig1, sig2, torch.tensor(label, dtype=torch.float32) class TripletDataset(Dataset): """ Dataset for triplet training. """ def __init__(self, triplet_data: List[Tuple[str, str, str]], preprocessor: SignaturePreprocessor, augmenter: Optional[SignatureAugmentationPipeline] = None, is_training: bool = True): """ Initialize the triplet dataset. Args: triplet_data: List of (anchor_path, positive_path, negative_path) tuples preprocessor: Image preprocessor augmenter: Data augmenter is_training: Whether this is training data """ self.triplet_data = triplet_data self.preprocessor = preprocessor self.augmenter = augmenter self.is_training = is_training def __len__(self): return len(self.triplet_data) def __getitem__(self, idx): anchor_path, positive_path, negative_path = self.triplet_data[idx] # Load and preprocess images anchor = self.preprocessor.load_image(anchor_path) positive = self.preprocessor.load_image(positive_path) negative = self.preprocessor.load_image(negative_path) # Apply augmentation if available if self.augmenter and self.is_training: anchor = self.augmenter.augment_image(anchor, is_training=True) positive = self.augmenter.augment_image(positive, is_training=True) negative = self.augmenter.augment_image(negative, is_training=True) else: anchor = self.preprocessor.preprocess_image(anchor) positive = self.preprocessor.preprocess_image(positive) negative = self.preprocessor.preprocess_image(negative) return anchor, positive, negative class SignatureTrainer: """ Trainer for signature verification models. """ def __init__(self, model: nn.Module, device: str = 'auto', learning_rate: float = 1e-4, weight_decay: float = 1e-5, loss_type: str = 'contrastive', log_dir: str = 'logs'): """ Initialize the trainer. Args: model: Model to train device: Device to train on learning_rate: Learning rate weight_decay: Weight decay for regularization loss_type: Type of loss function ('contrastive', 'triplet', 'combined') log_dir: Directory for logging """ self.model = model self.device = self._get_device(device) self.model.to(self.device) self.learning_rate = learning_rate self.weight_decay = weight_decay self.loss_type = loss_type # Initialize optimizer self.optimizer = optim.Adam( self.model.parameters(), lr=learning_rate, weight_decay=weight_decay ) # Initialize scheduler self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode='min', patience=5, factor=0.5 ) # Initialize loss function self.criterion = self._get_loss_function() # Initialize logging self.log_dir = log_dir os.makedirs(log_dir, exist_ok=True) self.writer = SummaryWriter(log_dir) # Training history self.train_losses = [] self.val_losses = [] self.train_accuracies = [] self.val_accuracies = [] def _get_device(self, device: str) -> torch.device: """Get the appropriate device.""" if device == 'auto': return torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: return torch.device(device) def _get_loss_function(self) -> nn.Module: """Get the appropriate loss function.""" if self.loss_type == 'contrastive': return ContrastiveLoss() elif self.loss_type == 'triplet': return TripletLoss() elif self.loss_type == 'combined': return CombinedLoss() elif self.loss_type == 'adaptive': return AdaptiveLoss() else: raise ValueError(f"Unsupported loss type: {self.loss_type}") def train_epoch(self, train_loader: DataLoader, epoch: int) -> Dict[str, float]: """ Train for one epoch. Args: train_loader: Training data loader epoch: Current epoch number Returns: Dictionary of training metrics """ self.model.train() total_loss = 0.0 correct_predictions = 0 total_predictions = 0 progress_bar = tqdm(train_loader, desc=f'Epoch {epoch}') for batch_idx, batch_data in enumerate(progress_bar): self.optimizer.zero_grad() if self.loss_type == 'triplet': # Triplet training anchor, positive, negative = batch_data anchor = anchor.to(self.device) positive = positive.to(self.device) negative = negative.to(self.device) # Forward pass anchor_feat, positive_feat, negative_feat = self.model(anchor, positive, negative) # Compute loss loss = self.criterion(anchor_feat, positive_feat, negative_feat) # Compute accuracy (simplified) pos_dist = torch.norm(anchor_feat - positive_feat, dim=1) neg_dist = torch.norm(anchor_feat - negative_feat, dim=1) correct = (pos_dist < neg_dist).sum().item() correct_predictions += correct total_predictions += anchor.size(0) else: # Contrastive training sig1, sig2, labels = batch_data sig1 = sig1.to(self.device) sig2 = sig2.to(self.device) labels = labels.to(self.device) # Forward pass similarity = self.model(sig1, sig2) # Compute loss if self.loss_type == 'adaptive': loss, loss_info = self.criterion(similarity, labels, sig1, sig2, sig1, sig1) else: loss = self.criterion(similarity, labels) # Compute accuracy predictions = (similarity.squeeze() > 0.5).float() correct = (predictions == labels).sum().item() correct_predictions += correct total_predictions += labels.size(0) # Backward pass loss.backward() self.optimizer.step() total_loss += loss.item() # Update progress bar progress_bar.set_postfix({ 'Loss': f'{loss.item():.4f}', 'Acc': f'{correct_predictions/total_predictions:.4f}' if total_predictions > 0 else '0.0000' }) # Compute epoch metrics avg_loss = total_loss / len(train_loader) accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 metrics = { 'loss': avg_loss, 'accuracy': accuracy } return metrics def validate_epoch(self, val_loader: DataLoader, epoch: int) -> Dict[str, float]: """ Validate for one epoch. Args: val_loader: Validation data loader epoch: Current epoch number Returns: Dictionary of validation metrics """ self.model.eval() total_loss = 0.0 correct_predictions = 0 total_predictions = 0 with torch.no_grad(): for batch_data in val_loader: if self.loss_type == 'triplet': # Triplet validation anchor, positive, negative = batch_data anchor = anchor.to(self.device) positive = positive.to(self.device) negative = negative.to(self.device) # Forward pass anchor_feat, positive_feat, negative_feat = self.model(anchor, positive, negative) # Compute loss loss = self.criterion(anchor_feat, positive_feat, negative_feat) # Compute accuracy pos_dist = torch.norm(anchor_feat - positive_feat, dim=1) neg_dist = torch.norm(anchor_feat - negative_feat, dim=1) correct = (pos_dist < neg_dist).sum().item() correct_predictions += correct total_predictions += anchor.size(0) else: # Contrastive validation sig1, sig2, labels = batch_data sig1 = sig1.to(self.device) sig2 = sig2.to(self.device) labels = labels.to(self.device) # Forward pass similarity = self.model(sig1, sig2) # Compute loss loss = self.criterion(similarity, labels) # Compute accuracy predictions = (similarity.squeeze() > 0.5).float() correct = (predictions == labels).sum().item() correct_predictions += correct total_predictions += labels.size(0) total_loss += loss.item() # Compute epoch metrics avg_loss = total_loss / len(val_loader) accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 metrics = { 'loss': avg_loss, 'accuracy': accuracy } return metrics def train(self, train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 100, save_best: bool = True, patience: int = 10) -> Dict[str, List[float]]: """ Train the model. Args: train_loader: Training data loader val_loader: Validation data loader num_epochs: Number of epochs to train save_best: Whether to save the best model patience: Early stopping patience Returns: Training history """ best_val_loss = float('inf') patience_counter = 0 print(f"Training on device: {self.device}") print(f"Training samples: {len(train_loader.dataset)}") print(f"Validation samples: {len(val_loader.dataset)}") for epoch in range(num_epochs): # Training train_metrics = self.train_epoch(train_loader, epoch) # Validation val_metrics = self.validate_epoch(val_loader, epoch) # Update learning rate self.scheduler.step(val_metrics['loss']) # Store metrics self.train_losses.append(train_metrics['loss']) self.val_losses.append(val_metrics['loss']) self.train_accuracies.append(train_metrics['accuracy']) self.val_accuracies.append(val_metrics['accuracy']) # Log metrics self.writer.add_scalar('Loss/Train', train_metrics['loss'], epoch) self.writer.add_scalar('Loss/Val', val_metrics['loss'], epoch) self.writer.add_scalar('Accuracy/Train', train_metrics['accuracy'], epoch) self.writer.add_scalar('Accuracy/Val', val_metrics['accuracy'], epoch) self.writer.add_scalar('Learning_Rate', self.optimizer.param_groups[0]['lr'], epoch) # Print progress print(f'Epoch {epoch+1}/{num_epochs}:') print(f' Train Loss: {train_metrics["loss"]:.4f}, Train Acc: {train_metrics["accuracy"]:.4f}') print(f' Val Loss: {val_metrics["loss"]:.4f}, Val Acc: {val_metrics["accuracy"]:.4f}') print(f' Learning Rate: {self.optimizer.param_groups[0]["lr"]:.6f}') # Save best model if save_best and val_metrics['loss'] < best_val_loss: best_val_loss = val_metrics['loss'] self.save_model(os.path.join(self.log_dir, 'best_model.pth')) patience_counter = 0 print(f' New best model saved!') else: patience_counter += 1 # Early stopping if patience_counter >= patience: print(f'Early stopping at epoch {epoch+1}') break print('-' * 50) # Save final model self.save_model(os.path.join(self.log_dir, 'final_model.pth')) # Plot training curves self.plot_training_curves() return { 'train_losses': self.train_losses, 'val_losses': self.val_losses, 'train_accuracies': self.train_accuracies, 'val_accuracies': self.val_accuracies } def save_model(self, filepath: str): """Save model checkpoint.""" checkpoint = { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'train_losses': self.train_losses, 'val_losses': self.val_losses, 'train_accuracies': self.train_accuracies, 'val_accuracies': self.val_accuracies } torch.save(checkpoint, filepath) def load_model(self, filepath: str): """Load model checkpoint.""" checkpoint = torch.load(filepath, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.train_losses = checkpoint.get('train_losses', []) self.val_losses = checkpoint.get('val_losses', []) self.train_accuracies = checkpoint.get('train_accuracies', []) self.val_accuracies = checkpoint.get('val_accuracies', []) def plot_training_curves(self): """Plot training curves.""" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) # Loss curves ax1.plot(self.train_losses, label='Train Loss') ax1.plot(self.val_losses, label='Val Loss') ax1.set_xlabel('Epoch') ax1.set_ylabel('Loss') ax1.set_title('Training and Validation Loss') ax1.legend() ax1.grid(True) # Accuracy curves ax2.plot(self.train_accuracies, label='Train Accuracy') ax2.plot(self.val_accuracies, label='Val Accuracy') ax2.set_xlabel('Epoch') ax2.set_ylabel('Accuracy') ax2.set_title('Training and Validation Accuracy') ax2.legend() ax2.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.log_dir, 'training_curves.png')) plt.close() def close(self): """Close the trainer and clean up resources.""" self.writer.close()