|
|
""" |
|
|
Training pipeline for signature verification model. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
import numpy as np |
|
|
from typing import Dict, List, Tuple, Optional, Callable |
|
|
import os |
|
|
import json |
|
|
from tqdm import tqdm |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
from ..models.siamese_network import SiameseNetwork, TripletSiameseNetwork |
|
|
from ..data.preprocessing import SignaturePreprocessor |
|
|
from ..data.augmentation import SignatureAugmentationPipeline |
|
|
from .losses import ContrastiveLoss, TripletLoss, CombinedLoss, AdaptiveLoss |
|
|
|
|
|
|
|
|
class SignatureDataset(Dataset): |
|
|
""" |
|
|
Dataset for signature verification training. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
data_pairs: List[Tuple[str, str, int]], |
|
|
preprocessor: SignaturePreprocessor, |
|
|
augmenter: Optional[SignatureAugmentationPipeline] = None, |
|
|
is_training: bool = True): |
|
|
""" |
|
|
Initialize the dataset. |
|
|
|
|
|
Args: |
|
|
data_pairs: List of (signature1_path, signature2_path, label) tuples |
|
|
preprocessor: Image preprocessor |
|
|
augmenter: Data augmenter |
|
|
is_training: Whether this is training data |
|
|
""" |
|
|
self.data_pairs = data_pairs |
|
|
self.preprocessor = preprocessor |
|
|
self.augmenter = augmenter |
|
|
self.is_training = is_training |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data_pairs) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
sig1_path, sig2_path, label = self.data_pairs[idx] |
|
|
|
|
|
|
|
|
sig1 = self.preprocessor.load_image(sig1_path) |
|
|
sig2 = self.preprocessor.load_image(sig2_path) |
|
|
|
|
|
|
|
|
if self.augmenter and self.is_training: |
|
|
sig1 = self.augmenter.augment_image(sig1, is_training=True) |
|
|
sig2 = self.augmenter.augment_image(sig2, is_training=True) |
|
|
else: |
|
|
sig1 = self.preprocessor.preprocess_image(sig1) |
|
|
sig2 = self.preprocessor.preprocess_image(sig2) |
|
|
|
|
|
return sig1, sig2, torch.tensor(label, dtype=torch.float32) |
|
|
|
|
|
|
|
|
class TripletDataset(Dataset): |
|
|
""" |
|
|
Dataset for triplet training. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
triplet_data: List[Tuple[str, str, str]], |
|
|
preprocessor: SignaturePreprocessor, |
|
|
augmenter: Optional[SignatureAugmentationPipeline] = None, |
|
|
is_training: bool = True): |
|
|
""" |
|
|
Initialize the triplet dataset. |
|
|
|
|
|
Args: |
|
|
triplet_data: List of (anchor_path, positive_path, negative_path) tuples |
|
|
preprocessor: Image preprocessor |
|
|
augmenter: Data augmenter |
|
|
is_training: Whether this is training data |
|
|
""" |
|
|
self.triplet_data = triplet_data |
|
|
self.preprocessor = preprocessor |
|
|
self.augmenter = augmenter |
|
|
self.is_training = is_training |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.triplet_data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
anchor_path, positive_path, negative_path = self.triplet_data[idx] |
|
|
|
|
|
|
|
|
anchor = self.preprocessor.load_image(anchor_path) |
|
|
positive = self.preprocessor.load_image(positive_path) |
|
|
negative = self.preprocessor.load_image(negative_path) |
|
|
|
|
|
|
|
|
if self.augmenter and self.is_training: |
|
|
anchor = self.augmenter.augment_image(anchor, is_training=True) |
|
|
positive = self.augmenter.augment_image(positive, is_training=True) |
|
|
negative = self.augmenter.augment_image(negative, is_training=True) |
|
|
else: |
|
|
anchor = self.preprocessor.preprocess_image(anchor) |
|
|
positive = self.preprocessor.preprocess_image(positive) |
|
|
negative = self.preprocessor.preprocess_image(negative) |
|
|
|
|
|
return anchor, positive, negative |
|
|
|
|
|
|
|
|
class SignatureTrainer: |
|
|
""" |
|
|
Trainer for signature verification models. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
model: nn.Module, |
|
|
device: str = 'auto', |
|
|
learning_rate: float = 1e-4, |
|
|
weight_decay: float = 1e-5, |
|
|
loss_type: str = 'contrastive', |
|
|
log_dir: str = 'logs'): |
|
|
""" |
|
|
Initialize the trainer. |
|
|
|
|
|
Args: |
|
|
model: Model to train |
|
|
device: Device to train on |
|
|
learning_rate: Learning rate |
|
|
weight_decay: Weight decay for regularization |
|
|
loss_type: Type of loss function ('contrastive', 'triplet', 'combined') |
|
|
log_dir: Directory for logging |
|
|
""" |
|
|
self.model = model |
|
|
self.device = self._get_device(device) |
|
|
self.model.to(self.device) |
|
|
|
|
|
self.learning_rate = learning_rate |
|
|
self.weight_decay = weight_decay |
|
|
self.loss_type = loss_type |
|
|
|
|
|
|
|
|
self.optimizer = optim.Adam( |
|
|
self.model.parameters(), |
|
|
lr=learning_rate, |
|
|
weight_decay=weight_decay |
|
|
) |
|
|
|
|
|
|
|
|
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( |
|
|
self.optimizer, mode='min', patience=5, factor=0.5 |
|
|
) |
|
|
|
|
|
|
|
|
self.criterion = self._get_loss_function() |
|
|
|
|
|
|
|
|
self.log_dir = log_dir |
|
|
os.makedirs(log_dir, exist_ok=True) |
|
|
self.writer = SummaryWriter(log_dir) |
|
|
|
|
|
|
|
|
self.train_losses = [] |
|
|
self.val_losses = [] |
|
|
self.train_accuracies = [] |
|
|
self.val_accuracies = [] |
|
|
|
|
|
def _get_device(self, device: str) -> torch.device: |
|
|
"""Get the appropriate device.""" |
|
|
if device == 'auto': |
|
|
return torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
else: |
|
|
return torch.device(device) |
|
|
|
|
|
def _get_loss_function(self) -> nn.Module: |
|
|
"""Get the appropriate loss function.""" |
|
|
if self.loss_type == 'contrastive': |
|
|
return ContrastiveLoss() |
|
|
elif self.loss_type == 'triplet': |
|
|
return TripletLoss() |
|
|
elif self.loss_type == 'combined': |
|
|
return CombinedLoss() |
|
|
elif self.loss_type == 'adaptive': |
|
|
return AdaptiveLoss() |
|
|
else: |
|
|
raise ValueError(f"Unsupported loss type: {self.loss_type}") |
|
|
|
|
|
def train_epoch(self, |
|
|
train_loader: DataLoader, |
|
|
epoch: int) -> Dict[str, float]: |
|
|
""" |
|
|
Train for one epoch. |
|
|
|
|
|
Args: |
|
|
train_loader: Training data loader |
|
|
epoch: Current epoch number |
|
|
|
|
|
Returns: |
|
|
Dictionary of training metrics |
|
|
""" |
|
|
self.model.train() |
|
|
total_loss = 0.0 |
|
|
correct_predictions = 0 |
|
|
total_predictions = 0 |
|
|
|
|
|
progress_bar = tqdm(train_loader, desc=f'Epoch {epoch}') |
|
|
|
|
|
for batch_idx, batch_data in enumerate(progress_bar): |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
if self.loss_type == 'triplet': |
|
|
|
|
|
anchor, positive, negative = batch_data |
|
|
anchor = anchor.to(self.device) |
|
|
positive = positive.to(self.device) |
|
|
negative = negative.to(self.device) |
|
|
|
|
|
|
|
|
anchor_feat, positive_feat, negative_feat = self.model(anchor, positive, negative) |
|
|
|
|
|
|
|
|
loss = self.criterion(anchor_feat, positive_feat, negative_feat) |
|
|
|
|
|
|
|
|
pos_dist = torch.norm(anchor_feat - positive_feat, dim=1) |
|
|
neg_dist = torch.norm(anchor_feat - negative_feat, dim=1) |
|
|
correct = (pos_dist < neg_dist).sum().item() |
|
|
correct_predictions += correct |
|
|
total_predictions += anchor.size(0) |
|
|
|
|
|
else: |
|
|
|
|
|
sig1, sig2, labels = batch_data |
|
|
sig1 = sig1.to(self.device) |
|
|
sig2 = sig2.to(self.device) |
|
|
labels = labels.to(self.device) |
|
|
|
|
|
|
|
|
similarity = self.model(sig1, sig2) |
|
|
|
|
|
|
|
|
if self.loss_type == 'adaptive': |
|
|
loss, loss_info = self.criterion(similarity, labels, sig1, sig2, sig1, sig1) |
|
|
else: |
|
|
loss = self.criterion(similarity, labels) |
|
|
|
|
|
|
|
|
predictions = (similarity.squeeze() > 0.5).float() |
|
|
correct = (predictions == labels).sum().item() |
|
|
correct_predictions += correct |
|
|
total_predictions += labels.size(0) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
self.optimizer.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
|
|
|
|
|
|
progress_bar.set_postfix({ |
|
|
'Loss': f'{loss.item():.4f}', |
|
|
'Acc': f'{correct_predictions/total_predictions:.4f}' if total_predictions > 0 else '0.0000' |
|
|
}) |
|
|
|
|
|
|
|
|
avg_loss = total_loss / len(train_loader) |
|
|
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 |
|
|
|
|
|
metrics = { |
|
|
'loss': avg_loss, |
|
|
'accuracy': accuracy |
|
|
} |
|
|
|
|
|
return metrics |
|
|
|
|
|
def validate_epoch(self, |
|
|
val_loader: DataLoader, |
|
|
epoch: int) -> Dict[str, float]: |
|
|
""" |
|
|
Validate for one epoch. |
|
|
|
|
|
Args: |
|
|
val_loader: Validation data loader |
|
|
epoch: Current epoch number |
|
|
|
|
|
Returns: |
|
|
Dictionary of validation metrics |
|
|
""" |
|
|
self.model.eval() |
|
|
total_loss = 0.0 |
|
|
correct_predictions = 0 |
|
|
total_predictions = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch_data in val_loader: |
|
|
if self.loss_type == 'triplet': |
|
|
|
|
|
anchor, positive, negative = batch_data |
|
|
anchor = anchor.to(self.device) |
|
|
positive = positive.to(self.device) |
|
|
negative = negative.to(self.device) |
|
|
|
|
|
|
|
|
anchor_feat, positive_feat, negative_feat = self.model(anchor, positive, negative) |
|
|
|
|
|
|
|
|
loss = self.criterion(anchor_feat, positive_feat, negative_feat) |
|
|
|
|
|
|
|
|
pos_dist = torch.norm(anchor_feat - positive_feat, dim=1) |
|
|
neg_dist = torch.norm(anchor_feat - negative_feat, dim=1) |
|
|
correct = (pos_dist < neg_dist).sum().item() |
|
|
correct_predictions += correct |
|
|
total_predictions += anchor.size(0) |
|
|
|
|
|
else: |
|
|
|
|
|
sig1, sig2, labels = batch_data |
|
|
sig1 = sig1.to(self.device) |
|
|
sig2 = sig2.to(self.device) |
|
|
labels = labels.to(self.device) |
|
|
|
|
|
|
|
|
similarity = self.model(sig1, sig2) |
|
|
|
|
|
|
|
|
loss = self.criterion(similarity, labels) |
|
|
|
|
|
|
|
|
predictions = (similarity.squeeze() > 0.5).float() |
|
|
correct = (predictions == labels).sum().item() |
|
|
correct_predictions += correct |
|
|
total_predictions += labels.size(0) |
|
|
|
|
|
total_loss += loss.item() |
|
|
|
|
|
|
|
|
avg_loss = total_loss / len(val_loader) |
|
|
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 |
|
|
|
|
|
metrics = { |
|
|
'loss': avg_loss, |
|
|
'accuracy': accuracy |
|
|
} |
|
|
|
|
|
return metrics |
|
|
|
|
|
def train(self, |
|
|
train_loader: DataLoader, |
|
|
val_loader: DataLoader, |
|
|
num_epochs: int = 100, |
|
|
save_best: bool = True, |
|
|
patience: int = 10) -> Dict[str, List[float]]: |
|
|
""" |
|
|
Train the model. |
|
|
|
|
|
Args: |
|
|
train_loader: Training data loader |
|
|
val_loader: Validation data loader |
|
|
num_epochs: Number of epochs to train |
|
|
save_best: Whether to save the best model |
|
|
patience: Early stopping patience |
|
|
|
|
|
Returns: |
|
|
Training history |
|
|
""" |
|
|
best_val_loss = float('inf') |
|
|
patience_counter = 0 |
|
|
|
|
|
print(f"Training on device: {self.device}") |
|
|
print(f"Training samples: {len(train_loader.dataset)}") |
|
|
print(f"Validation samples: {len(val_loader.dataset)}") |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
|
|
|
train_metrics = self.train_epoch(train_loader, epoch) |
|
|
|
|
|
|
|
|
val_metrics = self.validate_epoch(val_loader, epoch) |
|
|
|
|
|
|
|
|
self.scheduler.step(val_metrics['loss']) |
|
|
|
|
|
|
|
|
self.train_losses.append(train_metrics['loss']) |
|
|
self.val_losses.append(val_metrics['loss']) |
|
|
self.train_accuracies.append(train_metrics['accuracy']) |
|
|
self.val_accuracies.append(val_metrics['accuracy']) |
|
|
|
|
|
|
|
|
self.writer.add_scalar('Loss/Train', train_metrics['loss'], epoch) |
|
|
self.writer.add_scalar('Loss/Val', val_metrics['loss'], epoch) |
|
|
self.writer.add_scalar('Accuracy/Train', train_metrics['accuracy'], epoch) |
|
|
self.writer.add_scalar('Accuracy/Val', val_metrics['accuracy'], epoch) |
|
|
self.writer.add_scalar('Learning_Rate', self.optimizer.param_groups[0]['lr'], epoch) |
|
|
|
|
|
|
|
|
print(f'Epoch {epoch+1}/{num_epochs}:') |
|
|
print(f' Train Loss: {train_metrics["loss"]:.4f}, Train Acc: {train_metrics["accuracy"]:.4f}') |
|
|
print(f' Val Loss: {val_metrics["loss"]:.4f}, Val Acc: {val_metrics["accuracy"]:.4f}') |
|
|
print(f' Learning Rate: {self.optimizer.param_groups[0]["lr"]:.6f}') |
|
|
|
|
|
|
|
|
if save_best and val_metrics['loss'] < best_val_loss: |
|
|
best_val_loss = val_metrics['loss'] |
|
|
self.save_model(os.path.join(self.log_dir, 'best_model.pth')) |
|
|
patience_counter = 0 |
|
|
print(f' New best model saved!') |
|
|
else: |
|
|
patience_counter += 1 |
|
|
|
|
|
|
|
|
if patience_counter >= patience: |
|
|
print(f'Early stopping at epoch {epoch+1}') |
|
|
break |
|
|
|
|
|
print('-' * 50) |
|
|
|
|
|
|
|
|
self.save_model(os.path.join(self.log_dir, 'final_model.pth')) |
|
|
|
|
|
|
|
|
self.plot_training_curves() |
|
|
|
|
|
return { |
|
|
'train_losses': self.train_losses, |
|
|
'val_losses': self.val_losses, |
|
|
'train_accuracies': self.train_accuracies, |
|
|
'val_accuracies': self.val_accuracies |
|
|
} |
|
|
|
|
|
def save_model(self, filepath: str): |
|
|
"""Save model checkpoint.""" |
|
|
checkpoint = { |
|
|
'model_state_dict': self.model.state_dict(), |
|
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
|
'scheduler_state_dict': self.scheduler.state_dict(), |
|
|
'train_losses': self.train_losses, |
|
|
'val_losses': self.val_losses, |
|
|
'train_accuracies': self.train_accuracies, |
|
|
'val_accuracies': self.val_accuracies |
|
|
} |
|
|
torch.save(checkpoint, filepath) |
|
|
|
|
|
def load_model(self, filepath: str): |
|
|
"""Load model checkpoint.""" |
|
|
checkpoint = torch.load(filepath, map_location=self.device) |
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|
|
self.train_losses = checkpoint.get('train_losses', []) |
|
|
self.val_losses = checkpoint.get('val_losses', []) |
|
|
self.train_accuracies = checkpoint.get('train_accuracies', []) |
|
|
self.val_accuracies = checkpoint.get('val_accuracies', []) |
|
|
|
|
|
def plot_training_curves(self): |
|
|
"""Plot training curves.""" |
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) |
|
|
|
|
|
|
|
|
ax1.plot(self.train_losses, label='Train Loss') |
|
|
ax1.plot(self.val_losses, label='Val Loss') |
|
|
ax1.set_xlabel('Epoch') |
|
|
ax1.set_ylabel('Loss') |
|
|
ax1.set_title('Training and Validation Loss') |
|
|
ax1.legend() |
|
|
ax1.grid(True) |
|
|
|
|
|
|
|
|
ax2.plot(self.train_accuracies, label='Train Accuracy') |
|
|
ax2.plot(self.val_accuracies, label='Val Accuracy') |
|
|
ax2.set_xlabel('Epoch') |
|
|
ax2.set_ylabel('Accuracy') |
|
|
ax2.set_title('Training and Validation Accuracy') |
|
|
ax2.legend() |
|
|
ax2.grid(True) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(os.path.join(self.log_dir, 'training_curves.png')) |
|
|
plt.close() |
|
|
|
|
|
def close(self): |
|
|
"""Close the trainer and clean up resources.""" |
|
|
self.writer.close() |
|
|
|