Spaces:
Sleeping
Sleeping
| """ | |
| Training loop for forgery localization network | |
| Implements chunked training for RAM constraints | |
| """ | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from torch.cuda.amp import autocast, GradScaler | |
| from typing import Dict, Optional, Tuple | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import json | |
| import csv | |
| from ..models import get_model, get_loss_function | |
| from ..data import get_dataset | |
| from .metrics import MetricsTracker, EarlyStopping | |
| class Trainer: | |
| """ | |
| Trainer for forgery localization network | |
| Supports chunked training for large datasets (DocTamper) | |
| """ | |
| def __init__(self, config, dataset_name: str = 'doctamper'): | |
| """ | |
| Initialize trainer | |
| Args: | |
| config: Configuration object | |
| dataset_name: Dataset to train on | |
| """ | |
| self.config = config | |
| self.dataset_name = dataset_name | |
| # Device setup | |
| self.device = torch.device( | |
| 'cuda' if torch.cuda.is_available() and config.get('system.device') == 'cuda' | |
| else 'cpu' | |
| ) | |
| print(f"Training on: {self.device}") | |
| # Initialize model | |
| self.model = get_model(config).to(self.device) | |
| # Loss function (dataset-aware) | |
| self.criterion = get_loss_function(config) | |
| # Optimizer | |
| lr = config.get('training.learning_rate', 0.001) | |
| weight_decay = config.get('training.weight_decay', 0.0001) | |
| self.optimizer = optim.AdamW( | |
| self.model.parameters(), | |
| lr=lr, | |
| weight_decay=weight_decay | |
| ) | |
| # Learning rate scheduler | |
| epochs = config.get('training.epochs', 50) | |
| warmup_epochs = config.get('training.scheduler.warmup_epochs', 5) | |
| min_lr = config.get('training.scheduler.min_lr', 1e-5) | |
| self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( | |
| self.optimizer, | |
| T_0=epochs - warmup_epochs, | |
| T_mult=1, | |
| eta_min=min_lr | |
| ) | |
| # Mixed precision training | |
| self.scaler = GradScaler() | |
| # Metrics | |
| self.metrics_tracker = MetricsTracker(config) | |
| # Early stopping | |
| patience = config.get('training.early_stopping.patience', 10) | |
| min_delta = config.get('training.early_stopping.min_delta', 0.001) | |
| self.early_stopping = EarlyStopping(patience=patience, min_delta=min_delta) | |
| # Output directories | |
| self.checkpoint_dir = Path(config.get('outputs.checkpoints', 'outputs/checkpoints')) | |
| self.log_dir = Path(config.get('outputs.logs', 'outputs/logs')) | |
| self.checkpoint_dir.mkdir(parents=True, exist_ok=True) | |
| self.log_dir.mkdir(parents=True, exist_ok=True) | |
| # Training state | |
| self.current_epoch = 0 | |
| self.best_metric = 0.0 | |
| def create_dataloaders(self, | |
| chunk_start: float = 0.0, | |
| chunk_end: float = 1.0) -> Tuple[DataLoader, DataLoader]: | |
| """ | |
| Create train and validation dataloaders | |
| Args: | |
| chunk_start: Start ratio for chunked training | |
| chunk_end: End ratio for chunked training | |
| Returns: | |
| Train and validation dataloaders | |
| """ | |
| batch_size = self.config.get('data.batch_size', 8) | |
| num_workers = self.config.get('system.num_workers', 4) | |
| # Training dataset (with chunking for DocTamper) | |
| if self.dataset_name == 'doctamper': | |
| train_dataset = get_dataset( | |
| self.config, | |
| self.dataset_name, | |
| split='train', | |
| chunk_start=chunk_start, | |
| chunk_end=chunk_end | |
| ) | |
| else: | |
| train_dataset = get_dataset( | |
| self.config, | |
| self.dataset_name, | |
| split='train' | |
| ) | |
| # Validation dataset (always full) | |
| # For FCD and SCD, validate on DocTamper TestingSet | |
| if self.dataset_name in ['fcd', 'scd']: | |
| val_dataset = get_dataset( | |
| self.config, | |
| 'doctamper', # Use DocTamper for validation | |
| split='val' | |
| ) | |
| else: | |
| val_dataset = get_dataset( | |
| self.config, | |
| self.dataset_name, | |
| split='val' if self.dataset_name in ['doctamper', 'receipts'] else 'test' | |
| ) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| pin_memory=self.config.get('system.pin_memory', True), | |
| drop_last=True | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| pin_memory=True | |
| ) | |
| return train_loader, val_loader | |
| def train_epoch(self, dataloader: DataLoader) -> Tuple[float, Dict]: | |
| """ | |
| Train for one epoch | |
| Args: | |
| dataloader: Training dataloader | |
| Returns: | |
| Average loss and metrics | |
| """ | |
| self.model.train() | |
| self.metrics_tracker.reset() | |
| total_loss = 0.0 | |
| num_batches = 0 | |
| pbar = tqdm(dataloader, desc=f"Epoch {self.current_epoch} [Train]") | |
| for batch_idx, (images, masks, metadata) in enumerate(pbar): | |
| images = images.to(self.device) | |
| masks = masks.to(self.device) | |
| # Forward pass with mixed precision | |
| self.optimizer.zero_grad() | |
| with autocast(): | |
| outputs, _ = self.model(images) | |
| # Dataset-aware loss | |
| has_pixel_mask = self.config.has_pixel_mask(self.dataset_name) | |
| losses = self.criterion.combined_loss(outputs, masks, has_pixel_mask) | |
| # Backward pass with gradient scaling | |
| self.scaler.scale(losses['total']).backward() | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| # Update metrics | |
| with torch.no_grad(): | |
| probs = torch.sigmoid(outputs) | |
| self.metrics_tracker.update_segmentation( | |
| probs, masks, self.dataset_name | |
| ) | |
| total_loss += losses['total'].item() | |
| num_batches += 1 | |
| # Update progress bar | |
| pbar.set_postfix({ | |
| 'loss': f"{losses['total'].item():.4f}", | |
| 'bce': f"{losses['bce'].item():.4f}" | |
| }) | |
| avg_loss = total_loss / num_batches | |
| metrics = self.metrics_tracker.compute_all() | |
| return avg_loss, metrics | |
| def validate(self, dataloader: DataLoader) -> Tuple[float, Dict]: | |
| """ | |
| Validate model | |
| Args: | |
| dataloader: Validation dataloader | |
| Returns: | |
| Average loss and metrics | |
| """ | |
| self.model.eval() | |
| self.metrics_tracker.reset() | |
| total_loss = 0.0 | |
| num_batches = 0 | |
| pbar = tqdm(dataloader, desc=f"Epoch {self.current_epoch} [Val]") | |
| with torch.no_grad(): | |
| for images, masks, metadata in pbar: | |
| images = images.to(self.device) | |
| masks = masks.to(self.device) | |
| # Forward pass | |
| outputs, _ = self.model(images) | |
| # Dataset-aware loss | |
| has_pixel_mask = self.config.has_pixel_mask(self.dataset_name) | |
| losses = self.criterion.combined_loss(outputs, masks, has_pixel_mask) | |
| # Update metrics | |
| probs = torch.sigmoid(outputs) | |
| self.metrics_tracker.update_segmentation( | |
| probs, masks, self.dataset_name | |
| ) | |
| total_loss += losses['total'].item() | |
| num_batches += 1 | |
| pbar.set_postfix({ | |
| 'loss': f"{losses['total'].item():.4f}" | |
| }) | |
| avg_loss = total_loss / num_batches | |
| metrics = self.metrics_tracker.compute_all() | |
| return avg_loss, metrics | |
| def save_checkpoint(self, | |
| filename: str, | |
| is_best: bool = False, | |
| chunk_id: Optional[int] = None): | |
| """Save model checkpoint""" | |
| checkpoint = { | |
| 'epoch': self.current_epoch, | |
| 'model_state_dict': self.model.state_dict(), | |
| 'optimizer_state_dict': self.optimizer.state_dict(), | |
| 'scheduler_state_dict': self.scheduler.state_dict(), | |
| 'best_metric': self.best_metric, | |
| 'dataset': self.dataset_name, | |
| 'chunk_id': chunk_id | |
| } | |
| path = self.checkpoint_dir / filename | |
| torch.save(checkpoint, path) | |
| print(f"Saved checkpoint: {path}") | |
| if is_best: | |
| best_path = self.checkpoint_dir / f'best_{self.dataset_name}.pth' | |
| torch.save(checkpoint, best_path) | |
| print(f"Saved best model: {best_path}") | |
| def load_checkpoint(self, filename: str, reset_epoch: bool = False): | |
| """ | |
| Load model checkpoint | |
| Args: | |
| filename: Checkpoint filename | |
| reset_epoch: If True, reset epoch counter to 0 (useful for chunked training) | |
| """ | |
| path = self.checkpoint_dir / filename | |
| if not path.exists(): | |
| print(f"Checkpoint not found: {path}") | |
| return False | |
| checkpoint = torch.load(path, map_location=self.device) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) | |
| if reset_epoch: | |
| self.current_epoch = 0 | |
| print(f"Loaded checkpoint: {path} (epoch counter reset to 0)") | |
| else: | |
| self.current_epoch = checkpoint['epoch'] + 1 # Continue from next epoch | |
| print(f"Loaded checkpoint: {path} (resuming from epoch {self.current_epoch})") | |
| self.best_metric = checkpoint.get('best_metric', 0.0) | |
| return True | |
| def train(self, | |
| epochs: Optional[int] = None, | |
| chunk_start: float = 0.0, | |
| chunk_end: float = 1.0, | |
| chunk_id: Optional[int] = None, | |
| resume_from: Optional[str] = None): | |
| """ | |
| Main training loop | |
| Args: | |
| epochs: Number of epochs (None uses config) | |
| chunk_start: Start ratio for chunked training | |
| chunk_end: End ratio for chunked training | |
| chunk_id: Chunk identifier for logging | |
| resume_from: Checkpoint to resume from | |
| """ | |
| if epochs is None: | |
| epochs = self.config.get('training.epochs', 50) | |
| # Resume if specified | |
| if resume_from: | |
| self.load_checkpoint(resume_from) | |
| # Create dataloaders | |
| train_loader, val_loader = self.create_dataloaders(chunk_start, chunk_end) | |
| print(f"\n{'='*60}") | |
| print(f"Training: {self.dataset_name}") | |
| if chunk_id is not None: | |
| print(f"Chunk: {chunk_id} [{chunk_start*100:.0f}% - {chunk_end*100:.0f}%]") | |
| print(f"Epochs: {epochs}") | |
| print(f"Train samples: {len(train_loader.dataset)}") | |
| print(f"Val samples: {len(val_loader.dataset)}") | |
| print(f"{'='*60}\n") | |
| # Training log file | |
| log_file = self.log_dir / f'{self.dataset_name}_chunk{chunk_id or 0}_log.csv' | |
| with open(log_file, 'w', newline='') as f: | |
| writer = csv.writer(f) | |
| writer.writerow(['epoch', 'train_loss', 'val_loss', | |
| 'train_iou', 'val_iou', 'train_dice', 'val_dice', | |
| 'train_precision', 'val_precision', | |
| 'train_recall', 'val_recall', 'lr']) | |
| for epoch in range(self.current_epoch, epochs): | |
| self.current_epoch = epoch | |
| # Train | |
| train_loss, train_metrics = self.train_epoch(train_loader) | |
| # Validate | |
| val_loss, val_metrics = self.validate(val_loader) | |
| # Update scheduler | |
| self.scheduler.step() | |
| current_lr = self.optimizer.param_groups[0]['lr'] | |
| # Log metrics | |
| self.metrics_tracker.log_epoch(epoch, 'train', train_loss, train_metrics) | |
| self.metrics_tracker.log_epoch(epoch, 'val', val_loss, val_metrics) | |
| # Log to file | |
| with open(log_file, 'a', newline='') as f: | |
| writer = csv.writer(f) | |
| writer.writerow([ | |
| epoch, | |
| f"{train_loss:.4f}", | |
| f"{val_loss:.4f}", | |
| f"{train_metrics.get('iou', 0):.4f}", | |
| f"{val_metrics.get('iou', 0):.4f}", | |
| f"{train_metrics.get('dice', 0):.4f}", | |
| f"{val_metrics.get('dice', 0):.4f}", | |
| f"{train_metrics.get('precision', 0):.4f}", | |
| f"{val_metrics.get('precision', 0):.4f}", | |
| f"{train_metrics.get('recall', 0):.4f}", | |
| f"{val_metrics.get('recall', 0):.4f}", | |
| f"{current_lr:.6f}" | |
| ]) | |
| # Print summary | |
| print(f"\nEpoch {epoch}/{epochs-1}") | |
| print(f" Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}") | |
| print(f" Train IoU: {train_metrics.get('iou', 0):.4f} | Val IoU: {val_metrics.get('iou', 0):.4f}") | |
| print(f" Train Dice: {train_metrics.get('dice', 0):.4f} | Val Dice: {val_metrics.get('dice', 0):.4f}") | |
| print(f" LR: {current_lr:.6f}") | |
| # Save checkpoints | |
| if self.config.get('training.checkpoint.save_every', 5) > 0: | |
| if (epoch + 1) % self.config.get('training.checkpoint.save_every', 5) == 0: | |
| self.save_checkpoint( | |
| f'{self.dataset_name}_chunk{chunk_id or 0}_epoch{epoch}.pth', | |
| chunk_id=chunk_id | |
| ) | |
| # Check for best model | |
| monitor_metric = val_metrics.get('dice', 0) | |
| if monitor_metric > self.best_metric: | |
| self.best_metric = monitor_metric | |
| self.save_checkpoint( | |
| f'{self.dataset_name}_chunk{chunk_id or 0}_best.pth', | |
| is_best=True, | |
| chunk_id=chunk_id | |
| ) | |
| # Early stopping | |
| if self.early_stopping(monitor_metric): | |
| print(f"\nEarly stopping triggered at epoch {epoch}") | |
| break | |
| # Save final checkpoint | |
| self.save_checkpoint( | |
| f'{self.dataset_name}_chunk{chunk_id or 0}_final.pth', | |
| chunk_id=chunk_id | |
| ) | |
| # Save training history | |
| history_file = self.log_dir / f'{self.dataset_name}_chunk{chunk_id or 0}_history.json' | |
| with open(history_file, 'w') as f: | |
| json.dump(self.metrics_tracker.get_history(), f, indent=2) | |
| print(f"\nTraining complete!") | |
| print(f"Best Dice: {self.best_metric:.4f}") | |
| return self.metrics_tracker.get_history() | |
| def get_trainer(config, dataset_name: str = 'doctamper') -> Trainer: | |
| """Factory function for trainer""" | |
| return Trainer(config, dataset_name) | |