""" Bertint V8 Training — Cross-Attention + Live Bertose Finetuning Based on V7 trainer with changes for V8 architecture: - Per-residue protein embeddings (variable-length, padded in collate) - protein_mask passed to model for cross-attention - AMP (GradScaler + autocast) built in from the start - Regression only (no classification mode — V7 showed regression wins) """ import argparse import json import logging import os import random import sys import time from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np import torch import torch.nn as nn from scipy.stats import spearmanr, pearsonr from torch.cuda.amp import GradScaler, autocast from torch.utils.data import DataLoader from bertint_v8 import BertintV8, BertintV8Loss, load_bertose_encoder from dataset_v8 import BertintV8Dataset, collate_fn logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # ============================================================================ # Reproducibility # ============================================================================ def set_seed(seed: int = 42) -> None: """Set random seeds for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # ============================================================================ # Metrics # ============================================================================ def compute_metrics( preds: np.ndarray, targets: np.ndarray ) -> Dict[str, float]: """Compute Spearman, Pearson, MSE.""" rho, _ = spearmanr(preds, targets) r, _ = pearsonr(preds, targets) mse = np.mean((preds - targets) ** 2) return { "spearman": float(rho) if not np.isnan(rho) else 0.0, "pearson": float(r) if not np.isnan(r) else 0.0, "mse": float(mse), } # ============================================================================ # Trainer # ============================================================================ class BertintV8Trainer: """ Trainer for BertintV8 with cross-attention and AMP. Args: model: BertintV8 model. criterion: Loss function. train_loader: Training data loader. val_loader: Validation data loader. test_loader: Test data loader. output_dir: Directory for checkpoints and results. lr_encoder: Learning rate for Bertose encoder layers. lr_head: Learning rate for cross-attention, SWE, and head. weight_decay: Weight decay for AdamW. max_grad_norm: Maximum gradient norm for clipping. epochs: Number of training epochs. patience: Early stopping patience. checkpoint_interval: Save checkpoint every N epochs. resume: Whether to resume from last checkpoint. """ def __init__( self, model: BertintV8, criterion: nn.Module, train_loader: DataLoader, val_loader: DataLoader, test_loader: DataLoader, output_dir: str, lr_encoder: float = 1e-5, lr_head: float = 1e-4, weight_decay: float = 0.01, max_grad_norm: float = 1.0, epochs: int = 50, patience: int = 15, checkpoint_interval: int = 5, resume: bool = False, warmup_pct: float = 0.0, ): self.model = model self.criterion = criterion self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader self.output_dir = output_dir self.epochs = epochs self.patience = patience self.checkpoint_interval = checkpoint_interval self.resume = resume self.max_grad_norm = max_grad_norm os.makedirs(output_dir, exist_ok=True) self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self.model.to(self.device) self.criterion.to(self.device) # AMP scaler self.scaler = GradScaler() # Separate param groups: encoder (small lr) vs rest (larger lr) encoder_params = [] head_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if name.startswith("seq_embeddings") or name.startswith( "seq_layers" ): encoder_params.append(param) else: head_params.append(param) logger.info( f" Param groups: encoder={len(encoder_params)} tensors " f"(lr={lr_encoder}), head={len(head_params)} tensors " f"(lr={lr_head})" ) self.optimizer = torch.optim.AdamW( [ { "params": encoder_params, "lr": lr_encoder, "weight_decay": weight_decay, }, { "params": head_params, "lr": lr_head, "weight_decay": weight_decay, }, ] ) # OneCycleLR with per-batch stepping (matches Twin Peaks pattern) # Built-in warmup (pct_start) + cosine annealing total_steps = len(train_loader) * epochs if warmup_pct > 0: pct_start = warmup_pct else: pct_start = 0.3 # Default: 30% warmup self.scheduler = torch.optim.lr_scheduler.OneCycleLR( self.optimizer, max_lr=[lr_encoder, lr_head], total_steps=total_steps, pct_start=pct_start, anneal_strategy='cos', ) warmup_steps_actual = int(total_steps * pct_start) logger.info( f" Scheduler: OneCycleLR per-batch stepping" ) logger.info( f" total_steps={total_steps:,}, " f"warmup={warmup_steps_actual:,} steps " f"({pct_start*100:.0f}%), cosine decay" ) # State self.start_epoch = 0 self.best_metric = -float("inf") self.patience_counter = 0 self.history: List[Dict] = [] if resume: self._resume_from_checkpoint() def train(self) -> Dict: """Full training loop with early stopping.""" logger.info(f"\nStarting V8 training for {self.epochs} epochs") logger.info(f" Device: {self.device}") logger.info(f" Train batches: {len(self.train_loader)}") logger.info(f" Val batches: {len(self.val_loader)}") logger.info(f" AMP: enabled") for epoch in range(self.start_epoch, self.epochs): t0 = time.time() train_loss = self._train_epoch(epoch) val_loss, val_metrics = self._eval_epoch(self.val_loader) elapsed = time.time() - t0 rho = val_metrics["spearman"] r = val_metrics["pearson"] logger.info( f" Epoch {epoch + 1:3d} | Train loss={train_loss:.4f} | " f"Val loss={val_loss:.4f} rho={rho:.4f} r={r:.4f} | " f"{elapsed:.1f}s" ) # Track best if rho > self.best_metric: self.best_metric = rho self.patience_counter = 0 torch.save( self.model.state_dict(), os.path.join(self.output_dir, "best_model.pt"), ) logger.info(f" * New best: {rho:.4f}") else: self.patience_counter += 1 # History self.history.append( { "epoch": epoch + 1, "train_loss": train_loss, "val_loss": val_loss, "val_metrics": val_metrics, "lr_encoder": self.optimizer.param_groups[0]["lr"], "lr_head": self.optimizer.param_groups[1]["lr"], } ) # (scheduler.step() is now called per-batch in _train_epoch) # Periodic checkpoint if (epoch + 1) % self.checkpoint_interval == 0: self._save_checkpoint(epoch + 1) # Early stopping if self.patience_counter >= self.patience: logger.info( f" Early stopping at epoch {epoch + 1} " f"(no improvement for {self.patience} epochs)" ) break # Load best and test logger.info(f"\n{'=' * 60}") logger.info("Loading best model for test evaluation...") best_path = os.path.join(self.output_dir, "best_model.pt") self.model.load_state_dict( torch.load(best_path, map_location=self.device) ) test_loss, test_metrics = self._eval_epoch(self.test_loader) logger.info(f"\n{'=' * 60}") logger.info("TEST RESULTS:") logger.info(f" Spearman rho: {test_metrics['spearman']:.4f}") logger.info(f" Pearson r: {test_metrics['pearson']:.4f}") logger.info(f" MSE: {test_metrics['mse']:.6f}") logger.info(f"{'=' * 60}") # Save results results = { "task_type": "regression", "architecture": "cross-attention + SWE + live Bertose", "best_metric": self.best_metric, "test_metrics": test_metrics, "test_loss": test_loss, "history": self.history, } results_path = os.path.join(self.output_dir, "results.json") with open(results_path, "w") as f: json.dump(results, f, indent=2) logger.info(f"Results saved to {results_path}") return results def _train_epoch(self, epoch: int) -> float: """Run one training epoch with AMP.""" self.model.train() total_loss = 0.0 n_batches = len(self.train_loader) for batch_idx, batch in enumerate(self.train_loader): # Move to device token_ids = batch["token_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) branch_depths = batch["branch_depths"].to(self.device) linkage_types = batch["linkage_types"].to(self.device) protein_emb = batch["protein_emb"].to(self.device) protein_mask = batch["protein_mask"].to(self.device) target = batch["target"].to(self.device) self.optimizer.zero_grad() # AMP forward with autocast(): pred = self.model( token_ids=token_ids, attention_mask=attention_mask, branch_depths=branch_depths, linkage_types=linkage_types, protein_emb=protein_emb, protein_mask=protein_mask, ) loss = self.criterion(pred, target) # AMP backward self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.max_grad_norm ) self.scaler.step(self.optimizer) self.scaler.update() # Per-batch LR scheduling (OneCycleLR) self.scheduler.step() total_loss += loss.item() # Progress logging if (batch_idx + 1) % 200 == 0: avg = total_loss / (batch_idx + 1) lr_enc = self.optimizer.param_groups[0]["lr"] logger.info( f" [E{epoch + 1}][{batch_idx + 1}/{n_batches}] " f"loss={avg:.4f} lr_enc={lr_enc:.2e}" ) return total_loss / n_batches @torch.no_grad() def _eval_epoch( self, loader: DataLoader ) -> Tuple[float, Dict[str, float]]: """Run evaluation with AMP.""" self.model.eval() total_loss = 0.0 all_preds = [] all_targets = [] for batch in loader: token_ids = batch["token_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) branch_depths = batch["branch_depths"].to(self.device) linkage_types = batch["linkage_types"].to(self.device) protein_emb = batch["protein_emb"].to(self.device) protein_mask = batch["protein_mask"].to(self.device) target = batch["target"].to(self.device) with autocast(): pred = self.model( token_ids=token_ids, attention_mask=attention_mask, branch_depths=branch_depths, linkage_types=linkage_types, protein_emb=protein_emb, protein_mask=protein_mask, ) loss = self.criterion(pred, target) total_loss += loss.item() all_preds.extend(pred.float().cpu().numpy()) all_targets.extend(target.cpu().numpy()) avg_loss = total_loss / len(loader) metrics = compute_metrics( np.array(all_preds), np.array(all_targets) ) return avg_loss, metrics def _save_checkpoint(self, epoch: int) -> None: """Save full training state for resume.""" ckpt = { "epoch": epoch, "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "scheduler_state_dict": self.scheduler.state_dict(), "scaler_state_dict": self.scaler.state_dict(), "best_metric": self.best_metric, "patience_counter": self.patience_counter, "history": self.history, } path = os.path.join(self.output_dir, "last_checkpoint.pt") torch.save(ckpt, path) logger.info(f" [CKPT] Saved epoch {epoch}") def _resume_from_checkpoint(self) -> None: """Resume training from last checkpoint.""" ckpt_path = os.path.join(self.output_dir, "last_checkpoint.pt") if not os.path.exists(ckpt_path): logger.info(" No checkpoint found, starting fresh") return ckpt = torch.load(ckpt_path, map_location=self.device) self.model.load_state_dict(ckpt["model_state_dict"]) self.optimizer.load_state_dict(ckpt["optimizer_state_dict"]) self.scheduler.load_state_dict(ckpt["scheduler_state_dict"]) if "scaler_state_dict" in ckpt: self.scaler.load_state_dict(ckpt["scaler_state_dict"]) self.start_epoch = ckpt["epoch"] self.best_metric = ckpt["best_metric"] self.patience_counter = ckpt["patience_counter"] self.history = ckpt["history"] logger.info( f" Resumed from epoch {self.start_epoch}, " f"best={self.best_metric:.4f}" ) # ============================================================================ # Main # ============================================================================ def main(): """Entry point for V8 training.""" parser = argparse.ArgumentParser(description="Bertint V8 Training") parser.add_argument( "--csv_path", required=True, help="Path to binding data CSV" ) parser.add_argument( "--split_path", required=True, help="Path to glycan-cold splits JSON" ) parser.add_argument( "--protein_emb_path", required=True, help="Path to ESM-C HDF5" ) parser.add_argument( "--vocab_path", required=True, help="Path to BPE vocab JSON" ) parser.add_argument( "--bertose_checkpoint", required=True, help="Bertose checkpoint" ) parser.add_argument("--output_dir", required=True, help="Output dir") # Model architecture parser.add_argument( "--freeze_layers", type=int, default=4, help="Layers to freeze" ) parser.add_argument( "--shared_dim", type=int, default=512, help="Shared dim" ) parser.add_argument( "--num_cross_layers", type=int, default=2, help="Cross-attn layers" ) parser.add_argument( "--num_heads", type=int, default=8, help="Attention heads" ) parser.add_argument( "--swe_slices", type=int, default=512, help="SWE slices" ) parser.add_argument( "--dropout", type=float, default=0.1, help="Dropout rate" ) parser.add_argument( "--protein_dim", type=int, default=960, help="ESM-C dim" ) parser.add_argument( "--separate_swe", action="store_true", help="Use separate SWE modules for glycan and protein" ) # Training parser.add_argument( "--lr_encoder", type=float, default=1e-5, help="Encoder LR" ) parser.add_argument( "--lr_head", type=float, default=1e-4, help="Head LR" ) parser.add_argument( "--weight_decay", type=float, default=0.01, help="Weight decay" ) parser.add_argument( "--max_grad_norm", type=float, default=1.0, help="Grad clip" ) parser.add_argument( "--batch_size", type=int, default=32, help="Batch size" ) parser.add_argument( "--epochs", type=int, default=50, help="Max epochs" ) parser.add_argument( "--patience", type=int, default=15, help="Early stopping" ) parser.add_argument( "--max_glycan_length", type=int, default=256, help="Max glycan len" ) parser.add_argument( "--max_protein_length", type=int, default=1024, help="Max protein len" ) parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument( "--warmup_pct", type=float, default=0.05, help="Fraction of total steps for warmup (0.05=5%%, 0.10=10%%)" ) parser.add_argument( "--target_col", default="target_rank", help="Target column" ) parser.add_argument( "--checkpoint_interval", type=int, default=5, help="Ckpt every N" ) parser.add_argument( "--resume", action="store_true", help="Resume from checkpoint" ) # Ablation controls parser.add_argument( "--pooling_mode", default="swe", choices=["swe", "mean", "joint_swe"], help="Pooling strategy: swe (default), mean, or joint_swe" ) parser.add_argument( "--interaction_mode", default="product_sum", choices=["product_sum", "concat"], help="Interaction: product_sum (default) or concat" ) parser.add_argument( "--no_cross_attention", action="store_true", help="Disable cross-attention blocks (ablation)" ) args = parser.parse_args() set_seed(args.seed) logger.info("Bertint V8 Training — Cross-Attention + Live Bertose") logger.info(f" freeze_layers={args.freeze_layers}") logger.info(f" lr_encoder={args.lr_encoder}") logger.info(f" lr_head={args.lr_head}") logger.info(f" batch_size={args.batch_size}") logger.info(f" shared_dim={args.shared_dim}") logger.info(f" cross_layers={args.num_cross_layers}") logger.info(f" separate_swe={args.separate_swe}") logger.info(f" pooling_mode={args.pooling_mode}") logger.info(f" interaction_mode={args.interaction_mode}") logger.info(f" cross_attention={not args.no_cross_attention}") # Load datasets logger.info("\nLoading datasets...") train_ds = BertintV8Dataset( args.csv_path, args.split_path, "train", args.protein_emb_path, args.vocab_path, max_glycan_length=args.max_glycan_length, max_protein_length=args.max_protein_length, target_col=args.target_col, ) val_ds = BertintV8Dataset( args.csv_path, args.split_path, "val", args.protein_emb_path, args.vocab_path, max_glycan_length=args.max_glycan_length, max_protein_length=args.max_protein_length, target_col=args.target_col, ) test_ds = BertintV8Dataset( args.csv_path, args.split_path, "test", args.protein_emb_path, args.vocab_path, max_glycan_length=args.max_glycan_length, max_protein_length=args.max_protein_length, target_col=args.target_col, ) train_loader = DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate_fn, ) val_loader = DataLoader( val_ds, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate_fn, ) test_loader = DataLoader( test_ds, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate_fn, ) # Build model logger.info("\nBuilding model...") config, seq_emb, seq_layers = load_bertose_encoder( args.bertose_checkpoint, freeze_layers=args.freeze_layers ) model = BertintV8( seq_embeddings=seq_emb, seq_layers=seq_layers, glycan_dim=config.seq_hidden_size, protein_dim=args.protein_dim, shared_dim=args.shared_dim, num_cross_layers=args.num_cross_layers, num_heads=args.num_heads, swe_slices=args.swe_slices, dropout=args.dropout, separate_swe=args.separate_swe, pooling_mode=args.pooling_mode, interaction_mode=args.interaction_mode, use_cross_attention=not args.no_cross_attention, ) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum( p.numel() for p in model.parameters() if p.requires_grad ) logger.info(f" Total params: {total_params:,}") logger.info(f" Trainable: {trainable_params:,}") # Loss criterion = BertintV8Loss() # Train trainer = BertintV8Trainer( model=model, criterion=criterion, train_loader=train_loader, val_loader=val_loader, test_loader=test_loader, output_dir=args.output_dir, lr_encoder=args.lr_encoder, lr_head=args.lr_head, weight_decay=args.weight_decay, max_grad_norm=args.max_grad_norm, epochs=args.epochs, patience=args.patience, checkpoint_interval=args.checkpoint_interval, resume=args.resume, warmup_pct=args.warmup_pct, ) results = trainer.train() logger.info("\nTraining complete!") if __name__ == "__main__": main()