#!/usr/bin/env python3 """ Unified Fine-tuning Script for Glycan Classification This script fine-tunes a pre-trained Multimodal Glycan BERT model on taxonomy classification tasks (domain, kingdom, phylum, class, order, family, genus, species) and property prediction tasks (immunogenicity, link). Usage: python downstream_tasks/finetune.py \ --task species \ --data_path downstream_tasks/glycan_classification_with_wurcs.csv \ --checkpoint checkpoints/best_multimodal_v3_model.pt \ --vocab data/vocabulary.json \ --output_dir downstream_tasks/results/species """ import argparse import json import logging import os import random import sys from pathlib import Path from typing import Dict, List, Optional, Tuple import math from datetime import datetime import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader import pandas as pd from tqdm import tqdm from sklearn.metrics import ( accuracy_score, f1_score, precision_score, recall_score, matthews_corrcoef, classification_report ) # Add parent to path sys.path.insert(0, str(Path(__file__).parent.parent)) from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig from downstream_tasks.utils.tokenizer import WURCSTokenizer from downstream_tasks.utils.dataset import GlycanClassificationDataset, compute_valid_classes logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def set_seed(seed: int): """Set random seeds for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False class GlycanClassifier(nn.Module): """ Classification head on top of pre-trained BERT. Improvements: - Attention pooling (works better than first-token for WURCS sequences) - Mono pooling (pool at monosaccharide level using residue_ids) - Reduced frozen layers (4 vs 8) for better adaptation """ def __init__( self, bert: MultimodalGlycanBERT, num_classes: int, dropout: float = 0.25, # Increased from 0.1 to combat overfitting freeze_layers: int = 8, # Increased from 4 to prevent overfitting pooling_strategy: str = "attention", # "mean", "first", "max", "attention", "mono" ): super().__init__() self.bert = bert self.num_classes = num_classes self.pooling_strategy = pooling_strategy # Freeze bottom layers for i, layer in enumerate(self.bert.seq_layers): if i < freeze_layers: for param in layer.parameters(): param.requires_grad = False # Classification head (use sequence hidden size) hidden_size = bert.config.seq_hidden_size # Attention pooling layer (if using attention or mono strategy) if pooling_strategy in ["attention", "mono"]: self.attention_weights = nn.Linear(hidden_size, 1) self.classifier = nn.Sequential( nn.Dropout(dropout), nn.Linear(hidden_size, hidden_size // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_size // 2, num_classes), ) def forward(self, token_ids, attention_mask, residue_ids=None, **kwargs): """ Forward pass for classification. Args: token_ids: (batch, seq_len) - Token IDs attention_mask: (batch, seq_len) - Attention mask residue_ids: (batch, seq_len) - Residue ID for each token (optional, for mono pooling) Returns: logits: (batch, num_classes) """ # Get sequence embeddings with branch/linkage info if available # Check if seq_embeddings supports the new parameters if hasattr(self.bert.seq_embeddings, 'branch_embeddings'): branch_depths = kwargs.get('branch_depths') linkage_types = kwargs.get('linkage_types') seq_hidden = self.bert.seq_embeddings(token_ids, branch_depths, linkage_types) else: seq_hidden = self.bert.seq_embeddings(token_ids) # Apply transformer layers for layer in self.bert.seq_layers: seq_hidden = layer(seq_hidden, attention_mask) # Optional: Compute auxiliary distance reconstruction loss (topology) dist_loss = 0.0 dist_labels = kwargs.get('dist_labels') if dist_labels is not None: # Predict distances using the pre-trained distance head # (which was ignored in previous fine-tuning versions) dist_predictions = self.bert.distance_head(seq_hidden) # (batch, seq, seq) # Mask out padding (-1) # labels shape: (batch, seq, seq) # predictions shape: (batch, seq, seq) # Ensure proper casting and device dist_labels = dist_labels.to(dist_predictions.device) mask = dist_labels != -1 if mask.any(): # Compute MSE loss on valid distances # We cast labels to float for MSE loss_fct = nn.MSELoss() dist_loss = loss_fct(dist_predictions[mask], dist_labels[mask].float()) # Pool based on strategy if self.pooling_strategy == "first": # Original: Use first token (CLS-style) pooled = seq_hidden[:, 0, :] elif self.pooling_strategy == "max": # Max pooling over sequence mask_expanded = attention_mask.unsqueeze(-1).float() seq_hidden_masked = seq_hidden * mask_expanded + (1 - mask_expanded) * -1e9 pooled = seq_hidden_masked.max(dim=1)[0] elif self.pooling_strategy == "mono" and residue_ids is not None: # Monosaccharide-level pooling: pool tokens within each residue, then attention over residues batch_size = seq_hidden.size(0) hidden_size = seq_hidden.size(-1) # First, pool within each residue using mean pooled_residues = [] max_residues = 32 # Max number of residues per glycan for b in range(batch_size): residue_reps = [] unique_res = torch.unique(residue_ids[b]) # Filter to actual residues (>= 0) unique_res = unique_res[unique_res >= 0] for rid in unique_res[:max_residues]: mask = (residue_ids[b] == rid).float() if mask.sum() > 0: res_rep = (seq_hidden[b] * mask.unsqueeze(-1)).sum(dim=0) / mask.sum() residue_reps.append(res_rep) if len(residue_reps) == 0: # Fallback to mean pooling mask_expanded = attention_mask[b].unsqueeze(-1).float() pooled_residues.append((seq_hidden[b] * mask_expanded).sum(dim=0) / mask_expanded.sum()) else: # Stack residue representations and apply attention res_stack = torch.stack(residue_reps, dim=0) # (num_res, hidden) scores = self.attention_weights(res_stack).squeeze(-1) # (num_res,) weights = torch.softmax(scores, dim=0).unsqueeze(-1) # (num_res, 1) pooled_residues.append((res_stack * weights).sum(dim=0)) pooled = torch.stack(pooled_residues, dim=0) # (batch, hidden) elif self.pooling_strategy == "attention": # Attention-weighted pooling scores = self.attention_weights(seq_hidden).squeeze(-1) # (batch, seq_len) scores = scores.masked_fill(attention_mask == 0, -1e9) weights = torch.softmax(scores, dim=1).unsqueeze(-1) # (batch, seq_len, 1) pooled = (seq_hidden * weights).sum(dim=1) else: # "mean" - default # Mean pooling over non-padding tokens (recommended for WURCS) mask_expanded = attention_mask.unsqueeze(-1).float() sum_hidden = (seq_hidden * mask_expanded).sum(dim=1) sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9) pooled = sum_hidden / sum_mask # Classify logits = self.classifier(pooled) return logits, dist_loss def get_config_from_checkpoint(checkpoint_path: str, device: str) -> MultimodalGlycanBERTConfig: """Extract config from checkpoint.""" checkpoint = torch.load(checkpoint_path, map_location=device) if 'config' in checkpoint: config_dict = checkpoint['config'] if 'model' in config_dict: model_cfg = config_dict['model'] seq_cfg = model_cfg.get('sequence', {}) ms_cfg = model_cfg.get('mass_spectrometry', model_cfg.get('ms', {})) struct_cfg = model_cfg.get('structure_3d', model_cfg.get('structure', {})) fusion_cfg = model_cfg.get('fusion', {}) return MultimodalGlycanBERTConfig( seq_vocab_size=seq_cfg.get('vocab_size', 166), seq_hidden_size=seq_cfg.get('hidden_size', 768), seq_num_layers=seq_cfg.get('num_hidden_layers', 12), seq_num_heads=seq_cfg.get('num_attention_heads', 12), seq_max_length=seq_cfg.get('max_length', 512), ms_vocab_size=ms_cfg.get('vocab_size', 242), ms_hidden_size=ms_cfg.get('hidden_size', 256), ms_num_layers=ms_cfg.get('num_hidden_layers', 4), ms_num_heads=ms_cfg.get('num_attention_heads', 4), ms_max_length=ms_cfg.get('max_length', 100), struct_vocab_size=struct_cfg.get('vocab_size', 1024), struct_hidden_size=struct_cfg.get('hidden_size', 256), struct_num_layers=struct_cfg.get('num_hidden_layers', 4), struct_num_heads=struct_cfg.get('num_attention_heads', 4), struct_max_length=struct_cfg.get('max_length', 100), use_3d=struct_cfg.get('enabled', struct_cfg.get('use_3d', True)), fusion_hidden_size=fusion_cfg.get('fusion_hidden_size', 512), ) else: return MultimodalGlycanBERTConfig(**config_dict) return MultimodalGlycanBERTConfig() def load_pretrained_bert(checkpoint_path: str, config: MultimodalGlycanBERTConfig, device: str) -> MultimodalGlycanBERT: """Load pre-trained BERT from checkpoint using provided config.""" logger.info(f"Loading checkpoint from {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=device) # Create model bert = MultimodalGlycanBERT(config) # Load weights with strict=False to handle any minor mismatches if 'model_state_dict' in checkpoint: bert.load_state_dict(checkpoint['model_state_dict'], strict=False) else: bert.load_state_dict(checkpoint, strict=False) logger.info("Loaded pre-trained BERT successfully") return bert def train_epoch( model: GlycanClassifier, train_loader: DataLoader, optimizer: AdamW, criterion: nn.Module, device: str, scheduler=None, dist_alpha: float = 0.5, ) -> dict: """Train for one epoch.""" model.train() total_loss = 0 all_preds = [] all_labels = [] pbar = tqdm(train_loader, desc="Training") for batch in pbar: token_ids = batch['token_ids'].to(device) attention_mask = batch['attention_mask'].to(device) residue_ids = batch['residue_ids'].to(device) if 'residue_ids' in batch else None branch_depths = batch['branch_depths'].to(device) if 'branch_depths' in batch else None # NEW linkage_types = batch['linkage_types'].to(device) if 'linkage_types' in batch else None # NEW dist_labels = batch['dist_labels'].to(device) if 'dist_labels' in batch else None # NEW (Topology) labels = batch['label'].to(device) optimizer.zero_grad() logits, dist_loss = model( token_ids, attention_mask, residue_ids, branch_depths=branch_depths, linkage_types=linkage_types, dist_labels=dist_labels ) # Main task loss cls_loss = criterion(logits, labels) # Total loss = Classification Loss + alpha * Topology Loss # We weight topology loss to avoid overwhelming the main task total_batch_loss = cls_loss + dist_alpha * dist_loss total_batch_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() if scheduler: scheduler.step() total_loss += total_batch_loss.item() preds = logits.argmax(dim=1).cpu().numpy() all_preds.extend(preds) all_labels.extend(labels.cpu().numpy()) pbar.set_postfix({'loss': f'{total_batch_loss.item():.4f}', 'dist': f'{dist_loss:.4f}' if isinstance(dist_loss, float) else f'{dist_loss.item():.4f}'}) avg_loss = total_loss / len(train_loader) accuracy = accuracy_score(all_labels, all_preds) return { 'loss': avg_loss, 'accuracy': accuracy, } def evaluate( model: GlycanClassifier, data_loader: DataLoader, criterion: nn.Module, device: str, num_classes: int = None, dist_alpha: float = 0.5, ) -> dict: """Evaluate model on dataset.""" model.eval() total_loss = 0 all_preds = [] all_labels = [] all_probs = [] # Store probabilities for AUROC/AUPRC with torch.no_grad(): for batch in tqdm(data_loader, desc="Evaluating"): token_ids = batch['token_ids'].to(device) attention_mask = batch['attention_mask'].to(device) residue_ids = batch['residue_ids'].to(device) if 'residue_ids' in batch else None branch_depths = batch['branch_depths'].to(device) if 'branch_depths' in batch else None # NEW linkage_types = batch['linkage_types'].to(device) if 'linkage_types' in batch else None # NEW dist_labels = batch['dist_labels'].to(device) if 'dist_labels' in batch else None # NEW labels = batch['label'].to(device) logits, dist_loss = model( token_ids, attention_mask, residue_ids, branch_depths=branch_depths, linkage_types=linkage_types, dist_labels=dist_labels ) cls_loss = criterion(logits, labels) loss = cls_loss + dist_alpha * dist_loss total_loss += loss.item() probs = torch.softmax(logits, dim=1).cpu().numpy() preds = logits.argmax(dim=1).cpu().numpy() all_preds.extend(preds) all_labels.extend(labels.cpu().numpy()) all_probs.extend(probs) avg_loss = total_loss / len(data_loader) accuracy = accuracy_score(all_labels, all_preds) f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0) f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0) mcc = matthews_corrcoef(all_labels, all_preds) # Compute AUROC and AUPRC (for multi-class: one-vs-rest) auroc = None auprc = None all_probs = np.array(all_probs) all_labels_arr = np.array(all_labels) try: from sklearn.metrics import roc_auc_score, average_precision_score from sklearn.preprocessing import label_binarize # Get unique classes present in labels unique_classes = np.unique(all_labels_arr) if len(unique_classes) == 2: # Binary classification auroc = roc_auc_score(all_labels_arr, all_probs[:, 1]) auprc = average_precision_score(all_labels_arr, all_probs[:, 1]) elif len(unique_classes) > 2 and num_classes is not None: # Multi-class: use one-vs-rest # Only compute if all classes are present in test set if len(unique_classes) == num_classes: auroc = roc_auc_score(all_labels_arr, all_probs, multi_class='ovr', average='macro') # AUPRC for multi-class: binarize labels labels_bin = label_binarize(all_labels_arr, classes=list(range(num_classes))) auprc = average_precision_score(labels_bin, all_probs, average='macro') else: # Some classes missing - compute on available classes auroc = roc_auc_score(all_labels_arr, all_probs, multi_class='ovr', average='macro', labels=unique_classes) except Exception as e: # AUROC/AUPRC may fail with certain class distributions pass return { 'loss': avg_loss, 'accuracy': accuracy, 'f1_macro': f1_macro, 'f1_weighted': f1_weighted, 'mcc': mcc, 'auroc': auroc, 'auprc': auprc, 'predictions': all_preds, 'labels': all_labels, } def main(): parser = argparse.ArgumentParser(description='Fine-tune Glycan BERT for classification') # Required arguments parser.add_argument('--task', type=str, required=True, help='Task name (e.g., species, phylum)') parser.add_argument('--data_path', type=str, required=True, help='Path to CSV data file') parser.add_argument('--checkpoint', type=str, required=True, help='Path to pre-trained model checkpoint') parser.add_argument('--vocab', type=str, required=True, help='Path to vocabulary.json') parser.add_argument('--output_dir', type=str, required=True, help='Output directory for results') # Optional arguments parser.add_argument('--batch_size', type=int, default=256, help='Batch size (matching GlycanML for stable gradients)') parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--lr', type=float, default=5e-5) parser.add_argument('--weight_decay', type=float, default=0.01) parser.add_argument('--dropout', type=float, default=0.25, help='Dropout rate (increased from 0.1 to combat overfitting)') parser.add_argument('--freeze_layers', type=int, default=8, help='Number of bottom layers to freeze (increased from 4 to prevent overfitting)') parser.add_argument('--pooling_strategy', type=str, default='attention', choices=['mean', 'first', 'max', 'attention', 'mono'], help='Pooling strategy: attention (recommended), mono (residue-level), mean, first (CLS-style), max') parser.add_argument('--max_length', type=int, default=256) parser.add_argument('--patience', type=int, default=10, help='Early stopping patience') parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') parser.add_argument('--seed', type=int, default=42) parser.add_argument('--num_workers', type=int, default=4) parser.add_argument('--filter_mode', type=str, default='none', choices=['none', 'strict', 'strict_3', 'strict_5'], help='Class filtering: none (use all), strict (n=1), strict_3 (n=3), strict_5 (n=5)') parser.add_argument('--dist_alpha', type=float, default=0.0, help='Weight for auxiliary distance/topology loss (default: 0.0 disabled, set >0 to enable)') args = parser.parse_args() # Setup set_seed(args.seed) os.makedirs(args.output_dir, exist_ok=True) # Log configuration logger.info("=" * 70) logger.info(f"FINE-TUNING GLYCAN BERT ON {args.task.upper()}") logger.info("=" * 70) logger.info(f"Data: {args.data_path}") logger.info(f"Checkpoint: {args.checkpoint}") logger.info(f"Output: {args.output_dir}") logger.info(f"Device: {args.device}") logger.info(f"Seed: {args.seed}") logger.info(f"Filter mode: {args.filter_mode}") # Compute valid classes if using strict filtering valid_classes = None if args.filter_mode != 'none': # Determine min_samples from filter mode if args.filter_mode == 'strict': min_samples = 1 elif args.filter_mode == 'strict_3': min_samples = 3 elif args.filter_mode == 'strict_5': min_samples = 5 else: min_samples = 1 logger.info(f"\nComputing valid classes ({args.filter_mode} mode, min_samples={min_samples})...") valid_classes = compute_valid_classes(args.data_path, args.task, min_samples=min_samples) logger.info(f" Will use {len(valid_classes)} classes present in all splits") # Load config from checkpoint to get model capacity logger.info("\nChecking model capacity from checkpoint...") checkpoint_config = get_config_from_checkpoint(args.checkpoint, 'cpu') model_max_length = checkpoint_config.seq_max_length # Override max_length if it exceeds model capacity if args.max_length > model_max_length: logger.warning(f" Requested max_length ({args.max_length}) exceeds model capacity ({model_max_length}).") logger.warning(f" Overriding max_length to {model_max_length} to prevent size mismatch errors.") dataset_max_length = model_max_length else: dataset_max_length = args.max_length # Load data logger.info(f"\nLoading data (max_length={dataset_max_length})...") train_dataset = GlycanClassificationDataset( args.data_path, args.task, 'train', args.vocab, dataset_max_length, valid_classes=valid_classes ) val_dataset = GlycanClassificationDataset( args.data_path, args.task, 'validation', args.vocab, dataset_max_length, valid_classes=valid_classes ) test_dataset = GlycanClassificationDataset( args.data_path, args.task, 'test', args.vocab, dataset_max_length, valid_classes=valid_classes ) logger.info(f"\nDataset summary:") logger.info(f" Train: {len(train_dataset)} samples") logger.info(f" Val: {len(val_dataset)} samples") logger.info(f" Test: {len(test_dataset)} samples") logger.info(f" Classes: {len(train_dataset.unique_labels)}") # Report class filtering stats if args.filter_mode == 'none': train_classes = set(train_dataset.unique_labels) val_classes = set(val_dataset.unique_labels) test_classes = set(test_dataset.unique_labels) common_classes = train_classes & val_classes & test_classes logger.info(f"\nClass distribution (filter_mode=none):") logger.info(f" Train-only classes: {len(train_classes - common_classes)}") logger.info(f" Val-only classes: {len(val_classes - train_classes - test_classes)}") logger.info(f" Test-only classes: {len(test_classes - train_classes - val_classes)}") logger.info(f" Common to all: {len(common_classes)}") # Create dataloaders train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True ) val_loader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True ) test_loader = DataLoader( test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True ) # Load model logger.info("\nLoading model...") bert = load_pretrained_bert(args.checkpoint, checkpoint_config, args.device) num_classes = len(train_dataset.unique_labels) model = GlycanClassifier( bert, num_classes, dropout=args.dropout, freeze_layers=args.freeze_layers, pooling_strategy=args.pooling_strategy, ).to(args.device) logger.info(f" Pooling strategy: {args.pooling_strategy}") total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f" Total params: {total_params:,}") logger.info(f" Trainable params: {trainable_params:,} ({trainable_params/total_params*100:.1f}%)") # Setup training criterion = nn.CrossEntropyLoss() optimizer = AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay ) total_steps = len(train_loader) * args.epochs scheduler = CosineAnnealingLR(optimizer, T_max=total_steps) # Training loop logger.info("\n" + "=" * 70) logger.info("TRAINING") logger.info("=" * 70) best_val_mcc = -1 epochs_without_improvement = 0 history = [] for epoch in range(args.epochs): logger.info(f"\nEpoch {epoch + 1}/{args.epochs}") # Train train_metrics = train_epoch(model, train_loader, optimizer, criterion, args.device, scheduler, dist_alpha=args.dist_alpha) # Validate val_metrics = evaluate(model, val_loader, criterion, args.device, num_classes, dist_alpha=args.dist_alpha) logger.info(f" Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.4f}") val_log = f" Val - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, " val_log += f"F1: {val_metrics['f1_macro']:.4f}, MCC: {val_metrics['mcc']:.4f}" if val_metrics['auroc'] is not None: val_log += f", AUROC: {val_metrics['auroc']:.4f}" logger.info(val_log) history.append({ 'epoch': epoch + 1, 'train_loss': train_metrics['loss'], 'train_acc': train_metrics['accuracy'], 'val_loss': val_metrics['loss'], 'val_acc': val_metrics['accuracy'], 'val_f1': val_metrics['f1_macro'], 'val_mcc': val_metrics['mcc'], 'val_auroc': val_metrics['auroc'], 'val_auprc': val_metrics['auprc'], }) # Check for improvement if val_metrics['mcc'] > best_val_mcc: best_val_mcc = val_metrics['mcc'] epochs_without_improvement = 0 # Save best model torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_mcc': best_val_mcc, 'config': { 'task': args.task, 'num_classes': num_classes, 'classes': train_dataset.unique_labels, } }, os.path.join(args.output_dir, 'best_model.pt')) logger.info(f" New best MCC: {best_val_mcc:.4f} (saved)") else: epochs_without_improvement += 1 logger.info(f" No improvement ({epochs_without_improvement}/{args.patience})") # Early stopping if epochs_without_improvement >= args.patience: logger.info(f"\nEarly stopping at epoch {epoch + 1}") break # Load best model for testing logger.info("\n" + "=" * 70) logger.info("TESTING") logger.info("=" * 70) best_checkpoint = torch.load(os.path.join(args.output_dir, 'best_model.pt')) model.load_state_dict(best_checkpoint['model_state_dict']) test_metrics = evaluate(model, test_loader, criterion, args.device, num_classes) logger.info(f"\nTest Results:") logger.info(f" Accuracy: {test_metrics['accuracy']:.4f}") logger.info(f" F1-Macro: {test_metrics['f1_macro']:.4f}") logger.info(f" F1-Weighted: {test_metrics['f1_weighted']:.4f}") logger.info(f" MCC: {test_metrics['mcc']:.4f}") if test_metrics['auroc'] is not None: logger.info(f" AUROC: {test_metrics['auroc']:.4f}") if test_metrics['auprc'] is not None: logger.info(f" AUPRC: {test_metrics['auprc']:.4f}") # Save results results = { 'task': args.task, 'filter_mode': args.filter_mode, 'num_classes': num_classes, 'classes': train_dataset.unique_labels, 'train_samples': len(train_dataset), 'val_samples': len(val_dataset), 'test_samples': len(test_dataset), 'best_epoch': best_checkpoint['epoch'], 'test_accuracy': test_metrics['accuracy'], 'test_f1_macro': test_metrics['f1_macro'], 'test_f1_weighted': test_metrics['f1_weighted'], 'test_mcc': test_metrics['mcc'], 'test_auroc': test_metrics['auroc'], 'test_auprc': test_metrics['auprc'], 'config': vars(args), 'history': history, } with open(os.path.join(args.output_dir, 'results.json'), 'w') as f: json.dump(results, f, indent=2) logger.info(f"\nResults saved to {args.output_dir}") if __name__ == '__main__': main()