| |
| """ |
| Unified Fine-tuning Script for Glycan Classification |
| |
| This script fine-tunes a pre-trained Multimodal Glycan BERT model |
| on taxonomy classification tasks (domain, kingdom, phylum, class, |
| order, family, genus, species) and property prediction tasks |
| (immunogenicity, link). |
| |
| Usage: |
| python downstream_tasks/finetune.py \ |
| --task species \ |
| --data_path downstream_tasks/glycan_classification_with_wurcs.csv \ |
| --checkpoint checkpoints/best_multimodal_v3_model.pt \ |
| --vocab data/vocabulary.json \ |
| --output_dir downstream_tasks/results/species |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import os |
| import random |
| import sys |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
| import math |
| from datetime import datetime |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from torch.utils.data import DataLoader |
| import pandas as pd |
| from tqdm import tqdm |
| from sklearn.metrics import ( |
| accuracy_score, f1_score, precision_score, recall_score, |
| matthews_corrcoef, classification_report |
| ) |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig |
| from downstream_tasks.utils.tokenizer import WURCSTokenizer |
| from downstream_tasks.utils.dataset import GlycanClassificationDataset, compute_valid_classes |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def set_seed(seed: int): |
| """Set random seeds for reproducibility.""" |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| class GlycanClassifier(nn.Module): |
| """ |
| Classification head on top of pre-trained BERT. |
| |
| Improvements: |
| - Attention pooling (works better than first-token for WURCS sequences) |
| - Mono pooling (pool at monosaccharide level using residue_ids) |
| - Reduced frozen layers (4 vs 8) for better adaptation |
| """ |
| |
| def __init__( |
| self, |
| bert: MultimodalGlycanBERT, |
| num_classes: int, |
| dropout: float = 0.25, |
| freeze_layers: int = 8, |
| pooling_strategy: str = "attention", |
| ): |
| super().__init__() |
| self.bert = bert |
| self.num_classes = num_classes |
| self.pooling_strategy = pooling_strategy |
| |
| |
| for i, layer in enumerate(self.bert.seq_layers): |
| if i < freeze_layers: |
| for param in layer.parameters(): |
| param.requires_grad = False |
| |
| |
| hidden_size = bert.config.seq_hidden_size |
| |
| |
| if pooling_strategy in ["attention", "mono"]: |
| self.attention_weights = nn.Linear(hidden_size, 1) |
| |
| self.classifier = nn.Sequential( |
| nn.Dropout(dropout), |
| nn.Linear(hidden_size, hidden_size // 2), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_size // 2, num_classes), |
| ) |
| |
| def forward(self, token_ids, attention_mask, residue_ids=None, **kwargs): |
| """ |
| Forward pass for classification. |
| |
| Args: |
| token_ids: (batch, seq_len) - Token IDs |
| attention_mask: (batch, seq_len) - Attention mask |
| residue_ids: (batch, seq_len) - Residue ID for each token (optional, for mono pooling) |
| |
| Returns: |
| logits: (batch, num_classes) |
| """ |
| |
| |
| if hasattr(self.bert.seq_embeddings, 'branch_embeddings'): |
| branch_depths = kwargs.get('branch_depths') |
| linkage_types = kwargs.get('linkage_types') |
| seq_hidden = self.bert.seq_embeddings(token_ids, branch_depths, linkage_types) |
| else: |
| seq_hidden = self.bert.seq_embeddings(token_ids) |
| |
| |
| for layer in self.bert.seq_layers: |
| seq_hidden = layer(seq_hidden, attention_mask) |
| |
| |
| dist_loss = 0.0 |
| dist_labels = kwargs.get('dist_labels') |
| if dist_labels is not None: |
| |
| |
| dist_predictions = self.bert.distance_head(seq_hidden) |
| |
| |
| |
| |
| |
| |
| dist_labels = dist_labels.to(dist_predictions.device) |
| mask = dist_labels != -1 |
| |
| if mask.any(): |
| |
| |
| loss_fct = nn.MSELoss() |
| dist_loss = loss_fct(dist_predictions[mask], dist_labels[mask].float()) |
| |
| |
| if self.pooling_strategy == "first": |
| |
| pooled = seq_hidden[:, 0, :] |
| elif self.pooling_strategy == "max": |
| |
| mask_expanded = attention_mask.unsqueeze(-1).float() |
| seq_hidden_masked = seq_hidden * mask_expanded + (1 - mask_expanded) * -1e9 |
| pooled = seq_hidden_masked.max(dim=1)[0] |
| elif self.pooling_strategy == "mono" and residue_ids is not None: |
| |
| batch_size = seq_hidden.size(0) |
| hidden_size = seq_hidden.size(-1) |
| |
| |
| pooled_residues = [] |
| max_residues = 32 |
| |
| for b in range(batch_size): |
| residue_reps = [] |
| unique_res = torch.unique(residue_ids[b]) |
| |
| unique_res = unique_res[unique_res >= 0] |
| |
| for rid in unique_res[:max_residues]: |
| mask = (residue_ids[b] == rid).float() |
| if mask.sum() > 0: |
| res_rep = (seq_hidden[b] * mask.unsqueeze(-1)).sum(dim=0) / mask.sum() |
| residue_reps.append(res_rep) |
| |
| if len(residue_reps) == 0: |
| |
| mask_expanded = attention_mask[b].unsqueeze(-1).float() |
| pooled_residues.append((seq_hidden[b] * mask_expanded).sum(dim=0) / mask_expanded.sum()) |
| else: |
| |
| res_stack = torch.stack(residue_reps, dim=0) |
| scores = self.attention_weights(res_stack).squeeze(-1) |
| weights = torch.softmax(scores, dim=0).unsqueeze(-1) |
| pooled_residues.append((res_stack * weights).sum(dim=0)) |
| |
| pooled = torch.stack(pooled_residues, dim=0) |
| elif self.pooling_strategy == "attention": |
| |
| scores = self.attention_weights(seq_hidden).squeeze(-1) |
| scores = scores.masked_fill(attention_mask == 0, -1e9) |
| weights = torch.softmax(scores, dim=1).unsqueeze(-1) |
| pooled = (seq_hidden * weights).sum(dim=1) |
| else: |
| |
| mask_expanded = attention_mask.unsqueeze(-1).float() |
| sum_hidden = (seq_hidden * mask_expanded).sum(dim=1) |
| sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9) |
| pooled = sum_hidden / sum_mask |
| |
| |
| logits = self.classifier(pooled) |
| |
| return logits, dist_loss |
|
|
|
|
| def get_config_from_checkpoint(checkpoint_path: str, device: str) -> MultimodalGlycanBERTConfig: |
| """Extract config from checkpoint.""" |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
| |
| if 'config' in checkpoint: |
| config_dict = checkpoint['config'] |
| if 'model' in config_dict: |
| model_cfg = config_dict['model'] |
| seq_cfg = model_cfg.get('sequence', {}) |
| ms_cfg = model_cfg.get('mass_spectrometry', model_cfg.get('ms', {})) |
| struct_cfg = model_cfg.get('structure_3d', model_cfg.get('structure', {})) |
| fusion_cfg = model_cfg.get('fusion', {}) |
| |
| return MultimodalGlycanBERTConfig( |
| seq_vocab_size=seq_cfg.get('vocab_size', 166), |
| seq_hidden_size=seq_cfg.get('hidden_size', 768), |
| seq_num_layers=seq_cfg.get('num_hidden_layers', 12), |
| seq_num_heads=seq_cfg.get('num_attention_heads', 12), |
| seq_max_length=seq_cfg.get('max_length', 512), |
| ms_vocab_size=ms_cfg.get('vocab_size', 242), |
| ms_hidden_size=ms_cfg.get('hidden_size', 256), |
| ms_num_layers=ms_cfg.get('num_hidden_layers', 4), |
| ms_num_heads=ms_cfg.get('num_attention_heads', 4), |
| ms_max_length=ms_cfg.get('max_length', 100), |
| struct_vocab_size=struct_cfg.get('vocab_size', 1024), |
| struct_hidden_size=struct_cfg.get('hidden_size', 256), |
| struct_num_layers=struct_cfg.get('num_hidden_layers', 4), |
| struct_num_heads=struct_cfg.get('num_attention_heads', 4), |
| struct_max_length=struct_cfg.get('max_length', 100), |
| use_3d=struct_cfg.get('enabled', struct_cfg.get('use_3d', True)), |
| fusion_hidden_size=fusion_cfg.get('fusion_hidden_size', 512), |
| ) |
| else: |
| return MultimodalGlycanBERTConfig(**config_dict) |
| |
| return MultimodalGlycanBERTConfig() |
|
|
|
|
| def load_pretrained_bert(checkpoint_path: str, config: MultimodalGlycanBERTConfig, device: str) -> MultimodalGlycanBERT: |
| """Load pre-trained BERT from checkpoint using provided config.""" |
| logger.info(f"Loading checkpoint from {checkpoint_path}") |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
| |
| |
| bert = MultimodalGlycanBERT(config) |
| |
| |
| if 'model_state_dict' in checkpoint: |
| bert.load_state_dict(checkpoint['model_state_dict'], strict=False) |
| else: |
| bert.load_state_dict(checkpoint, strict=False) |
| |
| logger.info("Loaded pre-trained BERT successfully") |
| return bert |
|
|
|
|
|
|
| def train_epoch( |
| model: GlycanClassifier, |
| train_loader: DataLoader, |
| optimizer: AdamW, |
| criterion: nn.Module, |
| device: str, |
| scheduler=None, |
| dist_alpha: float = 0.5, |
| ) -> dict: |
| """Train for one epoch.""" |
| model.train() |
| total_loss = 0 |
| all_preds = [] |
| all_labels = [] |
| |
| pbar = tqdm(train_loader, desc="Training") |
| for batch in pbar: |
| token_ids = batch['token_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| residue_ids = batch['residue_ids'].to(device) if 'residue_ids' in batch else None |
| branch_depths = batch['branch_depths'].to(device) if 'branch_depths' in batch else None |
| linkage_types = batch['linkage_types'].to(device) if 'linkage_types' in batch else None |
| dist_labels = batch['dist_labels'].to(device) if 'dist_labels' in batch else None |
| labels = batch['label'].to(device) |
| |
| optimizer.zero_grad() |
| |
| logits, dist_loss = model( |
| token_ids, attention_mask, residue_ids, |
| branch_depths=branch_depths, linkage_types=linkage_types, |
| dist_labels=dist_labels |
| ) |
| |
| |
| cls_loss = criterion(logits, labels) |
| |
| |
| |
| total_batch_loss = cls_loss + dist_alpha * dist_loss |
| |
| total_batch_loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| optimizer.step() |
| |
| if scheduler: |
| scheduler.step() |
| |
| |
| total_loss += total_batch_loss.item() |
| preds = logits.argmax(dim=1).cpu().numpy() |
| all_preds.extend(preds) |
| all_labels.extend(labels.cpu().numpy()) |
| |
| pbar.set_postfix({'loss': f'{total_batch_loss.item():.4f}', 'dist': f'{dist_loss:.4f}' if isinstance(dist_loss, float) else f'{dist_loss.item():.4f}'}) |
| |
| avg_loss = total_loss / len(train_loader) |
| accuracy = accuracy_score(all_labels, all_preds) |
| |
| return { |
| 'loss': avg_loss, |
| 'accuracy': accuracy, |
| } |
|
|
|
|
| def evaluate( |
| model: GlycanClassifier, |
| data_loader: DataLoader, |
| criterion: nn.Module, |
| device: str, |
| num_classes: int = None, |
| dist_alpha: float = 0.5, |
| ) -> dict: |
| """Evaluate model on dataset.""" |
| model.eval() |
| total_loss = 0 |
| all_preds = [] |
| all_labels = [] |
| all_probs = [] |
| |
| with torch.no_grad(): |
| for batch in tqdm(data_loader, desc="Evaluating"): |
| token_ids = batch['token_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| residue_ids = batch['residue_ids'].to(device) if 'residue_ids' in batch else None |
| branch_depths = batch['branch_depths'].to(device) if 'branch_depths' in batch else None |
| linkage_types = batch['linkage_types'].to(device) if 'linkage_types' in batch else None |
| dist_labels = batch['dist_labels'].to(device) if 'dist_labels' in batch else None |
| labels = batch['label'].to(device) |
| |
| logits, dist_loss = model( |
| token_ids, attention_mask, residue_ids, |
| branch_depths=branch_depths, linkage_types=linkage_types, |
| dist_labels=dist_labels |
| ) |
| cls_loss = criterion(logits, labels) |
| |
| loss = cls_loss + dist_alpha * dist_loss |
| |
| total_loss += loss.item() |
| probs = torch.softmax(logits, dim=1).cpu().numpy() |
| preds = logits.argmax(dim=1).cpu().numpy() |
| all_preds.extend(preds) |
| all_labels.extend(labels.cpu().numpy()) |
| all_probs.extend(probs) |
| |
| avg_loss = total_loss / len(data_loader) |
| accuracy = accuracy_score(all_labels, all_preds) |
| f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0) |
| f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0) |
| mcc = matthews_corrcoef(all_labels, all_preds) |
| |
| |
| auroc = None |
| auprc = None |
| all_probs = np.array(all_probs) |
| all_labels_arr = np.array(all_labels) |
| |
| try: |
| from sklearn.metrics import roc_auc_score, average_precision_score |
| from sklearn.preprocessing import label_binarize |
| |
| |
| unique_classes = np.unique(all_labels_arr) |
| |
| if len(unique_classes) == 2: |
| |
| auroc = roc_auc_score(all_labels_arr, all_probs[:, 1]) |
| auprc = average_precision_score(all_labels_arr, all_probs[:, 1]) |
| elif len(unique_classes) > 2 and num_classes is not None: |
| |
| |
| if len(unique_classes) == num_classes: |
| auroc = roc_auc_score(all_labels_arr, all_probs, multi_class='ovr', average='macro') |
| |
| labels_bin = label_binarize(all_labels_arr, classes=list(range(num_classes))) |
| auprc = average_precision_score(labels_bin, all_probs, average='macro') |
| else: |
| |
| auroc = roc_auc_score(all_labels_arr, all_probs, multi_class='ovr', |
| average='macro', labels=unique_classes) |
| except Exception as e: |
| |
| pass |
| |
| return { |
| 'loss': avg_loss, |
| 'accuracy': accuracy, |
| 'f1_macro': f1_macro, |
| 'f1_weighted': f1_weighted, |
| 'mcc': mcc, |
| 'auroc': auroc, |
| 'auprc': auprc, |
| 'predictions': all_preds, |
| 'labels': all_labels, |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Fine-tune Glycan BERT for classification') |
| |
| |
| parser.add_argument('--task', type=str, required=True, |
| help='Task name (e.g., species, phylum)') |
| parser.add_argument('--data_path', type=str, required=True, |
| help='Path to CSV data file') |
| parser.add_argument('--checkpoint', type=str, required=True, |
| help='Path to pre-trained model checkpoint') |
| parser.add_argument('--vocab', type=str, required=True, |
| help='Path to vocabulary.json') |
| parser.add_argument('--output_dir', type=str, required=True, |
| help='Output directory for results') |
| |
| |
| parser.add_argument('--batch_size', type=int, default=256, |
| help='Batch size (matching GlycanML for stable gradients)') |
| parser.add_argument('--epochs', type=int, default=50) |
| parser.add_argument('--lr', type=float, default=5e-5) |
| parser.add_argument('--weight_decay', type=float, default=0.01) |
| parser.add_argument('--dropout', type=float, default=0.25, |
| help='Dropout rate (increased from 0.1 to combat overfitting)') |
| parser.add_argument('--freeze_layers', type=int, default=8, |
| help='Number of bottom layers to freeze (increased from 4 to prevent overfitting)') |
| parser.add_argument('--pooling_strategy', type=str, default='attention', |
| choices=['mean', 'first', 'max', 'attention', 'mono'], |
| help='Pooling strategy: attention (recommended), mono (residue-level), mean, first (CLS-style), max') |
| parser.add_argument('--max_length', type=int, default=256) |
| parser.add_argument('--patience', type=int, default=10, |
| help='Early stopping patience') |
| parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') |
| parser.add_argument('--seed', type=int, default=42) |
| parser.add_argument('--num_workers', type=int, default=4) |
| parser.add_argument('--filter_mode', type=str, default='none', |
| choices=['none', 'strict', 'strict_3', 'strict_5'], |
| help='Class filtering: none (use all), strict (n=1), strict_3 (n=3), strict_5 (n=5)') |
| parser.add_argument('--dist_alpha', type=float, default=0.0, |
| help='Weight for auxiliary distance/topology loss (default: 0.0 disabled, set >0 to enable)') |
| |
| args = parser.parse_args() |
| |
| |
| set_seed(args.seed) |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| logger.info("=" * 70) |
| logger.info(f"FINE-TUNING GLYCAN BERT ON {args.task.upper()}") |
| logger.info("=" * 70) |
| logger.info(f"Data: {args.data_path}") |
| logger.info(f"Checkpoint: {args.checkpoint}") |
| logger.info(f"Output: {args.output_dir}") |
| logger.info(f"Device: {args.device}") |
| logger.info(f"Seed: {args.seed}") |
| logger.info(f"Filter mode: {args.filter_mode}") |
| |
| |
| valid_classes = None |
| if args.filter_mode != 'none': |
| |
| if args.filter_mode == 'strict': |
| min_samples = 1 |
| elif args.filter_mode == 'strict_3': |
| min_samples = 3 |
| elif args.filter_mode == 'strict_5': |
| min_samples = 5 |
| else: |
| min_samples = 1 |
| |
| logger.info(f"\nComputing valid classes ({args.filter_mode} mode, min_samples={min_samples})...") |
| valid_classes = compute_valid_classes(args.data_path, args.task, min_samples=min_samples) |
| logger.info(f" Will use {len(valid_classes)} classes present in all splits") |
| |
| |
| logger.info("\nChecking model capacity from checkpoint...") |
| checkpoint_config = get_config_from_checkpoint(args.checkpoint, 'cpu') |
| model_max_length = checkpoint_config.seq_max_length |
| |
| |
| if args.max_length > model_max_length: |
| logger.warning(f" Requested max_length ({args.max_length}) exceeds model capacity ({model_max_length}).") |
| logger.warning(f" Overriding max_length to {model_max_length} to prevent size mismatch errors.") |
| dataset_max_length = model_max_length |
| else: |
| dataset_max_length = args.max_length |
| |
| |
| logger.info(f"\nLoading data (max_length={dataset_max_length})...") |
| train_dataset = GlycanClassificationDataset( |
| args.data_path, args.task, 'train', args.vocab, dataset_max_length, |
| valid_classes=valid_classes |
| ) |
| val_dataset = GlycanClassificationDataset( |
| args.data_path, args.task, 'validation', args.vocab, dataset_max_length, |
| valid_classes=valid_classes |
| ) |
| test_dataset = GlycanClassificationDataset( |
| args.data_path, args.task, 'test', args.vocab, dataset_max_length, |
| valid_classes=valid_classes |
| ) |
| |
| logger.info(f"\nDataset summary:") |
| logger.info(f" Train: {len(train_dataset)} samples") |
| logger.info(f" Val: {len(val_dataset)} samples") |
| logger.info(f" Test: {len(test_dataset)} samples") |
| logger.info(f" Classes: {len(train_dataset.unique_labels)}") |
| |
| |
| if args.filter_mode == 'none': |
| train_classes = set(train_dataset.unique_labels) |
| val_classes = set(val_dataset.unique_labels) |
| test_classes = set(test_dataset.unique_labels) |
| common_classes = train_classes & val_classes & test_classes |
| logger.info(f"\nClass distribution (filter_mode=none):") |
| logger.info(f" Train-only classes: {len(train_classes - common_classes)}") |
| logger.info(f" Val-only classes: {len(val_classes - train_classes - test_classes)}") |
| logger.info(f" Test-only classes: {len(test_classes - train_classes - val_classes)}") |
| logger.info(f" Common to all: {len(common_classes)}") |
| |
| |
| train_loader = DataLoader( |
| train_dataset, batch_size=args.batch_size, shuffle=True, |
| num_workers=args.num_workers, pin_memory=True |
| ) |
| val_loader = DataLoader( |
| val_dataset, batch_size=args.batch_size, shuffle=False, |
| num_workers=args.num_workers, pin_memory=True |
| ) |
| test_loader = DataLoader( |
| test_dataset, batch_size=args.batch_size, shuffle=False, |
| num_workers=args.num_workers, pin_memory=True |
| ) |
| |
| |
| logger.info("\nLoading model...") |
| bert = load_pretrained_bert(args.checkpoint, checkpoint_config, args.device) |
| |
| num_classes = len(train_dataset.unique_labels) |
| model = GlycanClassifier( |
| bert, num_classes, |
| dropout=args.dropout, |
| freeze_layers=args.freeze_layers, |
| pooling_strategy=args.pooling_strategy, |
| ).to(args.device) |
| logger.info(f" Pooling strategy: {args.pooling_strategy}") |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| logger.info(f" Total params: {total_params:,}") |
| logger.info(f" Trainable params: {trainable_params:,} ({trainable_params/total_params*100:.1f}%)") |
| |
| |
| criterion = nn.CrossEntropyLoss() |
| optimizer = AdamW( |
| filter(lambda p: p.requires_grad, model.parameters()), |
| lr=args.lr, |
| weight_decay=args.weight_decay |
| ) |
| |
| total_steps = len(train_loader) * args.epochs |
| scheduler = CosineAnnealingLR(optimizer, T_max=total_steps) |
| |
| |
| logger.info("\n" + "=" * 70) |
| logger.info("TRAINING") |
| logger.info("=" * 70) |
| |
| best_val_mcc = -1 |
| epochs_without_improvement = 0 |
| history = [] |
| |
| for epoch in range(args.epochs): |
| logger.info(f"\nEpoch {epoch + 1}/{args.epochs}") |
| |
| |
| train_metrics = train_epoch(model, train_loader, optimizer, criterion, args.device, scheduler, dist_alpha=args.dist_alpha) |
| |
| |
| val_metrics = evaluate(model, val_loader, criterion, args.device, num_classes, dist_alpha=args.dist_alpha) |
| |
| logger.info(f" Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.4f}") |
| val_log = f" Val - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, " |
| val_log += f"F1: {val_metrics['f1_macro']:.4f}, MCC: {val_metrics['mcc']:.4f}" |
| if val_metrics['auroc'] is not None: |
| val_log += f", AUROC: {val_metrics['auroc']:.4f}" |
| logger.info(val_log) |
| |
| history.append({ |
| 'epoch': epoch + 1, |
| 'train_loss': train_metrics['loss'], |
| 'train_acc': train_metrics['accuracy'], |
| 'val_loss': val_metrics['loss'], |
| 'val_acc': val_metrics['accuracy'], |
| 'val_f1': val_metrics['f1_macro'], |
| 'val_mcc': val_metrics['mcc'], |
| 'val_auroc': val_metrics['auroc'], |
| 'val_auprc': val_metrics['auprc'], |
| }) |
| |
| |
| if val_metrics['mcc'] > best_val_mcc: |
| best_val_mcc = val_metrics['mcc'] |
| epochs_without_improvement = 0 |
| |
| |
| torch.save({ |
| 'epoch': epoch + 1, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'val_mcc': best_val_mcc, |
| 'config': { |
| 'task': args.task, |
| 'num_classes': num_classes, |
| 'classes': train_dataset.unique_labels, |
| } |
| }, os.path.join(args.output_dir, 'best_model.pt')) |
| |
| logger.info(f" New best MCC: {best_val_mcc:.4f} (saved)") |
| else: |
| epochs_without_improvement += 1 |
| logger.info(f" No improvement ({epochs_without_improvement}/{args.patience})") |
| |
| |
| if epochs_without_improvement >= args.patience: |
| logger.info(f"\nEarly stopping at epoch {epoch + 1}") |
| break |
| |
| |
| logger.info("\n" + "=" * 70) |
| logger.info("TESTING") |
| logger.info("=" * 70) |
| |
| best_checkpoint = torch.load(os.path.join(args.output_dir, 'best_model.pt')) |
| model.load_state_dict(best_checkpoint['model_state_dict']) |
| |
| test_metrics = evaluate(model, test_loader, criterion, args.device, num_classes) |
| |
| logger.info(f"\nTest Results:") |
| logger.info(f" Accuracy: {test_metrics['accuracy']:.4f}") |
| logger.info(f" F1-Macro: {test_metrics['f1_macro']:.4f}") |
| logger.info(f" F1-Weighted: {test_metrics['f1_weighted']:.4f}") |
| logger.info(f" MCC: {test_metrics['mcc']:.4f}") |
| if test_metrics['auroc'] is not None: |
| logger.info(f" AUROC: {test_metrics['auroc']:.4f}") |
| if test_metrics['auprc'] is not None: |
| logger.info(f" AUPRC: {test_metrics['auprc']:.4f}") |
| |
| |
| results = { |
| 'task': args.task, |
| 'filter_mode': args.filter_mode, |
| 'num_classes': num_classes, |
| 'classes': train_dataset.unique_labels, |
| 'train_samples': len(train_dataset), |
| 'val_samples': len(val_dataset), |
| 'test_samples': len(test_dataset), |
| 'best_epoch': best_checkpoint['epoch'], |
| 'test_accuracy': test_metrics['accuracy'], |
| 'test_f1_macro': test_metrics['f1_macro'], |
| 'test_f1_weighted': test_metrics['f1_weighted'], |
| 'test_mcc': test_metrics['mcc'], |
| 'test_auroc': test_metrics['auroc'], |
| 'test_auprc': test_metrics['auprc'], |
| 'config': vars(args), |
| 'history': history, |
| } |
| |
| with open(os.path.join(args.output_dir, 'results.json'), 'w') as f: |
| json.dump(results, f, indent=2) |
| |
| logger.info(f"\nResults saved to {args.output_dir}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|
|
|