import os import sys import argparse import yaml import datetime import numpy as np from pathlib import Path from sklearn.metrics import f1_score import torch import torch.nn as nn import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler from torch.cuda.amp import GradScaler, autocast sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'hiera')) from hiera.hiera_mae import HieraClassifier from data.downstream_dataset import fMRITaskDataset, fMRITaskDataset1, EmoFMRIDataset, HCPtaskDataset from data.adni_dataset import ADNIDataset from utils.utils import MetricLogger, load_config, log_to_file, count_parameters, save_checkpoint, load_checkpoint, LabelScaler from utils.optim import create_optimizer, create_lr_scheduler from utils.ddp import setup_distributed, set_seed, cleanup_distributed def create_model(config): """Create Hiera Classifier model from config""" task_config = config['task'] exp_config = config['experiment'] model_config = config['model'] pretrained_checkpoint_path = exp_config.get('pretrained_checkpoint', None) if pretrained_checkpoint_path: pretrain_config_path = Path(pretrained_checkpoint_path).parent.parent / 'config.yaml' if os.path.exists(pretrain_config_path): print(f"Loading model architecture from pretrained config: {pretrain_config_path}") pretrain_config = load_config(pretrain_config_path) model_config = pretrain_config['model'] else: print(f"Warning: Pretrained config not found at {pretrain_config_path}. Using finetune config for model architecture.") model = HieraClassifier( num_classes=task_config['num_classes'], task_type=task_config['task_type'], input_size=tuple(model_config['input_size']), in_chans=model_config['in_chans'], patch_kernel=tuple(model_config['patch_kernel']), patch_stride=tuple(model_config['patch_stride']), patch_padding=tuple(model_config['patch_padding']), q_stride=tuple(model_config['q_stride']), mask_unit_size=tuple(model_config['mask_unit_size']), embed_dim=model_config['embed_dim'], num_heads=model_config['num_heads'], stages=tuple(model_config['stages']), q_pool=model_config['q_pool'], mlp_ratio=model_config['mlp_ratio'], ) # Load pretrained weights if specified if pretrained_checkpoint_path: if os.path.exists(pretrained_checkpoint_path): model.load_pretrained_mae(pretrained_checkpoint_path) else: print(f"Warning: Pretrained checkpoint not found at {pretrained_checkpoint_path}. Model is randomly initialized.") else: print("Warning: No pretrained checkpoint specified. Model is randomly initialized.") return model def create_dataloaders(config, is_distributed, rank, world_size): """Create train, validation, and test dataloaders""" data_config = config['data'] task_config = config['task'] train_dataset = fMRITaskDataset( data_root=data_config['data_root'], datasets=data_config['datasets'], split_suffixes=data_config['train_split_suffixes'], crop_length=data_config['input_seq_len'], label_csv_path=task_config['csv'], task_type=task_config['task_type'] ) val_dataset = fMRITaskDataset( data_root=data_config['data_root'], datasets=data_config['datasets'], split_suffixes=data_config['val_split_suffixes'], crop_length=data_config['input_seq_len'], label_csv_path=task_config['csv'], task_type=task_config['task_type'] ) test_dataset = fMRITaskDataset( data_root=data_config['data_root'], datasets=data_config['datasets'], split_suffixes=data_config.get('test_split_suffixes', ['test']), crop_length=data_config['input_seq_len'], label_csv_path=task_config['csv'], task_type=task_config['task_type'] ) # Create samplers if is_distributed: train_sampler = DistributedSampler( train_dataset, num_replicas=world_size, rank=rank, shuffle=True, seed=config['experiment']['seed'] ) val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False) test_sampler = DistributedSampler(test_dataset, num_replicas=world_size, rank=rank, shuffle=False) else: train_sampler = None val_sampler = None test_sampler = None # Create dataloaders train_loader = DataLoader( train_dataset, batch_size=data_config['batch_size'], sampler=train_sampler, shuffle=(train_sampler is None), num_workers=data_config['num_workers'], pin_memory=data_config['pin_memory'], prefetch_factor=data_config.get('prefetch_factor', 2), drop_last=True ) val_loader = DataLoader( val_dataset, batch_size=data_config['batch_size'], sampler=val_sampler, shuffle=False, num_workers=data_config['num_workers'], pin_memory=data_config['pin_memory'], prefetch_factor=data_config.get('prefetch_factor', 2), drop_last=False ) test_loader = DataLoader( test_dataset, batch_size=data_config['batch_size'], sampler=test_sampler, shuffle=False, num_workers=data_config['num_workers'], pin_memory=data_config['pin_memory'], prefetch_factor=data_config.get('prefetch_factor', 2), drop_last=False ) return train_loader, val_loader, test_loader, train_sampler def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, epoch, config, rank, world_size, label_scaler=None,log_file=None): """Train for one epoch""" model.train() metric_logger = MetricLogger(delimiter=" ") header = f'Epoch: [{epoch}]' train_config = config['training'] log_config = config['logging'] task_config = config['task'] accum_iter = train_config['accum_iter'] use_amp = train_config['use_amp'] clip_grad = train_config.get('clip_grad', None) optimizer.zero_grad() for data_iter_step, (samples, labels) in enumerate(metric_logger.log_every(train_loader, log_config['print_freq'], header)): # Adjust learning rate per iteration if data_iter_step % accum_iter == 0: scheduler.step() # Move data to GPU samples = samples.cuda(rank, non_blocking=True) labels = labels.cuda(rank, non_blocking=True) # Forward pass with mixed precision with autocast(enabled=use_amp): outputs = model(samples) # Calculate loss based on task type if task_config['task_type'] == 'classification': if labels.dim() > 1: labels = labels.squeeze() loss = criterion(outputs, labels) # Calculate accuracy _, predicted = outputs.max(1) correct = predicted.eq(labels).sum().item() accuracy = correct / labels.size(0) else: # regression if label_scaler is not None: target_for_loss = label_scaler.transform(labels) else: target_for_loss = labels loss = criterion(outputs.squeeze(), target_for_loss.squeeze()) accuracy = 0.0 # Not applicable for regression loss = loss / accum_iter # Backward pass if use_amp: scaler.scale(loss).backward() if (data_iter_step + 1) % accum_iter == 0: if clip_grad is not None: scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(model.parameters(), clip_grad) scaler.step(optimizer) scaler.update() optimizer.zero_grad() else: loss.backward() if (data_iter_step + 1) % accum_iter == 0: if clip_grad is not None: nn.utils.clip_grad_norm_(model.parameters(), clip_grad) optimizer.step() optimizer.zero_grad() # Synchronize loss across GPUs loss_value = loss.item() * accum_iter if not np.isfinite(loss_value): print(f"Loss is {loss_value}, stopping training") sys.exit(1) metric_logger.update(loss=loss_value) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) if task_config['task_type'] == 'classification': metric_logger.update(acc=accuracy) # Gather stats from all processes metric_logger.synchronize_between_processes() print(f"Averaged stats: {metric_logger}") return {k: meter.global_avg for k, meter in metric_logger.meters.items()} @torch.no_grad() def evaluate(model, data_loader, criterion, config, rank, epoch=None, label_scaler=None, mode='val'): model.eval() metric_logger = MetricLogger(delimiter=" ") header = f'{mode.capitalize()} Epoch: [{epoch}]' if epoch is not None else f'{mode.capitalize()}:' task_type = config['task']['task_type'] all_preds, all_targets = [], [] for samples, labels in metric_logger.log_every(data_loader, 50, header): samples = samples.cuda(rank, non_blocking=True) labels = labels.cuda(rank, non_blocking=True) outputs = model(samples) if task_type == 'classification': labels = labels.squeeze().long() if labels.dim() > 1 else labels.long() loss = criterion(outputs, labels) preds = outputs.argmax(1) acc = (preds == labels).float().mean().item() metric_logger.update(loss=loss.item(), acc=acc) all_preds.append(preds.cpu()) all_targets.append(labels.cpu()) else: if label_scaler is not None: target_norm = label_scaler.transform(labels) loss = criterion(outputs.view(-1), target_norm.view(-1)) metric_logger.update(loss=loss.item()) all_preds.append(outputs.detach().cpu().view(-1)) all_targets.append(target_norm.detach().cpu().view(-1)) if len(all_preds) > 0: all_preds = torch.cat(all_preds) all_targets = torch.cat(all_targets) if task_type == 'classification': f1 = f1_score(all_targets.numpy(), all_preds.numpy(), average='weighted') metric_logger.update(f1=f1) else: mse = torch.mean((all_preds - all_targets) ** 2).item() mae = torch.mean(torch.abs(all_preds - all_targets)).item() ss_res = torch.sum((all_targets - all_preds) ** 2) ss_tot = torch.sum((all_targets - all_targets.mean()) ** 2) r2 = (1 - ss_res / (ss_tot + 1e-8)).item() vx = all_preds - all_preds.mean() vy = all_targets - all_targets.mean() corr = (torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2)) + 1e-8)).item() metric_logger.update(mse=mse, mae=mae, r2=r2, corr=corr) metric_logger.synchronize_between_processes() if rank == 0: print(f"[{mode.upper()}] Global stats: {metric_logger}") return {k: meter.global_avg for k, meter in metric_logger.meters.items()} def main(): """Main fine-tuning function""" # Parse arguments parser = argparse.ArgumentParser(description='Hiera MAE 4D fMRI Downstream Fine-tuning') parser.add_argument('--config', type=str, default='configs/finetune_config.yaml', help='Path to config file') parser.add_argument('--resume', type=str, default=None, help='Path to checkpoint to resume from') parser.add_argument('--output_dir', type=str, default=None, help='Output directory (overrides config)') args = parser.parse_args() # Load config config = load_config(args.config) # Override config with command line arguments if args.resume is not None: config['experiment']['resume'] = args.resume if args.output_dir is not None: config['experiment']['output_dir'] = args.output_dir # Setup distributed training is_distributed, rank, world_size, gpu = setup_distributed() # Set random seed set_seed(config['experiment']['seed'], rank) # Create output directories if rank == 0: output_dir = Path(config['experiment']['output_dir']) checkpoint_dir = output_dir / 'checkpoints' log_dir = output_dir / 'logs' output_dir.mkdir(parents=True, exist_ok=True) checkpoint_dir.mkdir(parents=True, exist_ok=True) log_dir.mkdir(parents=True, exist_ok=True) # Save config with open(output_dir / 'config.yaml', 'w') as f: yaml.dump(config, f, default_flow_style=False) # Setup text log file log_file = output_dir / 'training_log.txt' with open(log_file, 'w') as f: f.write(f"Fine-tuning started at {datetime.datetime.now()}\n") f.write("="*80 + "\n") f.write(f"Config: {args.config}\n") f.write(f"Output directory: {config['experiment']['output_dir']}\n") f.write(f"Task type: {config['task']['task_type']}\n") f.write("="*80 + "\n\n") else: checkpoint_dir = None log_file = None if is_distributed: dist.barrier() model = create_model(config) model = model.cuda(gpu) if rank == 0: print("\nAnalyzing model architecture...") count_parameters(model, verbose=True) if is_distributed: model = DDP(model, device_ids=[gpu], find_unused_parameters=True) model_without_ddp = model.module if is_distributed else model if rank == 0: print("Creating dataloaders...") train_loader, val_loader, test_loader, train_sampler = create_dataloaders( config, is_distributed, rank, world_size ) label_scaler = None if config['task']['task_type'] == 'regression': if rank == 0: mean_val = config['task']['mean'] scale_val = config['task']['std'] print(f"StandardScaler fit complete. Mean: {mean_val:.4f}, Std: {scale_val:.4f}") norm_mean = torch.tensor(mean_val, device=gpu, dtype=torch.float32) norm_std = torch.tensor(scale_val, device=gpu, dtype=torch.float32) if is_distributed: dist.broadcast(norm_mean, src=0) dist.broadcast(norm_std, src=0) label_scaler = LabelScaler(norm_mean, norm_std) if rank == 0: print(f"Training samples: {len(train_loader.dataset)}") print(f"Validation samples: {len(val_loader.dataset)}") print(f"Test samples: {len(test_loader.dataset)}") print(f"Batches per epoch: {len(train_loader)}") # Create loss criterion task_config = config['task'] if task_config['task_type'] == 'classification': criterion = nn.CrossEntropyLoss(label_smoothing=0.0) else: # regression criterion = nn.MSELoss() # Optionally freeze the encoder if config['training'].get('freeze_encoder', False): if rank == 0: print("Freezing encoder weights. Only the head will be trained.") for name, param in model_without_ddp.named_parameters(): if 'head' not in name: param.requires_grad = False # Log which parameters are trainable if rank == 0: print("Trainable parameters:") for name, param in model_without_ddp.named_parameters(): if param.requires_grad: print(name) # Create optimizer and scheduler optimizer = create_optimizer(model_without_ddp, config) scheduler = create_lr_scheduler(optimizer, config, len(train_loader)) # Create gradient scaler for mixed precision scaler = GradScaler() if config['training']['use_amp'] else None # Load checkpoint if resuming start_epoch = 0 best_metric = 0.0 # For classification: accuracy best_loss = float('inf') # For regression: loss if config['experiment'].get('resume', None) is not None: start_epoch, best_metric, best_loss = load_checkpoint( config['experiment']['resume'], model_without_ddp, optimizer, scheduler, scaler ) print(f"Resumed from epoch {start_epoch}. Best metric: {best_metric:.4f}, Best loss: {best_loss:.4f}") else: # Initialize best_metric for new run based on task if config['task']['task_type'] == 'classification': best_metric = 0.0 # Accuracy starts at 0 else: # regression best_metric = float('inf') # Training loop if rank == 0: print("Starting fine-tuning...") print(f"Training from epoch {start_epoch} to {config['training']['epochs']}") for epoch in range(start_epoch, config['training']['epochs']): if is_distributed and train_sampler is not None: train_sampler.set_epoch(epoch) # Train for one epoch train_stats = train_one_epoch( model, train_loader, criterion, optimizer, scheduler, scaler, epoch, config, rank, world_size, label_scaler, log_file ) # Log training stats if rank == 0: log_msg = f"Epoch {epoch} Training - " log_msg += " | ".join([f"{k}: {v:.4f}" for k, v in train_stats.items()]) print(log_msg) log_to_file(log_file, log_msg) # Validate if epoch % config['validation']['val_freq'] == 0 or epoch == config['training']['epochs'] - 1: print(f"DEBUG: label_scaler type is {type(label_scaler)}, value is {label_scaler}") val_stats = evaluate( model, val_loader, criterion, config, rank, epoch, label_scaler, 'val' ) test_stats = evaluate(model, test_loader, criterion, config, rank, epoch, label_scaler, 'test' ) # Log validation stats if rank == 0: log_msg = f"Epoch {epoch} Validation - " log_msg += " | ".join([f"{k}: {v:.4f}" for k, v in val_stats.items()]) print(log_msg) log_to_file(log_file, log_msg) log_msg = f"Epoch {epoch} Test - " log_msg += " | ".join([f"{k}: {v:.4f}" for k, v in test_stats.items()]) print(log_msg) log_to_file(log_file, log_msg) # Determine best model based on task type if rank == 0: if task_config['task_type'] == 'classification': # For classification, higher accuracy is better current_metric = val_stats.get('acc', 0.0) is_best = current_metric > best_metric if is_best: best_metric = current_metric best_loss = val_stats['loss'] else: # For regression, lower loss is better is_best = val_stats['loss'] < best_loss if is_best: best_loss = val_stats['loss'] best_metric = -best_loss # Store negative loss as metric checkpoint_state = { 'epoch': epoch + 1, 'model_state_dict': model_without_ddp.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_metric': best_metric, 'best_loss': best_loss, 'config': config, 'train_stats': train_stats, 'val_stats': val_stats, } if scaler is not None: checkpoint_state['scaler_state_dict'] = scaler.state_dict() save_checkpoint( checkpoint_state, is_best, checkpoint_dir, filename=f'checkpoint_epoch_{epoch}.pth' ) checkpoint_msg = f"Checkpoint saved at epoch {epoch}" print(checkpoint_msg) log_to_file(log_file, checkpoint_msg) if is_best: if task_config['task_type'] == 'classification': best_msg = f"New best validation accuracy: {best_metric:.4f}" else: best_msg = f"New best validation loss: {best_loss:.4f}" print(best_msg) log_to_file(log_file, best_msg) # Save periodic checkpoint if rank == 0 and (epoch + 1) % config['logging']['save_freq'] == 0: checkpoint_state = { 'epoch': epoch + 1, 'model_state_dict': model_without_ddp.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_metric': best_metric, 'best_loss': best_loss, 'config': config, } if scaler is not None: checkpoint_state['scaler_state_dict'] = scaler.state_dict() save_checkpoint( checkpoint_state, False, checkpoint_dir, filename=f'checkpoint_epoch_{epoch}.pth' ) # Cleanup cleanup_distributed() if __name__ == '__main__': main()