supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
29.7 kB
#!/usr/bin/env python3
"""
Unified Fine-tuning Script for Glycan Classification
This script fine-tunes a pre-trained Multimodal Glycan BERT model
on taxonomy classification tasks (domain, kingdom, phylum, class,
order, family, genus, species) and property prediction tasks
(immunogenicity, link).
Usage:
python downstream_tasks/finetune.py \
--task species \
--data_path downstream_tasks/glycan_classification_with_wurcs.csv \
--checkpoint checkpoints/best_multimodal_v3_model.pt \
--vocab data/vocabulary.json \
--output_dir downstream_tasks/results/species
"""
import argparse
import json
import logging
import os
import random
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import math
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import (
accuracy_score, f1_score, precision_score, recall_score,
matthews_corrcoef, classification_report
)
# Add parent to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERT, MultimodalGlycanBERTConfig
from downstream_tasks.utils.tokenizer import WURCSTokenizer
from downstream_tasks.utils.dataset import GlycanClassificationDataset, compute_valid_classes
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def set_seed(seed: int):
"""Set random seeds for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class GlycanClassifier(nn.Module):
"""
Classification head on top of pre-trained BERT.
Improvements:
- Attention pooling (works better than first-token for WURCS sequences)
- Mono pooling (pool at monosaccharide level using residue_ids)
- Reduced frozen layers (4 vs 8) for better adaptation
"""
def __init__(
self,
bert: MultimodalGlycanBERT,
num_classes: int,
dropout: float = 0.25, # Increased from 0.1 to combat overfitting
freeze_layers: int = 8, # Increased from 4 to prevent overfitting
pooling_strategy: str = "attention", # "mean", "first", "max", "attention", "mono"
):
super().__init__()
self.bert = bert
self.num_classes = num_classes
self.pooling_strategy = pooling_strategy
# Freeze bottom layers
for i, layer in enumerate(self.bert.seq_layers):
if i < freeze_layers:
for param in layer.parameters():
param.requires_grad = False
# Classification head (use sequence hidden size)
hidden_size = bert.config.seq_hidden_size
# Attention pooling layer (if using attention or mono strategy)
if pooling_strategy in ["attention", "mono"]:
self.attention_weights = nn.Linear(hidden_size, 1)
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(hidden_size, hidden_size // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_size // 2, num_classes),
)
def forward(self, token_ids, attention_mask, residue_ids=None, **kwargs):
"""
Forward pass for classification.
Args:
token_ids: (batch, seq_len) - Token IDs
attention_mask: (batch, seq_len) - Attention mask
residue_ids: (batch, seq_len) - Residue ID for each token (optional, for mono pooling)
Returns:
logits: (batch, num_classes)
"""
# Get sequence embeddings with branch/linkage info if available
# Check if seq_embeddings supports the new parameters
if hasattr(self.bert.seq_embeddings, 'branch_embeddings'):
branch_depths = kwargs.get('branch_depths')
linkage_types = kwargs.get('linkage_types')
seq_hidden = self.bert.seq_embeddings(token_ids, branch_depths, linkage_types)
else:
seq_hidden = self.bert.seq_embeddings(token_ids)
# Apply transformer layers
for layer in self.bert.seq_layers:
seq_hidden = layer(seq_hidden, attention_mask)
# Optional: Compute auxiliary distance reconstruction loss (topology)
dist_loss = 0.0
dist_labels = kwargs.get('dist_labels')
if dist_labels is not None:
# Predict distances using the pre-trained distance head
# (which was ignored in previous fine-tuning versions)
dist_predictions = self.bert.distance_head(seq_hidden) # (batch, seq, seq)
# Mask out padding (-1)
# labels shape: (batch, seq, seq)
# predictions shape: (batch, seq, seq)
# Ensure proper casting and device
dist_labels = dist_labels.to(dist_predictions.device)
mask = dist_labels != -1
if mask.any():
# Compute MSE loss on valid distances
# We cast labels to float for MSE
loss_fct = nn.MSELoss()
dist_loss = loss_fct(dist_predictions[mask], dist_labels[mask].float())
# Pool based on strategy
if self.pooling_strategy == "first":
# Original: Use first token (CLS-style)
pooled = seq_hidden[:, 0, :]
elif self.pooling_strategy == "max":
# Max pooling over sequence
mask_expanded = attention_mask.unsqueeze(-1).float()
seq_hidden_masked = seq_hidden * mask_expanded + (1 - mask_expanded) * -1e9
pooled = seq_hidden_masked.max(dim=1)[0]
elif self.pooling_strategy == "mono" and residue_ids is not None:
# Monosaccharide-level pooling: pool tokens within each residue, then attention over residues
batch_size = seq_hidden.size(0)
hidden_size = seq_hidden.size(-1)
# First, pool within each residue using mean
pooled_residues = []
max_residues = 32 # Max number of residues per glycan
for b in range(batch_size):
residue_reps = []
unique_res = torch.unique(residue_ids[b])
# Filter to actual residues (>= 0)
unique_res = unique_res[unique_res >= 0]
for rid in unique_res[:max_residues]:
mask = (residue_ids[b] == rid).float()
if mask.sum() > 0:
res_rep = (seq_hidden[b] * mask.unsqueeze(-1)).sum(dim=0) / mask.sum()
residue_reps.append(res_rep)
if len(residue_reps) == 0:
# Fallback to mean pooling
mask_expanded = attention_mask[b].unsqueeze(-1).float()
pooled_residues.append((seq_hidden[b] * mask_expanded).sum(dim=0) / mask_expanded.sum())
else:
# Stack residue representations and apply attention
res_stack = torch.stack(residue_reps, dim=0) # (num_res, hidden)
scores = self.attention_weights(res_stack).squeeze(-1) # (num_res,)
weights = torch.softmax(scores, dim=0).unsqueeze(-1) # (num_res, 1)
pooled_residues.append((res_stack * weights).sum(dim=0))
pooled = torch.stack(pooled_residues, dim=0) # (batch, hidden)
elif self.pooling_strategy == "attention":
# Attention-weighted pooling
scores = self.attention_weights(seq_hidden).squeeze(-1) # (batch, seq_len)
scores = scores.masked_fill(attention_mask == 0, -1e9)
weights = torch.softmax(scores, dim=1).unsqueeze(-1) # (batch, seq_len, 1)
pooled = (seq_hidden * weights).sum(dim=1)
else: # "mean" - default
# Mean pooling over non-padding tokens (recommended for WURCS)
mask_expanded = attention_mask.unsqueeze(-1).float()
sum_hidden = (seq_hidden * mask_expanded).sum(dim=1)
sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9)
pooled = sum_hidden / sum_mask
# Classify
logits = self.classifier(pooled)
return logits, dist_loss
def get_config_from_checkpoint(checkpoint_path: str, device: str) -> MultimodalGlycanBERTConfig:
"""Extract config from checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=device)
if 'config' in checkpoint:
config_dict = checkpoint['config']
if 'model' in config_dict:
model_cfg = config_dict['model']
seq_cfg = model_cfg.get('sequence', {})
ms_cfg = model_cfg.get('mass_spectrometry', model_cfg.get('ms', {}))
struct_cfg = model_cfg.get('structure_3d', model_cfg.get('structure', {}))
fusion_cfg = model_cfg.get('fusion', {})
return MultimodalGlycanBERTConfig(
seq_vocab_size=seq_cfg.get('vocab_size', 166),
seq_hidden_size=seq_cfg.get('hidden_size', 768),
seq_num_layers=seq_cfg.get('num_hidden_layers', 12),
seq_num_heads=seq_cfg.get('num_attention_heads', 12),
seq_max_length=seq_cfg.get('max_length', 512),
ms_vocab_size=ms_cfg.get('vocab_size', 242),
ms_hidden_size=ms_cfg.get('hidden_size', 256),
ms_num_layers=ms_cfg.get('num_hidden_layers', 4),
ms_num_heads=ms_cfg.get('num_attention_heads', 4),
ms_max_length=ms_cfg.get('max_length', 100),
struct_vocab_size=struct_cfg.get('vocab_size', 1024),
struct_hidden_size=struct_cfg.get('hidden_size', 256),
struct_num_layers=struct_cfg.get('num_hidden_layers', 4),
struct_num_heads=struct_cfg.get('num_attention_heads', 4),
struct_max_length=struct_cfg.get('max_length', 100),
use_3d=struct_cfg.get('enabled', struct_cfg.get('use_3d', True)),
fusion_hidden_size=fusion_cfg.get('fusion_hidden_size', 512),
)
else:
return MultimodalGlycanBERTConfig(**config_dict)
return MultimodalGlycanBERTConfig()
def load_pretrained_bert(checkpoint_path: str, config: MultimodalGlycanBERTConfig, device: str) -> MultimodalGlycanBERT:
"""Load pre-trained BERT from checkpoint using provided config."""
logger.info(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)
# Create model
bert = MultimodalGlycanBERT(config)
# Load weights with strict=False to handle any minor mismatches
if 'model_state_dict' in checkpoint:
bert.load_state_dict(checkpoint['model_state_dict'], strict=False)
else:
bert.load_state_dict(checkpoint, strict=False)
logger.info("Loaded pre-trained BERT successfully")
return bert
def train_epoch(
model: GlycanClassifier,
train_loader: DataLoader,
optimizer: AdamW,
criterion: nn.Module,
device: str,
scheduler=None,
dist_alpha: float = 0.5,
) -> dict:
"""Train for one epoch."""
model.train()
total_loss = 0
all_preds = []
all_labels = []
pbar = tqdm(train_loader, desc="Training")
for batch in pbar:
token_ids = batch['token_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
residue_ids = batch['residue_ids'].to(device) if 'residue_ids' in batch else None
branch_depths = batch['branch_depths'].to(device) if 'branch_depths' in batch else None # NEW
linkage_types = batch['linkage_types'].to(device) if 'linkage_types' in batch else None # NEW
dist_labels = batch['dist_labels'].to(device) if 'dist_labels' in batch else None # NEW (Topology)
labels = batch['label'].to(device)
optimizer.zero_grad()
logits, dist_loss = model(
token_ids, attention_mask, residue_ids,
branch_depths=branch_depths, linkage_types=linkage_types,
dist_labels=dist_labels
)
# Main task loss
cls_loss = criterion(logits, labels)
# Total loss = Classification Loss + alpha * Topology Loss
# We weight topology loss to avoid overwhelming the main task
total_batch_loss = cls_loss + dist_alpha * dist_loss
total_batch_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
if scheduler:
scheduler.step()
total_loss += total_batch_loss.item()
preds = logits.argmax(dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.cpu().numpy())
pbar.set_postfix({'loss': f'{total_batch_loss.item():.4f}', 'dist': f'{dist_loss:.4f}' if isinstance(dist_loss, float) else f'{dist_loss.item():.4f}'})
avg_loss = total_loss / len(train_loader)
accuracy = accuracy_score(all_labels, all_preds)
return {
'loss': avg_loss,
'accuracy': accuracy,
}
def evaluate(
model: GlycanClassifier,
data_loader: DataLoader,
criterion: nn.Module,
device: str,
num_classes: int = None,
dist_alpha: float = 0.5,
) -> dict:
"""Evaluate model on dataset."""
model.eval()
total_loss = 0
all_preds = []
all_labels = []
all_probs = [] # Store probabilities for AUROC/AUPRC
with torch.no_grad():
for batch in tqdm(data_loader, desc="Evaluating"):
token_ids = batch['token_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
residue_ids = batch['residue_ids'].to(device) if 'residue_ids' in batch else None
branch_depths = batch['branch_depths'].to(device) if 'branch_depths' in batch else None # NEW
linkage_types = batch['linkage_types'].to(device) if 'linkage_types' in batch else None # NEW
dist_labels = batch['dist_labels'].to(device) if 'dist_labels' in batch else None # NEW
labels = batch['label'].to(device)
logits, dist_loss = model(
token_ids, attention_mask, residue_ids,
branch_depths=branch_depths, linkage_types=linkage_types,
dist_labels=dist_labels
)
cls_loss = criterion(logits, labels)
loss = cls_loss + dist_alpha * dist_loss
total_loss += loss.item()
probs = torch.softmax(logits, dim=1).cpu().numpy()
preds = logits.argmax(dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.cpu().numpy())
all_probs.extend(probs)
avg_loss = total_loss / len(data_loader)
accuracy = accuracy_score(all_labels, all_preds)
f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
mcc = matthews_corrcoef(all_labels, all_preds)
# Compute AUROC and AUPRC (for multi-class: one-vs-rest)
auroc = None
auprc = None
all_probs = np.array(all_probs)
all_labels_arr = np.array(all_labels)
try:
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.preprocessing import label_binarize
# Get unique classes present in labels
unique_classes = np.unique(all_labels_arr)
if len(unique_classes) == 2:
# Binary classification
auroc = roc_auc_score(all_labels_arr, all_probs[:, 1])
auprc = average_precision_score(all_labels_arr, all_probs[:, 1])
elif len(unique_classes) > 2 and num_classes is not None:
# Multi-class: use one-vs-rest
# Only compute if all classes are present in test set
if len(unique_classes) == num_classes:
auroc = roc_auc_score(all_labels_arr, all_probs, multi_class='ovr', average='macro')
# AUPRC for multi-class: binarize labels
labels_bin = label_binarize(all_labels_arr, classes=list(range(num_classes)))
auprc = average_precision_score(labels_bin, all_probs, average='macro')
else:
# Some classes missing - compute on available classes
auroc = roc_auc_score(all_labels_arr, all_probs, multi_class='ovr',
average='macro', labels=unique_classes)
except Exception as e:
# AUROC/AUPRC may fail with certain class distributions
pass
return {
'loss': avg_loss,
'accuracy': accuracy,
'f1_macro': f1_macro,
'f1_weighted': f1_weighted,
'mcc': mcc,
'auroc': auroc,
'auprc': auprc,
'predictions': all_preds,
'labels': all_labels,
}
def main():
parser = argparse.ArgumentParser(description='Fine-tune Glycan BERT for classification')
# Required arguments
parser.add_argument('--task', type=str, required=True,
help='Task name (e.g., species, phylum)')
parser.add_argument('--data_path', type=str, required=True,
help='Path to CSV data file')
parser.add_argument('--checkpoint', type=str, required=True,
help='Path to pre-trained model checkpoint')
parser.add_argument('--vocab', type=str, required=True,
help='Path to vocabulary.json')
parser.add_argument('--output_dir', type=str, required=True,
help='Output directory for results')
# Optional arguments
parser.add_argument('--batch_size', type=int, default=256,
help='Batch size (matching GlycanML for stable gradients)')
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--lr', type=float, default=5e-5)
parser.add_argument('--weight_decay', type=float, default=0.01)
parser.add_argument('--dropout', type=float, default=0.25,
help='Dropout rate (increased from 0.1 to combat overfitting)')
parser.add_argument('--freeze_layers', type=int, default=8,
help='Number of bottom layers to freeze (increased from 4 to prevent overfitting)')
parser.add_argument('--pooling_strategy', type=str, default='attention',
choices=['mean', 'first', 'max', 'attention', 'mono'],
help='Pooling strategy: attention (recommended), mono (residue-level), mean, first (CLS-style), max')
parser.add_argument('--max_length', type=int, default=256)
parser.add_argument('--patience', type=int, default=10,
help='Early stopping patience')
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--filter_mode', type=str, default='none',
choices=['none', 'strict', 'strict_3', 'strict_5'],
help='Class filtering: none (use all), strict (n=1), strict_3 (n=3), strict_5 (n=5)')
parser.add_argument('--dist_alpha', type=float, default=0.0,
help='Weight for auxiliary distance/topology loss (default: 0.0 disabled, set >0 to enable)')
args = parser.parse_args()
# Setup
set_seed(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
# Log configuration
logger.info("=" * 70)
logger.info(f"FINE-TUNING GLYCAN BERT ON {args.task.upper()}")
logger.info("=" * 70)
logger.info(f"Data: {args.data_path}")
logger.info(f"Checkpoint: {args.checkpoint}")
logger.info(f"Output: {args.output_dir}")
logger.info(f"Device: {args.device}")
logger.info(f"Seed: {args.seed}")
logger.info(f"Filter mode: {args.filter_mode}")
# Compute valid classes if using strict filtering
valid_classes = None
if args.filter_mode != 'none':
# Determine min_samples from filter mode
if args.filter_mode == 'strict':
min_samples = 1
elif args.filter_mode == 'strict_3':
min_samples = 3
elif args.filter_mode == 'strict_5':
min_samples = 5
else:
min_samples = 1
logger.info(f"\nComputing valid classes ({args.filter_mode} mode, min_samples={min_samples})...")
valid_classes = compute_valid_classes(args.data_path, args.task, min_samples=min_samples)
logger.info(f" Will use {len(valid_classes)} classes present in all splits")
# Load config from checkpoint to get model capacity
logger.info("\nChecking model capacity from checkpoint...")
checkpoint_config = get_config_from_checkpoint(args.checkpoint, 'cpu')
model_max_length = checkpoint_config.seq_max_length
# Override max_length if it exceeds model capacity
if args.max_length > model_max_length:
logger.warning(f" Requested max_length ({args.max_length}) exceeds model capacity ({model_max_length}).")
logger.warning(f" Overriding max_length to {model_max_length} to prevent size mismatch errors.")
dataset_max_length = model_max_length
else:
dataset_max_length = args.max_length
# Load data
logger.info(f"\nLoading data (max_length={dataset_max_length})...")
train_dataset = GlycanClassificationDataset(
args.data_path, args.task, 'train', args.vocab, dataset_max_length,
valid_classes=valid_classes
)
val_dataset = GlycanClassificationDataset(
args.data_path, args.task, 'validation', args.vocab, dataset_max_length,
valid_classes=valid_classes
)
test_dataset = GlycanClassificationDataset(
args.data_path, args.task, 'test', args.vocab, dataset_max_length,
valid_classes=valid_classes
)
logger.info(f"\nDataset summary:")
logger.info(f" Train: {len(train_dataset)} samples")
logger.info(f" Val: {len(val_dataset)} samples")
logger.info(f" Test: {len(test_dataset)} samples")
logger.info(f" Classes: {len(train_dataset.unique_labels)}")
# Report class filtering stats
if args.filter_mode == 'none':
train_classes = set(train_dataset.unique_labels)
val_classes = set(val_dataset.unique_labels)
test_classes = set(test_dataset.unique_labels)
common_classes = train_classes & val_classes & test_classes
logger.info(f"\nClass distribution (filter_mode=none):")
logger.info(f" Train-only classes: {len(train_classes - common_classes)}")
logger.info(f" Val-only classes: {len(val_classes - train_classes - test_classes)}")
logger.info(f" Test-only classes: {len(test_classes - train_classes - val_classes)}")
logger.info(f" Common to all: {len(common_classes)}")
# Create dataloaders
train_loader = DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=True
)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=True
)
test_loader = DataLoader(
test_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=True
)
# Load model
logger.info("\nLoading model...")
bert = load_pretrained_bert(args.checkpoint, checkpoint_config, args.device)
num_classes = len(train_dataset.unique_labels)
model = GlycanClassifier(
bert, num_classes,
dropout=args.dropout,
freeze_layers=args.freeze_layers,
pooling_strategy=args.pooling_strategy,
).to(args.device)
logger.info(f" Pooling strategy: {args.pooling_strategy}")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f" Total params: {total_params:,}")
logger.info(f" Trainable params: {trainable_params:,} ({trainable_params/total_params*100:.1f}%)")
# Setup training
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr,
weight_decay=args.weight_decay
)
total_steps = len(train_loader) * args.epochs
scheduler = CosineAnnealingLR(optimizer, T_max=total_steps)
# Training loop
logger.info("\n" + "=" * 70)
logger.info("TRAINING")
logger.info("=" * 70)
best_val_mcc = -1
epochs_without_improvement = 0
history = []
for epoch in range(args.epochs):
logger.info(f"\nEpoch {epoch + 1}/{args.epochs}")
# Train
train_metrics = train_epoch(model, train_loader, optimizer, criterion, args.device, scheduler, dist_alpha=args.dist_alpha)
# Validate
val_metrics = evaluate(model, val_loader, criterion, args.device, num_classes, dist_alpha=args.dist_alpha)
logger.info(f" Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.4f}")
val_log = f" Val - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, "
val_log += f"F1: {val_metrics['f1_macro']:.4f}, MCC: {val_metrics['mcc']:.4f}"
if val_metrics['auroc'] is not None:
val_log += f", AUROC: {val_metrics['auroc']:.4f}"
logger.info(val_log)
history.append({
'epoch': epoch + 1,
'train_loss': train_metrics['loss'],
'train_acc': train_metrics['accuracy'],
'val_loss': val_metrics['loss'],
'val_acc': val_metrics['accuracy'],
'val_f1': val_metrics['f1_macro'],
'val_mcc': val_metrics['mcc'],
'val_auroc': val_metrics['auroc'],
'val_auprc': val_metrics['auprc'],
})
# Check for improvement
if val_metrics['mcc'] > best_val_mcc:
best_val_mcc = val_metrics['mcc']
epochs_without_improvement = 0
# Save best model
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_mcc': best_val_mcc,
'config': {
'task': args.task,
'num_classes': num_classes,
'classes': train_dataset.unique_labels,
}
}, os.path.join(args.output_dir, 'best_model.pt'))
logger.info(f" New best MCC: {best_val_mcc:.4f} (saved)")
else:
epochs_without_improvement += 1
logger.info(f" No improvement ({epochs_without_improvement}/{args.patience})")
# Early stopping
if epochs_without_improvement >= args.patience:
logger.info(f"\nEarly stopping at epoch {epoch + 1}")
break
# Load best model for testing
logger.info("\n" + "=" * 70)
logger.info("TESTING")
logger.info("=" * 70)
best_checkpoint = torch.load(os.path.join(args.output_dir, 'best_model.pt'))
model.load_state_dict(best_checkpoint['model_state_dict'])
test_metrics = evaluate(model, test_loader, criterion, args.device, num_classes)
logger.info(f"\nTest Results:")
logger.info(f" Accuracy: {test_metrics['accuracy']:.4f}")
logger.info(f" F1-Macro: {test_metrics['f1_macro']:.4f}")
logger.info(f" F1-Weighted: {test_metrics['f1_weighted']:.4f}")
logger.info(f" MCC: {test_metrics['mcc']:.4f}")
if test_metrics['auroc'] is not None:
logger.info(f" AUROC: {test_metrics['auroc']:.4f}")
if test_metrics['auprc'] is not None:
logger.info(f" AUPRC: {test_metrics['auprc']:.4f}")
# Save results
results = {
'task': args.task,
'filter_mode': args.filter_mode,
'num_classes': num_classes,
'classes': train_dataset.unique_labels,
'train_samples': len(train_dataset),
'val_samples': len(val_dataset),
'test_samples': len(test_dataset),
'best_epoch': best_checkpoint['epoch'],
'test_accuracy': test_metrics['accuracy'],
'test_f1_macro': test_metrics['f1_macro'],
'test_f1_weighted': test_metrics['f1_weighted'],
'test_mcc': test_metrics['mcc'],
'test_auroc': test_metrics['auroc'],
'test_auprc': test_metrics['auprc'],
'config': vars(args),
'history': history,
}
with open(os.path.join(args.output_dir, 'results.json'), 'w') as f:
json.dump(results, f, indent=2)
logger.info(f"\nResults saved to {args.output_dir}")
if __name__ == '__main__':
main()