import argparse import copy import numpy as np import torch from omegaconf import OmegaConf from sklearn.metrics import roc_auc_score from torch import nn, optim from barista.data.braintreebank_dataset import BrainTreebankDataset from barista.models.model import Barista from barista.models.utils import seed_everything def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser( description="Fine-tune Barista model on BrainTreebank dataset" ) parser.add_argument( "--dataset_config", type=str, default="barista/config/braintreebank.yaml", help="Path to dataset configuration file", ) parser.add_argument( "--train_config", type=str, default="barista/config/train.yaml", help="Path to training configuration file", ) parser.add_argument( "--model_config", type=str, default="barista/config/model.yaml", help="Path to model configuration file", ) parser.add_argument( "--override", type=str, nargs="+", default=[], help="Override config parameters (e.g., --override epochs=50 optimization.finetune_lr=1e-4)", ) return parser.parse_args() def load_configs(args): """Load all configuration files.""" dataset_config = OmegaConf.load(args.dataset_config) train_config = OmegaConf.load(args.train_config) model_config = OmegaConf.load(args.model_config) assert ( len(dataset_config.finetune_sessions) == 1 ), "Specify one session for finetuning" return dataset_config, train_config, model_config def apply_overrides(config_dict, overrides): """Apply command-line overrides to configs using dot notation.""" if not overrides: return config_dict override_dict = {} for override in overrides: if "=" not in override: raise ValueError( f"Invalid override format: {override}. Expected format: key=value" ) key, value = override.split("=", 1) try: if value.isnumeric(): if "." in value: value = float(value) else: value = int(value) elif value.startswith("[") or value in ("True", "False"): # list, bool value = eval(value) except ValueError as e: print(e) pass keys = key.split(".") current = override_dict for k in keys[:-1]: if k not in current: current[k] = {} current = current[k] current[keys[-1]] = value # Convert override dict to OmegaConf and merge override_conf = OmegaConf.create(override_dict) # Determine which config to merge based on keys merged_configs = {} for config_name, config in config_dict.items(): config_keys = set(OmegaConf.to_container(config).keys()) override_keys = set(override_dict.keys()) if config_keys.intersection(override_keys): merged_configs[config_name] = OmegaConf.merge(config, override_conf) else: merged_configs[config_name] = config if merged_configs.get("train") is not None: merged_configs["train"] = OmegaConf.merge( merged_configs["train"], override_conf ) return merged_configs def setup_dataloaders(dataset_config, train_config): """Initialize dataset and create dataloaders.""" dataset = BrainTreebankDataset(dataset_config) train_dataloader = dataset.get_dataloader("train", train_config) val_dataloader = dataset.get_dataloader("val", train_config) test_dataloader = dataset.get_dataloader("test", train_config) print(f"Train: {len(train_dataloader.dataset.metadata)} samples") print(f"Val: {len(val_dataloader.dataset.metadata)} samples") print(f"Test: {len(test_dataloader.dataset.metadata)} samples") dataset.check_no_common_segment(train_dataloader, val_dataloader, test_dataloader) return dataset, train_dataloader, val_dataloader, test_dataloader def get_optimizer(model, finetune_lr=1e-4, new_param_lr=1e-3): """Create optimizer with different learning rates for task and upstream parameters.""" task_params, upstream_params = [], [] for _, p in model.get_task_params(): if p.requires_grad: task_params.append(p) for _, p in model.get_upstream_params(): if p.requires_grad: upstream_params.append(p) params = [ {"params": upstream_params, "lr": finetune_lr}, {"params": task_params, "lr": new_param_lr}, ] optimizer = optim.AdamW(params, lr=finetune_lr, weight_decay=1e-2) return optimizer def get_lr_scheduler(optimizer): """Create learning rate scheduler with warmup and exponential decay.""" milestone = 5 lr_schedulers_list = [ torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=0.2, end_factor=1.0, total_iters=milestone, ), torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99), ] lr_scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, lr_schedulers_list, milestones=[milestone], ) return lr_scheduler def load_pretrained_weights(model, checkpoint_path, device): """Load pretrained weights, excluding masked_recon and multi_head_fc layers.""" checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) model.load_state_dict(checkpoint) print(f"Pretrained weights loaded from {checkpoint_path}") return model def freeze_tokenizer(model): for n, p in model.tokenizer.named_parameters(): p.requires_grad = False def print_number_of_parmas(model): trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) total_params = sum(p.numel() for p in model.parameters()) print(f"Model parameters: {total_params}\t Trainable params: {trainable_params}") def run_epoch( model, dataloader, criterion, device, optimizer=None, scheduler=None, train=False ): """Run one epoch of training or evaluation.""" if train: model.train() else: model.eval() all_preds = [] all_labels = [] running_loss = 0 for batch in dataloader: x = [x_item.to(device) for x_item in batch.x] y = batch.labels.flatten().long().to(device) if train: optimizer.zero_grad() with torch.set_grad_enabled(train): logits = model( x, subject_sessions=batch.subject_sessions, ) loss = criterion(logits, y) if train: loss.backward() optimizer.step() running_loss += loss.item() * y.size(0) probs = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() labels = y.detach().cpu().numpy() all_preds.append(probs) all_labels.append(labels) if train: # step scheduler at epoch interval scheduler.step() all_preds = np.concatenate(all_preds) all_labels = np.concatenate(all_labels) try: auc = roc_auc_score(all_labels, all_preds) except: auc = float("nan") avg_loss = running_loss / len(dataloader.dataset) return avg_loss, auc def finetune_model(model, train_dataloader, val_dataloader, train_config, device): """Finetune the model and track best validation performance.""" criterion = nn.CrossEntropyLoss() optimizer = get_optimizer( model, finetune_lr=train_config.optimization.finetune_lr, new_param_lr=train_config.optimization.new_param_lr, ) scheduler = get_lr_scheduler(optimizer) best_val_auc = -1 best_state = None num_epochs = train_config.epochs for epoch in range(num_epochs): train_loss, train_auc = run_epoch( model, train_dataloader, criterion, device, optimizer, scheduler, train=True ) val_loss, val_auc = evaluate_model(model, val_dataloader, criterion, device) print( f"Epoch {epoch+1}/{num_epochs} " f"- Train Loss: {train_loss:.4f}, AUC: {train_auc:.4f} " f"- Val Loss: {val_loss:.4f}, AUC: {val_auc:.4f}" ) # Track best model by validation AUC if best_state is None or val_auc > best_val_auc: best_val_auc = val_auc best_state = { "epoch": epoch + 1, "model": copy.deepcopy(model.state_dict()), "optimizer": copy.deepcopy(optimizer.state_dict()), "scheduler": copy.deepcopy(scheduler.state_dict()), "val_auc": val_auc, } return best_state, criterion def evaluate_model(model, test_dataloader, criterion, device): """Evaluate model on test set.""" test_loss, test_auc = run_epoch( model, test_dataloader, criterion, device, train=False ) return test_loss, test_auc def main(): """Main training pipeline.""" # Parse arguments and load configs args = parse_args() dataset_config, train_config, model_config = load_configs(args) configs = {"dataset": dataset_config, "train": train_config, "model": model_config} configs = apply_overrides(configs, args.override) dataset_config = configs["dataset"] train_config = configs["train"] model_config = configs["model"] # Set random seed seed_everything(train_config.seed) # Setup data dataset, train_dataloader, val_dataloader, test_dataloader = setup_dataloaders( dataset_config, train_config ) # Get fine-tuning session info ft_session = dataset_config.finetune_sessions[0] ft_session_n_chans = dataset.metadata.get_subject_session_full_d_data()[ft_session][ -1 ] # Initialize model device = train_config.device model = Barista(model_config, dataset.metadata) # Load pretrained weights if train_config.checkpoint_path: print("Running pretrained model") model = load_pretrained_weights(model, train_config.checkpoint_path, device) # Freeze tokenizer if train_config.optimization.freeze_tokenizer: freeze_tokenizer(model) else: print("Running non-pretrained model") # Create downstream head and move to device model.create_downstream_head(n_chans=ft_session_n_chans, output_dim=2) model.to(device) print_number_of_parmas(model) # Finetune model best_state, criterion = finetune_model( model, train_dataloader, val_dataloader, train_config, device ) print(f"\nBEST VAL AUC: {best_state['val_auc']:.4f}") # Evaluate on test set _, last_test_auc = evaluate_model(model, test_dataloader, criterion, device) print(f"LAST TEST AUC: {last_test_auc:.4f}") # Load best model for testing model.load_state_dict(best_state["model"]) # Evaluate on test set _, test_auc = evaluate_model(model, test_dataloader, criterion, device) print(f"BEST TEST AUC: {test_auc:.4f}") if __name__ == "__main__": main()