| """
|
| Training loop for forgery localization network
|
| Implements chunked training for RAM constraints
|
| """
|
|
|
| import os
|
| import torch
|
| import torch.nn as nn
|
| import torch.optim as optim
|
| from torch.utils.data import DataLoader
|
| from torch.cuda.amp import autocast, GradScaler
|
| from typing import Dict, Optional, Tuple
|
| from pathlib import Path
|
| from tqdm import tqdm
|
| import json
|
| import csv
|
|
|
| from ..models import get_model, get_loss_function
|
| from ..data import get_dataset
|
| from .metrics import MetricsTracker, EarlyStopping
|
|
|
|
|
| class Trainer:
|
| """
|
| Trainer for forgery localization network
|
| Supports chunked training for large datasets (DocTamper)
|
| """
|
|
|
| def __init__(self, config, dataset_name: str = 'doctamper'):
|
| """
|
| Initialize trainer
|
|
|
| Args:
|
| config: Configuration object
|
| dataset_name: Dataset to train on
|
| """
|
| self.config = config
|
| self.dataset_name = dataset_name
|
|
|
|
|
| self.device = torch.device(
|
| 'cuda' if torch.cuda.is_available() and config.get('system.device') == 'cuda'
|
| else 'cpu'
|
| )
|
| print(f"Training on: {self.device}")
|
|
|
|
|
| self.model = get_model(config).to(self.device)
|
|
|
|
|
| self.criterion = get_loss_function(config)
|
|
|
|
|
| lr = config.get('training.learning_rate', 0.001)
|
| weight_decay = config.get('training.weight_decay', 0.0001)
|
| self.optimizer = optim.AdamW(
|
| self.model.parameters(),
|
| lr=lr,
|
| weight_decay=weight_decay
|
| )
|
|
|
|
|
| epochs = config.get('training.epochs', 50)
|
| warmup_epochs = config.get('training.scheduler.warmup_epochs', 5)
|
| min_lr = config.get('training.scheduler.min_lr', 1e-5)
|
|
|
| self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
| self.optimizer,
|
| T_0=epochs - warmup_epochs,
|
| T_mult=1,
|
| eta_min=min_lr
|
| )
|
|
|
|
|
| self.scaler = GradScaler()
|
|
|
|
|
| self.metrics_tracker = MetricsTracker(config)
|
|
|
|
|
| patience = config.get('training.early_stopping.patience', 10)
|
| min_delta = config.get('training.early_stopping.min_delta', 0.001)
|
| self.early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
|
|
|
|
|
| self.checkpoint_dir = Path(config.get('outputs.checkpoints', 'outputs/checkpoints'))
|
| self.log_dir = Path(config.get('outputs.logs', 'outputs/logs'))
|
| self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| self.current_epoch = 0
|
| self.best_metric = 0.0
|
|
|
| def create_dataloaders(self,
|
| chunk_start: float = 0.0,
|
| chunk_end: float = 1.0) -> Tuple[DataLoader, DataLoader]:
|
| """
|
| Create train and validation dataloaders
|
|
|
| Args:
|
| chunk_start: Start ratio for chunked training
|
| chunk_end: End ratio for chunked training
|
|
|
| Returns:
|
| Train and validation dataloaders
|
| """
|
| batch_size = self.config.get('data.batch_size', 8)
|
| num_workers = self.config.get('system.num_workers', 4)
|
|
|
|
|
| if self.dataset_name == 'doctamper':
|
| train_dataset = get_dataset(
|
| self.config,
|
| self.dataset_name,
|
| split='train',
|
| chunk_start=chunk_start,
|
| chunk_end=chunk_end
|
| )
|
| else:
|
| train_dataset = get_dataset(
|
| self.config,
|
| self.dataset_name,
|
| split='train'
|
| )
|
|
|
|
|
|
|
| if self.dataset_name in ['fcd', 'scd']:
|
| val_dataset = get_dataset(
|
| self.config,
|
| 'doctamper',
|
| split='val'
|
| )
|
| else:
|
| val_dataset = get_dataset(
|
| self.config,
|
| self.dataset_name,
|
| split='val' if self.dataset_name in ['doctamper', 'receipts'] else 'test'
|
| )
|
|
|
| train_loader = DataLoader(
|
| train_dataset,
|
| batch_size=batch_size,
|
| shuffle=True,
|
| num_workers=num_workers,
|
| pin_memory=self.config.get('system.pin_memory', True),
|
| drop_last=True
|
| )
|
|
|
| val_loader = DataLoader(
|
| val_dataset,
|
| batch_size=batch_size,
|
| shuffle=False,
|
| num_workers=num_workers,
|
| pin_memory=True
|
| )
|
|
|
| return train_loader, val_loader
|
|
|
| def train_epoch(self, dataloader: DataLoader) -> Tuple[float, Dict]:
|
| """
|
| Train for one epoch
|
|
|
| Args:
|
| dataloader: Training dataloader
|
|
|
| Returns:
|
| Average loss and metrics
|
| """
|
| self.model.train()
|
| self.metrics_tracker.reset()
|
|
|
| total_loss = 0.0
|
| num_batches = 0
|
|
|
| pbar = tqdm(dataloader, desc=f"Epoch {self.current_epoch} [Train]")
|
|
|
| for batch_idx, (images, masks, metadata) in enumerate(pbar):
|
| images = images.to(self.device)
|
| masks = masks.to(self.device)
|
|
|
|
|
| self.optimizer.zero_grad()
|
|
|
| with autocast():
|
| outputs, _ = self.model(images)
|
|
|
|
|
| has_pixel_mask = self.config.has_pixel_mask(self.dataset_name)
|
| losses = self.criterion.combined_loss(outputs, masks, has_pixel_mask)
|
|
|
|
|
| self.scaler.scale(losses['total']).backward()
|
| self.scaler.step(self.optimizer)
|
| self.scaler.update()
|
|
|
|
|
| with torch.no_grad():
|
| probs = torch.sigmoid(outputs)
|
| self.metrics_tracker.update_segmentation(
|
| probs, masks, self.dataset_name
|
| )
|
|
|
| total_loss += losses['total'].item()
|
| num_batches += 1
|
|
|
|
|
| pbar.set_postfix({
|
| 'loss': f"{losses['total'].item():.4f}",
|
| 'bce': f"{losses['bce'].item():.4f}"
|
| })
|
|
|
| avg_loss = total_loss / num_batches
|
| metrics = self.metrics_tracker.compute_all()
|
|
|
| return avg_loss, metrics
|
|
|
| def validate(self, dataloader: DataLoader) -> Tuple[float, Dict]:
|
| """
|
| Validate model
|
|
|
| Args:
|
| dataloader: Validation dataloader
|
|
|
| Returns:
|
| Average loss and metrics
|
| """
|
| self.model.eval()
|
| self.metrics_tracker.reset()
|
|
|
| total_loss = 0.0
|
| num_batches = 0
|
|
|
| pbar = tqdm(dataloader, desc=f"Epoch {self.current_epoch} [Val]")
|
|
|
| with torch.no_grad():
|
| for images, masks, metadata in pbar:
|
| images = images.to(self.device)
|
| masks = masks.to(self.device)
|
|
|
|
|
| outputs, _ = self.model(images)
|
|
|
|
|
| has_pixel_mask = self.config.has_pixel_mask(self.dataset_name)
|
| losses = self.criterion.combined_loss(outputs, masks, has_pixel_mask)
|
|
|
|
|
| probs = torch.sigmoid(outputs)
|
| self.metrics_tracker.update_segmentation(
|
| probs, masks, self.dataset_name
|
| )
|
|
|
| total_loss += losses['total'].item()
|
| num_batches += 1
|
|
|
| pbar.set_postfix({
|
| 'loss': f"{losses['total'].item():.4f}"
|
| })
|
|
|
| avg_loss = total_loss / num_batches
|
| metrics = self.metrics_tracker.compute_all()
|
|
|
| return avg_loss, metrics
|
|
|
| def save_checkpoint(self,
|
| filename: str,
|
| is_best: bool = False,
|
| chunk_id: Optional[int] = None):
|
| """Save model checkpoint"""
|
| checkpoint = {
|
| 'epoch': self.current_epoch,
|
| 'model_state_dict': self.model.state_dict(),
|
| 'optimizer_state_dict': self.optimizer.state_dict(),
|
| 'scheduler_state_dict': self.scheduler.state_dict(),
|
| 'best_metric': self.best_metric,
|
| 'dataset': self.dataset_name,
|
| 'chunk_id': chunk_id
|
| }
|
|
|
| path = self.checkpoint_dir / filename
|
| torch.save(checkpoint, path)
|
| print(f"Saved checkpoint: {path}")
|
|
|
| if is_best:
|
| best_path = self.checkpoint_dir / f'best_{self.dataset_name}.pth'
|
| torch.save(checkpoint, best_path)
|
| print(f"Saved best model: {best_path}")
|
|
|
| def load_checkpoint(self, filename: str, reset_epoch: bool = False):
|
| """
|
| Load model checkpoint
|
|
|
| Args:
|
| filename: Checkpoint filename
|
| reset_epoch: If True, reset epoch counter to 0 (useful for chunked training)
|
| """
|
| path = self.checkpoint_dir / filename
|
|
|
| if not path.exists():
|
| print(f"Checkpoint not found: {path}")
|
| return False
|
|
|
| checkpoint = torch.load(path, map_location=self.device)
|
|
|
| self.model.load_state_dict(checkpoint['model_state_dict'])
|
| self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
|
| if reset_epoch:
|
| self.current_epoch = 0
|
| print(f"Loaded checkpoint: {path} (epoch counter reset to 0)")
|
| else:
|
| self.current_epoch = checkpoint['epoch'] + 1
|
| print(f"Loaded checkpoint: {path} (resuming from epoch {self.current_epoch})")
|
|
|
| self.best_metric = checkpoint.get('best_metric', 0.0)
|
|
|
| return True
|
|
|
| def train(self,
|
| epochs: Optional[int] = None,
|
| chunk_start: float = 0.0,
|
| chunk_end: float = 1.0,
|
| chunk_id: Optional[int] = None,
|
| resume_from: Optional[str] = None):
|
| """
|
| Main training loop
|
|
|
| Args:
|
| epochs: Number of epochs (None uses config)
|
| chunk_start: Start ratio for chunked training
|
| chunk_end: End ratio for chunked training
|
| chunk_id: Chunk identifier for logging
|
| resume_from: Checkpoint to resume from
|
| """
|
| if epochs is None:
|
| epochs = self.config.get('training.epochs', 50)
|
|
|
|
|
| if resume_from:
|
| self.load_checkpoint(resume_from)
|
|
|
|
|
| train_loader, val_loader = self.create_dataloaders(chunk_start, chunk_end)
|
|
|
| print(f"\n{'='*60}")
|
| print(f"Training: {self.dataset_name}")
|
| if chunk_id is not None:
|
| print(f"Chunk: {chunk_id} [{chunk_start*100:.0f}% - {chunk_end*100:.0f}%]")
|
| print(f"Epochs: {epochs}")
|
| print(f"Train samples: {len(train_loader.dataset)}")
|
| print(f"Val samples: {len(val_loader.dataset)}")
|
| print(f"{'='*60}\n")
|
|
|
|
|
| log_file = self.log_dir / f'{self.dataset_name}_chunk{chunk_id or 0}_log.csv'
|
| with open(log_file, 'w', newline='') as f:
|
| writer = csv.writer(f)
|
| writer.writerow(['epoch', 'train_loss', 'val_loss',
|
| 'train_iou', 'val_iou', 'train_dice', 'val_dice',
|
| 'train_precision', 'val_precision',
|
| 'train_recall', 'val_recall', 'lr'])
|
|
|
| for epoch in range(self.current_epoch, epochs):
|
| self.current_epoch = epoch
|
|
|
|
|
| train_loss, train_metrics = self.train_epoch(train_loader)
|
|
|
|
|
| val_loss, val_metrics = self.validate(val_loader)
|
|
|
|
|
| self.scheduler.step()
|
| current_lr = self.optimizer.param_groups[0]['lr']
|
|
|
|
|
| self.metrics_tracker.log_epoch(epoch, 'train', train_loss, train_metrics)
|
| self.metrics_tracker.log_epoch(epoch, 'val', val_loss, val_metrics)
|
|
|
|
|
| with open(log_file, 'a', newline='') as f:
|
| writer = csv.writer(f)
|
| writer.writerow([
|
| epoch,
|
| f"{train_loss:.4f}",
|
| f"{val_loss:.4f}",
|
| f"{train_metrics.get('iou', 0):.4f}",
|
| f"{val_metrics.get('iou', 0):.4f}",
|
| f"{train_metrics.get('dice', 0):.4f}",
|
| f"{val_metrics.get('dice', 0):.4f}",
|
| f"{train_metrics.get('precision', 0):.4f}",
|
| f"{val_metrics.get('precision', 0):.4f}",
|
| f"{train_metrics.get('recall', 0):.4f}",
|
| f"{val_metrics.get('recall', 0):.4f}",
|
| f"{current_lr:.6f}"
|
| ])
|
|
|
|
|
| print(f"\nEpoch {epoch}/{epochs-1}")
|
| print(f" Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
|
| print(f" Train IoU: {train_metrics.get('iou', 0):.4f} | Val IoU: {val_metrics.get('iou', 0):.4f}")
|
| print(f" Train Dice: {train_metrics.get('dice', 0):.4f} | Val Dice: {val_metrics.get('dice', 0):.4f}")
|
| print(f" LR: {current_lr:.6f}")
|
|
|
|
|
| if self.config.get('training.checkpoint.save_every', 5) > 0:
|
| if (epoch + 1) % self.config.get('training.checkpoint.save_every', 5) == 0:
|
| self.save_checkpoint(
|
| f'{self.dataset_name}_chunk{chunk_id or 0}_epoch{epoch}.pth',
|
| chunk_id=chunk_id
|
| )
|
|
|
|
|
| monitor_metric = val_metrics.get('dice', 0)
|
| if monitor_metric > self.best_metric:
|
| self.best_metric = monitor_metric
|
| self.save_checkpoint(
|
| f'{self.dataset_name}_chunk{chunk_id or 0}_best.pth',
|
| is_best=True,
|
| chunk_id=chunk_id
|
| )
|
|
|
|
|
| if self.early_stopping(monitor_metric):
|
| print(f"\nEarly stopping triggered at epoch {epoch}")
|
| break
|
|
|
|
|
| self.save_checkpoint(
|
| f'{self.dataset_name}_chunk{chunk_id or 0}_final.pth',
|
| chunk_id=chunk_id
|
| )
|
|
|
|
|
| history_file = self.log_dir / f'{self.dataset_name}_chunk{chunk_id or 0}_history.json'
|
| with open(history_file, 'w') as f:
|
| json.dump(self.metrics_tracker.get_history(), f, indent=2)
|
|
|
| print(f"\nTraining complete!")
|
| print(f"Best Dice: {self.best_metric:.4f}")
|
|
|
| return self.metrics_tracker.get_history()
|
|
|
|
|
| def get_trainer(config, dataset_name: str = 'doctamper') -> Trainer:
|
| """Factory function for trainer"""
|
| return Trainer(config, dataset_name)
|
|
|