""" Training Script for TransMIL + Query2Label Hybrid Model Supports: - End-to-end training with ResNet-50 backbone - Mixed precision training (AMP) for memory efficiency - Gradient accumulation for larger effective batch size - Gradient checkpointing for ResNet-50 - AsymmetricLoss for multi-label imbalance - Multi-label evaluation metrics (mAP, per-class AP, F1) """ import sys #sys.path.append('query2labels/lib/models') #sys.path.append('/XYFS01/HDD_POOL/sysu_gbli2/sysu_gbli2xy_1/chenshiyu/ThyroidAgent/ThyroidRegion/HintsVer3/query2labels/lib/') import os import argparse import yaml from pathlib import Path from datetime import datetime import json import torch import torch.nn as nn import torch.optim as optim from torch.cuda.amp import autocast, GradScaler from torch.utils.tensorboard import SummaryWriter import numpy as np from tqdm import tqdm from sklearn.metrics import average_precision_score, f1_score # Import model and dataset from models.transmil_q2l import TransMIL_Query2Label_E2E from thyroid_dataset import create_dataloaders # Import AsymmetricLoss try: from models.aslloss import AsymmetricLossOptimized except ImportError: print("Warning: Could not import AsymmetricLoss.") AsymmetricLossOptimized = None ''' try: #from aslloss import AsymmetricLossOptimized from models.aslloss import AsymmetricLossOptimized except ImportError: print("Warning: Could not import AsymmetricLoss from query2labels.") print("Make sure query2labels/lib/models/aslloss.py is in Python path.") AsymmetricLossOptimized = None ''' # ============================================================================ # Metrics # ============================================================================ def compute_multilabel_metrics(preds, targets, threshold=0.5): """ Compute multi-label classification metrics. Args: preds: [N, num_class] numpy array of probabilities targets: [N, num_class] numpy array of binary labels threshold: Classification threshold for F1 score Returns: dict with mAP, per-class AP, F1 scores """ metrics = {} # Mean Average Precision (mAP) aps = [] for i in range(targets.shape[1]): if targets[:, i].sum() > 0: # Skip classes with no positive samples ap = average_precision_score(targets[:, i], preds[:, i]) aps.append(ap) else: aps.append(np.nan) metrics['mAP'] = np.nanmean(aps) metrics['per_class_AP'] = aps # F1 Score at threshold preds_binary = (preds >= threshold).astype(int) f1_micro = f1_score(targets, preds_binary, average='micro', zero_division=0) f1_macro = f1_score(targets, preds_binary, average='macro', zero_division=0) f1_samples = f1_score(targets, preds_binary, average='samples', zero_division=0) metrics['F1_micro'] = f1_micro metrics['F1_macro'] = f1_macro metrics['F1_samples'] = f1_samples return metrics # ============================================================================ # Training Functions # ============================================================================ def train_epoch(model, dataloader, criterion, optimizer, scaler, device, config, epoch): """ Train for one epoch with gradient accumulation and mixed precision. Args: model: TransMIL_Query2Label_E2E model dataloader: Training dataloader criterion: AsymmetricLoss optimizer: AdamW optimizer scaler: GradScaler for AMP device: torch.device config: Config dict epoch: Current epoch number Returns: Average loss for epoch """ model.train() total_loss = 0.0 accumulation_steps = config['training']['gradient_accumulation_steps'] use_amp = config['training']['use_amp'] # Progress bar pbar = tqdm(dataloader, desc=f"Epoch {epoch}") optimizer.zero_grad() for i, batch in enumerate(pbar): images = batch['images'].to(device) # [B*N_total, 3, H, W] labels = batch['labels'].to(device) # [B, num_class] num_instances_per_case = batch['num_instances_per_case'] # [B] # Mixed precision forward pass if use_amp: with autocast(): logits = model(images, num_instances_per_case) loss = criterion(logits, labels) loss = loss / accumulation_steps # Scale loss for accumulation else: logits = model(images, num_instances_per_case) loss = criterion(logits, labels) loss = loss / accumulation_steps # Backward pass if use_amp: scaler.scale(loss).backward() else: loss.backward() # Optimizer step every accumulation_steps if (i + 1) % accumulation_steps == 0: if use_amp: scaler.step(optimizer) scaler.update() else: optimizer.step() optimizer.zero_grad() # Track loss total_loss += loss.item() * accumulation_steps pbar.set_postfix({'loss': loss.item() * accumulation_steps}) return total_loss / len(dataloader) @torch.no_grad() def validate(model, dataloader, criterion, device, config): """ Validate model with multi-label metrics. Args: model: TransMIL_Query2Label_E2E model dataloader: Validation dataloader criterion: AsymmetricLoss device: torch.device config: Config dict Returns: dict with loss and metrics (mAP, F1, etc.) """ model.eval() total_loss = 0.0 all_preds = [] all_targets = [] for batch in tqdm(dataloader, desc="Validating"): images = batch['images'].to(device) labels = batch['labels'].to(device) num_instances_per_case = batch['num_instances_per_case'] # Forward pass logits = model(images, num_instances_per_case) loss = criterion(logits, labels) # Sigmoid for multi-label probabilities preds = torch.sigmoid(logits) # Store predictions and targets all_preds.append(preds.cpu().numpy()) all_targets.append(labels.cpu().numpy()) total_loss += loss.item() # Concatenate all batches all_preds = np.concatenate(all_preds, axis=0) all_targets = np.concatenate(all_targets, axis=0) # Compute metrics metrics = compute_multilabel_metrics(all_preds, all_targets) metrics['loss'] = total_loss / len(dataloader) return metrics # ============================================================================ # Main Training Loop # ============================================================================ def train(config, resume_from=None): """ Main training function. Args: config: Config dictionary from YAML resume_from: Optional checkpoint path to resume training """ # Setup device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"\nUsing device: {device}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") # Create save directory save_dir = Path(config['training']['save_dir']) save_dir.mkdir(parents=True, exist_ok=True) # Create tensorboard writer log_dir = save_dir / 'logs' / datetime.now().strftime('%Y%m%d_%H%M%S') writer = SummaryWriter(log_dir) # Save config with open(save_dir / 'config.yaml', 'w') as f: yaml.dump(config, f) # Create dataloaders print("\nCreating dataloaders...") train_loader, val_loader, test_loader = create_dataloaders(config) # Create model print("\nCreating model...") model = TransMIL_Query2Label_E2E( num_class=config['model']['num_class'], hidden_dim=config['model']['hidden_dim'], nheads=config['model']['nheads'], num_decoder_layers=config['model']['num_decoder_layers'], pretrained_resnet=config['model']['pretrained_resnet'], use_checkpointing=config['training']['gradient_checkpointing'], use_ppeg=config['model'].get('use_ppeg', False) ) model = model.to(device) # Print model stats total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total parameters: {total_params:,}") print(f"Trainable parameters: {trainable_params:,}") # Create optimizer optimizer = optim.AdamW( model.parameters(), lr=config['training']['lr'], weight_decay=config['training']['weight_decay'] ) # Create scheduler scheduler_type = config['training'].get('scheduler', 'cosine') if scheduler_type == 'cosine': scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=config['training']['epochs'], eta_min=1e-6 ) elif scheduler_type == 'onecycle': scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=config['training']['lr'], epochs=config['training']['epochs'], steps_per_epoch=len(train_loader) ) else: scheduler = None # Create loss function if AsymmetricLossOptimized is not None: criterion = AsymmetricLossOptimized( gamma_neg=config['training']['gamma_neg'], gamma_pos=config['training']['gamma_pos'], clip=config['training']['clip'], eps=1e-5 ) else: # Fallback to BCEWithLogitsLoss print("Warning: Using BCEWithLogitsLoss instead of AsymmetricLoss") criterion = nn.BCEWithLogitsLoss() # Mixed precision scaler scaler = GradScaler() if config['training']['use_amp'] else None # Resume from checkpoint if specified start_epoch = 0 best_map = 0.0 if resume_from is not None and Path(resume_from).exists(): print(f"\nResuming from {resume_from}") checkpoint = torch.load(resume_from, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] + 1 best_map = checkpoint.get('best_map', 0.0) if scheduler is not None and 'scheduler_state_dict' in checkpoint: scheduler.load_state_dict(checkpoint['scheduler_state_dict']) print(f"Resumed from epoch {start_epoch}, best mAP: {best_map:.4f}") # Training loop print(f"\nStarting training for {config['training']['epochs']} epochs...") print("="*80) for epoch in range(start_epoch, config['training']['epochs']): # Train train_loss = train_epoch(model, train_loader, criterion, optimizer, scaler, device, config, epoch) # Validate val_metrics = validate(model, val_loader, criterion, device, config) # Update scheduler if scheduler is not None: if scheduler_type == 'onecycle': pass # OneCycleLR updates per step, not per epoch else: scheduler.step() # Log metrics current_lr = optimizer.param_groups[0]['lr'] writer.add_scalar('Loss/train', train_loss, epoch) writer.add_scalar('Loss/val', val_metrics['loss'], epoch) writer.add_scalar('Metrics/mAP', val_metrics['mAP'], epoch) writer.add_scalar('Metrics/F1_micro', val_metrics['F1_micro'], epoch) writer.add_scalar('Metrics/F1_macro', val_metrics['F1_macro'], epoch) writer.add_scalar('LR', current_lr, epoch) # Print epoch summary print(f"\nEpoch {epoch}/{config['training']['epochs']}") print(f" Train Loss: {train_loss:.4f}") print(f" Val Loss: {val_metrics['loss']:.4f}") print(f" mAP: {val_metrics['mAP']:.4f}") print(f" F1 (micro): {val_metrics['F1_micro']:.4f}") print(f" F1 (macro): {val_metrics['F1_macro']:.4f}") print(f" LR: {current_lr:.6f}") # Save checkpoint is_best = val_metrics['mAP'] > best_map if is_best: best_map = val_metrics['mAP'] if (epoch + 1) % config['training']['save_freq'] == 0 or is_best: checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None, 'train_loss': train_loss, 'val_metrics': val_metrics, 'best_map': best_map, 'config': config } # Save latest checkpoint torch.save(checkpoint, save_dir / 'checkpoint_latest.pth') # Save best checkpoint if is_best: torch.save(checkpoint, save_dir / 'checkpoint_best.pth') print(f" ✓ Saved best model (mAP: {best_map:.4f})") # Save periodic checkpoint if (epoch + 1) % config['training']['save_freq'] == 0: torch.save(checkpoint, save_dir / f'checkpoint_epoch_{epoch}.pth') print("\n" + "="*80) print(f"Training completed! Best mAP: {best_map:.4f}") print(f"Checkpoints saved to: {save_dir}") writer.close() # Final test evaluation print("\nEvaluating on test set...") test_metrics = validate(model, test_loader, criterion, device, config) print(f"\nTest Results:") print(f" mAP: {test_metrics['mAP']:.4f}") print(f" F1 (micro): {test_metrics['F1_micro']:.4f}") print(f" F1 (macro): {test_metrics['F1_macro']:.4f}") # Save test results with open(save_dir / 'test_results.json', 'w') as f: json.dump({k: float(v) if not isinstance(v, list) else v for k, v in test_metrics.items()}, f, indent=2) # ============================================================================ # Main # ============================================================================ def main(): parser = argparse.ArgumentParser(description='Train TransMIL + Query2Label Hybrid Model') parser.add_argument('--config', type=str, default='hybrid_model/config.yaml', help='Path to config file') parser.add_argument('--resume', type=str, default=None, help='Path to checkpoint to resume from') args = parser.parse_args() # Load config with open(args.config, 'r') as f: config = yaml.safe_load(f) print("="*80) print("TransMIL + Query2Label Hybrid Model Training") print("="*80) print(f"\nConfig: {args.config}") if args.resume: print(f"Resume from: {args.resume}") # Train train(config, resume_from=args.resume) if __name__ == "__main__": main()