|
|
"""
|
|
|
Training Script for TransMIL + Query2Label Hybrid Model
|
|
|
|
|
|
Supports:
|
|
|
- End-to-end training with ResNet-50 backbone
|
|
|
- Mixed precision training (AMP) for memory efficiency
|
|
|
- Gradient accumulation for larger effective batch size
|
|
|
- Gradient checkpointing for ResNet-50
|
|
|
- AsymmetricLoss for multi-label imbalance
|
|
|
- Multi-label evaluation metrics (mAP, per-class AP, F1)
|
|
|
"""
|
|
|
|
|
|
import sys
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
import argparse
|
|
|
import yaml
|
|
|
from pathlib import Path
|
|
|
from datetime import datetime
|
|
|
import json
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
from torch.cuda.amp import autocast, GradScaler
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
import numpy as np
|
|
|
from tqdm import tqdm
|
|
|
from sklearn.metrics import average_precision_score, f1_score
|
|
|
|
|
|
|
|
|
from models.transmil_q2l import TransMIL_Query2Label_E2E
|
|
|
from thyroid_dataset import create_dataloaders
|
|
|
|
|
|
|
|
|
try:
|
|
|
from models.aslloss import AsymmetricLossOptimized
|
|
|
except ImportError:
|
|
|
print("Warning: Could not import AsymmetricLoss.")
|
|
|
AsymmetricLossOptimized = None
|
|
|
'''
|
|
|
try:
|
|
|
#from aslloss import AsymmetricLossOptimized
|
|
|
from models.aslloss import AsymmetricLossOptimized
|
|
|
except ImportError:
|
|
|
print("Warning: Could not import AsymmetricLoss from query2labels.")
|
|
|
print("Make sure query2labels/lib/models/aslloss.py is in Python path.")
|
|
|
AsymmetricLossOptimized = None
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_multilabel_metrics(preds, targets, threshold=0.5):
|
|
|
"""
|
|
|
Compute multi-label classification metrics.
|
|
|
|
|
|
Args:
|
|
|
preds: [N, num_class] numpy array of probabilities
|
|
|
targets: [N, num_class] numpy array of binary labels
|
|
|
threshold: Classification threshold for F1 score
|
|
|
|
|
|
Returns:
|
|
|
dict with mAP, per-class AP, F1 scores
|
|
|
"""
|
|
|
metrics = {}
|
|
|
|
|
|
|
|
|
aps = []
|
|
|
for i in range(targets.shape[1]):
|
|
|
if targets[:, i].sum() > 0:
|
|
|
ap = average_precision_score(targets[:, i], preds[:, i])
|
|
|
aps.append(ap)
|
|
|
else:
|
|
|
aps.append(np.nan)
|
|
|
|
|
|
metrics['mAP'] = np.nanmean(aps)
|
|
|
metrics['per_class_AP'] = aps
|
|
|
|
|
|
|
|
|
preds_binary = (preds >= threshold).astype(int)
|
|
|
f1_micro = f1_score(targets, preds_binary, average='micro', zero_division=0)
|
|
|
f1_macro = f1_score(targets, preds_binary, average='macro', zero_division=0)
|
|
|
f1_samples = f1_score(targets, preds_binary, average='samples', zero_division=0)
|
|
|
|
|
|
metrics['F1_micro'] = f1_micro
|
|
|
metrics['F1_macro'] = f1_macro
|
|
|
metrics['F1_samples'] = f1_samples
|
|
|
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_epoch(model, dataloader, criterion, optimizer, scaler, device, config, epoch):
|
|
|
"""
|
|
|
Train for one epoch with gradient accumulation and mixed precision.
|
|
|
|
|
|
Args:
|
|
|
model: TransMIL_Query2Label_E2E model
|
|
|
dataloader: Training dataloader
|
|
|
criterion: AsymmetricLoss
|
|
|
optimizer: AdamW optimizer
|
|
|
scaler: GradScaler for AMP
|
|
|
device: torch.device
|
|
|
config: Config dict
|
|
|
epoch: Current epoch number
|
|
|
|
|
|
Returns:
|
|
|
Average loss for epoch
|
|
|
"""
|
|
|
model.train()
|
|
|
|
|
|
total_loss = 0.0
|
|
|
accumulation_steps = config['training']['gradient_accumulation_steps']
|
|
|
use_amp = config['training']['use_amp']
|
|
|
|
|
|
|
|
|
pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
for i, batch in enumerate(pbar):
|
|
|
images = batch['images'].to(device)
|
|
|
labels = batch['labels'].to(device)
|
|
|
num_instances_per_case = batch['num_instances_per_case']
|
|
|
|
|
|
|
|
|
if use_amp:
|
|
|
with autocast():
|
|
|
logits = model(images, num_instances_per_case)
|
|
|
loss = criterion(logits, labels)
|
|
|
loss = loss / accumulation_steps
|
|
|
else:
|
|
|
logits = model(images, num_instances_per_case)
|
|
|
loss = criterion(logits, labels)
|
|
|
loss = loss / accumulation_steps
|
|
|
|
|
|
|
|
|
if use_amp:
|
|
|
scaler.scale(loss).backward()
|
|
|
else:
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
if (i + 1) % accumulation_steps == 0:
|
|
|
if use_amp:
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
|
else:
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
total_loss += loss.item() * accumulation_steps
|
|
|
pbar.set_postfix({'loss': loss.item() * accumulation_steps})
|
|
|
|
|
|
return total_loss / len(dataloader)
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def validate(model, dataloader, criterion, device, config):
|
|
|
"""
|
|
|
Validate model with multi-label metrics.
|
|
|
|
|
|
Args:
|
|
|
model: TransMIL_Query2Label_E2E model
|
|
|
dataloader: Validation dataloader
|
|
|
criterion: AsymmetricLoss
|
|
|
device: torch.device
|
|
|
config: Config dict
|
|
|
|
|
|
Returns:
|
|
|
dict with loss and metrics (mAP, F1, etc.)
|
|
|
"""
|
|
|
model.eval()
|
|
|
|
|
|
total_loss = 0.0
|
|
|
all_preds = []
|
|
|
all_targets = []
|
|
|
|
|
|
for batch in tqdm(dataloader, desc="Validating"):
|
|
|
images = batch['images'].to(device)
|
|
|
labels = batch['labels'].to(device)
|
|
|
num_instances_per_case = batch['num_instances_per_case']
|
|
|
|
|
|
|
|
|
logits = model(images, num_instances_per_case)
|
|
|
loss = criterion(logits, labels)
|
|
|
|
|
|
|
|
|
preds = torch.sigmoid(logits)
|
|
|
|
|
|
|
|
|
all_preds.append(preds.cpu().numpy())
|
|
|
all_targets.append(labels.cpu().numpy())
|
|
|
|
|
|
total_loss += loss.item()
|
|
|
|
|
|
|
|
|
all_preds = np.concatenate(all_preds, axis=0)
|
|
|
all_targets = np.concatenate(all_targets, axis=0)
|
|
|
|
|
|
|
|
|
metrics = compute_multilabel_metrics(all_preds, all_targets)
|
|
|
metrics['loss'] = total_loss / len(dataloader)
|
|
|
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(config, resume_from=None):
|
|
|
"""
|
|
|
Main training function.
|
|
|
|
|
|
Args:
|
|
|
config: Config dictionary from YAML
|
|
|
resume_from: Optional checkpoint path to resume training
|
|
|
"""
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
print(f"\nUsing device: {device}")
|
|
|
if torch.cuda.is_available():
|
|
|
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
|
|
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
|
|
|
|
|
|
|
|
save_dir = Path(config['training']['save_dir'])
|
|
|
save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
log_dir = save_dir / 'logs' / datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
|
writer = SummaryWriter(log_dir)
|
|
|
|
|
|
|
|
|
with open(save_dir / 'config.yaml', 'w') as f:
|
|
|
yaml.dump(config, f)
|
|
|
|
|
|
|
|
|
print("\nCreating dataloaders...")
|
|
|
train_loader, val_loader, test_loader = create_dataloaders(config)
|
|
|
|
|
|
|
|
|
print("\nCreating model...")
|
|
|
model = TransMIL_Query2Label_E2E(
|
|
|
num_class=config['model']['num_class'],
|
|
|
hidden_dim=config['model']['hidden_dim'],
|
|
|
nheads=config['model']['nheads'],
|
|
|
num_decoder_layers=config['model']['num_decoder_layers'],
|
|
|
pretrained_resnet=config['model']['pretrained_resnet'],
|
|
|
use_checkpointing=config['training']['gradient_checkpointing'],
|
|
|
use_ppeg=config['model'].get('use_ppeg', False)
|
|
|
)
|
|
|
model = model.to(device)
|
|
|
|
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
print(f"Total parameters: {total_params:,}")
|
|
|
print(f"Trainable parameters: {trainable_params:,}")
|
|
|
|
|
|
|
|
|
optimizer = optim.AdamW(
|
|
|
model.parameters(),
|
|
|
lr=config['training']['lr'],
|
|
|
weight_decay=config['training']['weight_decay']
|
|
|
)
|
|
|
|
|
|
|
|
|
scheduler_type = config['training'].get('scheduler', 'cosine')
|
|
|
if scheduler_type == 'cosine':
|
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
|
|
optimizer,
|
|
|
T_max=config['training']['epochs'],
|
|
|
eta_min=1e-6
|
|
|
)
|
|
|
elif scheduler_type == 'onecycle':
|
|
|
scheduler = optim.lr_scheduler.OneCycleLR(
|
|
|
optimizer,
|
|
|
max_lr=config['training']['lr'],
|
|
|
epochs=config['training']['epochs'],
|
|
|
steps_per_epoch=len(train_loader)
|
|
|
)
|
|
|
else:
|
|
|
scheduler = None
|
|
|
|
|
|
|
|
|
if AsymmetricLossOptimized is not None:
|
|
|
criterion = AsymmetricLossOptimized(
|
|
|
gamma_neg=config['training']['gamma_neg'],
|
|
|
gamma_pos=config['training']['gamma_pos'],
|
|
|
clip=config['training']['clip'],
|
|
|
eps=1e-5
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
print("Warning: Using BCEWithLogitsLoss instead of AsymmetricLoss")
|
|
|
criterion = nn.BCEWithLogitsLoss()
|
|
|
|
|
|
|
|
|
scaler = GradScaler() if config['training']['use_amp'] else None
|
|
|
|
|
|
|
|
|
start_epoch = 0
|
|
|
best_map = 0.0
|
|
|
|
|
|
if resume_from is not None and Path(resume_from).exists():
|
|
|
print(f"\nResuming from {resume_from}")
|
|
|
checkpoint = torch.load(resume_from, map_location=device)
|
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
start_epoch = checkpoint['epoch'] + 1
|
|
|
best_map = checkpoint.get('best_map', 0.0)
|
|
|
if scheduler is not None and 'scheduler_state_dict' in checkpoint:
|
|
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
|
print(f"Resumed from epoch {start_epoch}, best mAP: {best_map:.4f}")
|
|
|
|
|
|
|
|
|
print(f"\nStarting training for {config['training']['epochs']} epochs...")
|
|
|
print("="*80)
|
|
|
|
|
|
for epoch in range(start_epoch, config['training']['epochs']):
|
|
|
|
|
|
train_loss = train_epoch(model, train_loader, criterion, optimizer, scaler, device, config, epoch)
|
|
|
|
|
|
|
|
|
val_metrics = validate(model, val_loader, criterion, device, config)
|
|
|
|
|
|
|
|
|
if scheduler is not None:
|
|
|
if scheduler_type == 'onecycle':
|
|
|
pass
|
|
|
else:
|
|
|
scheduler.step()
|
|
|
|
|
|
|
|
|
current_lr = optimizer.param_groups[0]['lr']
|
|
|
writer.add_scalar('Loss/train', train_loss, epoch)
|
|
|
writer.add_scalar('Loss/val', val_metrics['loss'], epoch)
|
|
|
writer.add_scalar('Metrics/mAP', val_metrics['mAP'], epoch)
|
|
|
writer.add_scalar('Metrics/F1_micro', val_metrics['F1_micro'], epoch)
|
|
|
writer.add_scalar('Metrics/F1_macro', val_metrics['F1_macro'], epoch)
|
|
|
writer.add_scalar('LR', current_lr, epoch)
|
|
|
|
|
|
|
|
|
print(f"\nEpoch {epoch}/{config['training']['epochs']}")
|
|
|
print(f" Train Loss: {train_loss:.4f}")
|
|
|
print(f" Val Loss: {val_metrics['loss']:.4f}")
|
|
|
print(f" mAP: {val_metrics['mAP']:.4f}")
|
|
|
print(f" F1 (micro): {val_metrics['F1_micro']:.4f}")
|
|
|
print(f" F1 (macro): {val_metrics['F1_macro']:.4f}")
|
|
|
print(f" LR: {current_lr:.6f}")
|
|
|
|
|
|
|
|
|
is_best = val_metrics['mAP'] > best_map
|
|
|
if is_best:
|
|
|
best_map = val_metrics['mAP']
|
|
|
|
|
|
if (epoch + 1) % config['training']['save_freq'] == 0 or is_best:
|
|
|
checkpoint = {
|
|
|
'epoch': epoch,
|
|
|
'model_state_dict': model.state_dict(),
|
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
|
'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
|
|
|
'train_loss': train_loss,
|
|
|
'val_metrics': val_metrics,
|
|
|
'best_map': best_map,
|
|
|
'config': config
|
|
|
}
|
|
|
|
|
|
|
|
|
torch.save(checkpoint, save_dir / 'checkpoint_latest.pth')
|
|
|
|
|
|
|
|
|
if is_best:
|
|
|
torch.save(checkpoint, save_dir / 'checkpoint_best.pth')
|
|
|
print(f" ✓ Saved best model (mAP: {best_map:.4f})")
|
|
|
|
|
|
|
|
|
if (epoch + 1) % config['training']['save_freq'] == 0:
|
|
|
torch.save(checkpoint, save_dir / f'checkpoint_epoch_{epoch}.pth')
|
|
|
|
|
|
print("\n" + "="*80)
|
|
|
print(f"Training completed! Best mAP: {best_map:.4f}")
|
|
|
print(f"Checkpoints saved to: {save_dir}")
|
|
|
|
|
|
writer.close()
|
|
|
|
|
|
|
|
|
print("\nEvaluating on test set...")
|
|
|
test_metrics = validate(model, test_loader, criterion, device, config)
|
|
|
print(f"\nTest Results:")
|
|
|
print(f" mAP: {test_metrics['mAP']:.4f}")
|
|
|
print(f" F1 (micro): {test_metrics['F1_micro']:.4f}")
|
|
|
print(f" F1 (macro): {test_metrics['F1_macro']:.4f}")
|
|
|
|
|
|
|
|
|
with open(save_dir / 'test_results.json', 'w') as f:
|
|
|
json.dump({k: float(v) if not isinstance(v, list) else v
|
|
|
for k, v in test_metrics.items()}, f, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description='Train TransMIL + Query2Label Hybrid Model')
|
|
|
parser.add_argument('--config', type=str, default='hybrid_model/config.yaml',
|
|
|
help='Path to config file')
|
|
|
parser.add_argument('--resume', type=str, default=None,
|
|
|
help='Path to checkpoint to resume from')
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
with open(args.config, 'r') as f:
|
|
|
config = yaml.safe_load(f)
|
|
|
|
|
|
print("="*80)
|
|
|
print("TransMIL + Query2Label Hybrid Model Training")
|
|
|
print("="*80)
|
|
|
print(f"\nConfig: {args.config}")
|
|
|
if args.resume:
|
|
|
print(f"Resume from: {args.resume}")
|
|
|
|
|
|
|
|
|
train(config, resume_from=args.resume)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|