| """ |
| Bertint V8 Training — Cross-Attention + Live Bertose Finetuning |
| |
| Based on V7 trainer with changes for V8 architecture: |
| - Per-residue protein embeddings (variable-length, padded in collate) |
| - protein_mask passed to model for cross-attention |
| - AMP (GradScaler + autocast) built in from the start |
| - Regression only (no classification mode — V7 showed regression wins) |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import os |
| import random |
| import sys |
| import time |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from scipy.stats import spearmanr, pearsonr |
| from torch.cuda.amp import GradScaler, autocast |
| from torch.utils.data import DataLoader |
|
|
| from bertint_v8 import BertintV8, BertintV8Loss, load_bertose_encoder |
| from dataset_v8 import BertintV8Dataset, collate_fn |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s - %(levelname)s - %(message)s", |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def set_seed(seed: int = 42) -> None: |
| """Set random seeds for reproducibility.""" |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def compute_metrics( |
| preds: np.ndarray, targets: np.ndarray |
| ) -> Dict[str, float]: |
| """Compute Spearman, Pearson, MSE.""" |
| rho, _ = spearmanr(preds, targets) |
| r, _ = pearsonr(preds, targets) |
| mse = np.mean((preds - targets) ** 2) |
| return { |
| "spearman": float(rho) if not np.isnan(rho) else 0.0, |
| "pearson": float(r) if not np.isnan(r) else 0.0, |
| "mse": float(mse), |
| } |
|
|
|
|
| |
| |
| |
|
|
|
|
| class BertintV8Trainer: |
| """ |
| Trainer for BertintV8 with cross-attention and AMP. |
| |
| Args: |
| model: BertintV8 model. |
| criterion: Loss function. |
| train_loader: Training data loader. |
| val_loader: Validation data loader. |
| test_loader: Test data loader. |
| output_dir: Directory for checkpoints and results. |
| lr_encoder: Learning rate for Bertose encoder layers. |
| lr_head: Learning rate for cross-attention, SWE, and head. |
| weight_decay: Weight decay for AdamW. |
| max_grad_norm: Maximum gradient norm for clipping. |
| epochs: Number of training epochs. |
| patience: Early stopping patience. |
| checkpoint_interval: Save checkpoint every N epochs. |
| resume: Whether to resume from last checkpoint. |
| """ |
|
|
| def __init__( |
| self, |
| model: BertintV8, |
| criterion: nn.Module, |
| train_loader: DataLoader, |
| val_loader: DataLoader, |
| test_loader: DataLoader, |
| output_dir: str, |
| lr_encoder: float = 1e-5, |
| lr_head: float = 1e-4, |
| weight_decay: float = 0.01, |
| max_grad_norm: float = 1.0, |
| epochs: int = 50, |
| patience: int = 15, |
| checkpoint_interval: int = 5, |
| resume: bool = False, |
| warmup_pct: float = 0.0, |
| ): |
| self.model = model |
| self.criterion = criterion |
| self.train_loader = train_loader |
| self.val_loader = val_loader |
| self.test_loader = test_loader |
| self.output_dir = output_dir |
| self.epochs = epochs |
| self.patience = patience |
| self.checkpoint_interval = checkpoint_interval |
| self.resume = resume |
| self.max_grad_norm = max_grad_norm |
|
|
| os.makedirs(output_dir, exist_ok=True) |
|
|
| self.device = torch.device( |
| "cuda" if torch.cuda.is_available() else "cpu" |
| ) |
| self.model.to(self.device) |
| self.criterion.to(self.device) |
|
|
| |
| self.scaler = GradScaler() |
|
|
| |
| encoder_params = [] |
| head_params = [] |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if name.startswith("seq_embeddings") or name.startswith( |
| "seq_layers" |
| ): |
| encoder_params.append(param) |
| else: |
| head_params.append(param) |
|
|
| logger.info( |
| f" Param groups: encoder={len(encoder_params)} tensors " |
| f"(lr={lr_encoder}), head={len(head_params)} tensors " |
| f"(lr={lr_head})" |
| ) |
|
|
| self.optimizer = torch.optim.AdamW( |
| [ |
| { |
| "params": encoder_params, |
| "lr": lr_encoder, |
| "weight_decay": weight_decay, |
| }, |
| { |
| "params": head_params, |
| "lr": lr_head, |
| "weight_decay": weight_decay, |
| }, |
| ] |
| ) |
|
|
| |
| |
| total_steps = len(train_loader) * epochs |
| if warmup_pct > 0: |
| pct_start = warmup_pct |
| else: |
| pct_start = 0.3 |
|
|
| self.scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| self.optimizer, |
| max_lr=[lr_encoder, lr_head], |
| total_steps=total_steps, |
| pct_start=pct_start, |
| anneal_strategy='cos', |
| ) |
| warmup_steps_actual = int(total_steps * pct_start) |
| logger.info( |
| f" Scheduler: OneCycleLR per-batch stepping" |
| ) |
| logger.info( |
| f" total_steps={total_steps:,}, " |
| f"warmup={warmup_steps_actual:,} steps " |
| f"({pct_start*100:.0f}%), cosine decay" |
| ) |
|
|
| |
| self.start_epoch = 0 |
| self.best_metric = -float("inf") |
| self.patience_counter = 0 |
| self.history: List[Dict] = [] |
|
|
| if resume: |
| self._resume_from_checkpoint() |
|
|
| def train(self) -> Dict: |
| """Full training loop with early stopping.""" |
| logger.info(f"\nStarting V8 training for {self.epochs} epochs") |
| logger.info(f" Device: {self.device}") |
| logger.info(f" Train batches: {len(self.train_loader)}") |
| logger.info(f" Val batches: {len(self.val_loader)}") |
| logger.info(f" AMP: enabled") |
|
|
| for epoch in range(self.start_epoch, self.epochs): |
| t0 = time.time() |
|
|
| train_loss = self._train_epoch(epoch) |
| val_loss, val_metrics = self._eval_epoch(self.val_loader) |
|
|
| elapsed = time.time() - t0 |
| rho = val_metrics["spearman"] |
| r = val_metrics["pearson"] |
|
|
| logger.info( |
| f" Epoch {epoch + 1:3d} | Train loss={train_loss:.4f} | " |
| f"Val loss={val_loss:.4f} rho={rho:.4f} r={r:.4f} | " |
| f"{elapsed:.1f}s" |
| ) |
|
|
| |
| if rho > self.best_metric: |
| self.best_metric = rho |
| self.patience_counter = 0 |
| torch.save( |
| self.model.state_dict(), |
| os.path.join(self.output_dir, "best_model.pt"), |
| ) |
| logger.info(f" * New best: {rho:.4f}") |
| else: |
| self.patience_counter += 1 |
|
|
| |
| self.history.append( |
| { |
| "epoch": epoch + 1, |
| "train_loss": train_loss, |
| "val_loss": val_loss, |
| "val_metrics": val_metrics, |
| "lr_encoder": self.optimizer.param_groups[0]["lr"], |
| "lr_head": self.optimizer.param_groups[1]["lr"], |
| } |
| ) |
|
|
| |
|
|
| |
| if (epoch + 1) % self.checkpoint_interval == 0: |
| self._save_checkpoint(epoch + 1) |
|
|
| |
| if self.patience_counter >= self.patience: |
| logger.info( |
| f" Early stopping at epoch {epoch + 1} " |
| f"(no improvement for {self.patience} epochs)" |
| ) |
| break |
|
|
| |
| logger.info(f"\n{'=' * 60}") |
| logger.info("Loading best model for test evaluation...") |
| best_path = os.path.join(self.output_dir, "best_model.pt") |
| self.model.load_state_dict( |
| torch.load(best_path, map_location=self.device) |
| ) |
|
|
| test_loss, test_metrics = self._eval_epoch(self.test_loader) |
|
|
| logger.info(f"\n{'=' * 60}") |
| logger.info("TEST RESULTS:") |
| logger.info(f" Spearman rho: {test_metrics['spearman']:.4f}") |
| logger.info(f" Pearson r: {test_metrics['pearson']:.4f}") |
| logger.info(f" MSE: {test_metrics['mse']:.6f}") |
| logger.info(f"{'=' * 60}") |
|
|
| |
| results = { |
| "task_type": "regression", |
| "architecture": "cross-attention + SWE + live Bertose", |
| "best_metric": self.best_metric, |
| "test_metrics": test_metrics, |
| "test_loss": test_loss, |
| "history": self.history, |
| } |
| results_path = os.path.join(self.output_dir, "results.json") |
| with open(results_path, "w") as f: |
| json.dump(results, f, indent=2) |
| logger.info(f"Results saved to {results_path}") |
|
|
| return results |
|
|
| def _train_epoch(self, epoch: int) -> float: |
| """Run one training epoch with AMP.""" |
| self.model.train() |
| total_loss = 0.0 |
| n_batches = len(self.train_loader) |
|
|
| for batch_idx, batch in enumerate(self.train_loader): |
| |
| token_ids = batch["token_ids"].to(self.device) |
| attention_mask = batch["attention_mask"].to(self.device) |
| branch_depths = batch["branch_depths"].to(self.device) |
| linkage_types = batch["linkage_types"].to(self.device) |
| protein_emb = batch["protein_emb"].to(self.device) |
| protein_mask = batch["protein_mask"].to(self.device) |
| target = batch["target"].to(self.device) |
|
|
| self.optimizer.zero_grad() |
|
|
| |
| with autocast(): |
| pred = self.model( |
| token_ids=token_ids, |
| attention_mask=attention_mask, |
| branch_depths=branch_depths, |
| linkage_types=linkage_types, |
| protein_emb=protein_emb, |
| protein_mask=protein_mask, |
| ) |
| loss = self.criterion(pred, target) |
|
|
| |
| self.scaler.scale(loss).backward() |
| self.scaler.unscale_(self.optimizer) |
| torch.nn.utils.clip_grad_norm_( |
| self.model.parameters(), self.max_grad_norm |
| ) |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
|
|
| |
| self.scheduler.step() |
|
|
| total_loss += loss.item() |
|
|
| |
| if (batch_idx + 1) % 200 == 0: |
| avg = total_loss / (batch_idx + 1) |
| lr_enc = self.optimizer.param_groups[0]["lr"] |
| logger.info( |
| f" [E{epoch + 1}][{batch_idx + 1}/{n_batches}] " |
| f"loss={avg:.4f} lr_enc={lr_enc:.2e}" |
| ) |
|
|
| return total_loss / n_batches |
|
|
| @torch.no_grad() |
| def _eval_epoch( |
| self, loader: DataLoader |
| ) -> Tuple[float, Dict[str, float]]: |
| """Run evaluation with AMP.""" |
| self.model.eval() |
| total_loss = 0.0 |
| all_preds = [] |
| all_targets = [] |
|
|
| for batch in loader: |
| token_ids = batch["token_ids"].to(self.device) |
| attention_mask = batch["attention_mask"].to(self.device) |
| branch_depths = batch["branch_depths"].to(self.device) |
| linkage_types = batch["linkage_types"].to(self.device) |
| protein_emb = batch["protein_emb"].to(self.device) |
| protein_mask = batch["protein_mask"].to(self.device) |
| target = batch["target"].to(self.device) |
|
|
| with autocast(): |
| pred = self.model( |
| token_ids=token_ids, |
| attention_mask=attention_mask, |
| branch_depths=branch_depths, |
| linkage_types=linkage_types, |
| protein_emb=protein_emb, |
| protein_mask=protein_mask, |
| ) |
| loss = self.criterion(pred, target) |
|
|
| total_loss += loss.item() |
| all_preds.extend(pred.float().cpu().numpy()) |
| all_targets.extend(target.cpu().numpy()) |
|
|
| avg_loss = total_loss / len(loader) |
| metrics = compute_metrics( |
| np.array(all_preds), np.array(all_targets) |
| ) |
| return avg_loss, metrics |
|
|
| def _save_checkpoint(self, epoch: int) -> None: |
| """Save full training state for resume.""" |
| ckpt = { |
| "epoch": epoch, |
| "model_state_dict": self.model.state_dict(), |
| "optimizer_state_dict": self.optimizer.state_dict(), |
| "scheduler_state_dict": self.scheduler.state_dict(), |
| "scaler_state_dict": self.scaler.state_dict(), |
| "best_metric": self.best_metric, |
| "patience_counter": self.patience_counter, |
| "history": self.history, |
| } |
| path = os.path.join(self.output_dir, "last_checkpoint.pt") |
| torch.save(ckpt, path) |
| logger.info(f" [CKPT] Saved epoch {epoch}") |
|
|
| def _resume_from_checkpoint(self) -> None: |
| """Resume training from last checkpoint.""" |
| ckpt_path = os.path.join(self.output_dir, "last_checkpoint.pt") |
| if not os.path.exists(ckpt_path): |
| logger.info(" No checkpoint found, starting fresh") |
| return |
|
|
| ckpt = torch.load(ckpt_path, map_location=self.device) |
| self.model.load_state_dict(ckpt["model_state_dict"]) |
| self.optimizer.load_state_dict(ckpt["optimizer_state_dict"]) |
| self.scheduler.load_state_dict(ckpt["scheduler_state_dict"]) |
| if "scaler_state_dict" in ckpt: |
| self.scaler.load_state_dict(ckpt["scaler_state_dict"]) |
| self.start_epoch = ckpt["epoch"] |
| self.best_metric = ckpt["best_metric"] |
| self.patience_counter = ckpt["patience_counter"] |
| self.history = ckpt["history"] |
| logger.info( |
| f" Resumed from epoch {self.start_epoch}, " |
| f"best={self.best_metric:.4f}" |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def main(): |
| """Entry point for V8 training.""" |
| parser = argparse.ArgumentParser(description="Bertint V8 Training") |
| parser.add_argument( |
| "--csv_path", required=True, help="Path to binding data CSV" |
| ) |
| parser.add_argument( |
| "--split_path", required=True, help="Path to glycan-cold splits JSON" |
| ) |
| parser.add_argument( |
| "--protein_emb_path", required=True, help="Path to ESM-C HDF5" |
| ) |
| parser.add_argument( |
| "--vocab_path", required=True, help="Path to BPE vocab JSON" |
| ) |
| parser.add_argument( |
| "--bertose_checkpoint", required=True, help="Bertose checkpoint" |
| ) |
| parser.add_argument("--output_dir", required=True, help="Output dir") |
|
|
| |
| parser.add_argument( |
| "--freeze_layers", type=int, default=4, help="Layers to freeze" |
| ) |
| parser.add_argument( |
| "--shared_dim", type=int, default=512, help="Shared dim" |
| ) |
| parser.add_argument( |
| "--num_cross_layers", type=int, default=2, help="Cross-attn layers" |
| ) |
| parser.add_argument( |
| "--num_heads", type=int, default=8, help="Attention heads" |
| ) |
| parser.add_argument( |
| "--swe_slices", type=int, default=512, help="SWE slices" |
| ) |
| parser.add_argument( |
| "--dropout", type=float, default=0.1, help="Dropout rate" |
| ) |
| parser.add_argument( |
| "--protein_dim", type=int, default=960, help="ESM-C dim" |
| ) |
| parser.add_argument( |
| "--separate_swe", action="store_true", |
| help="Use separate SWE modules for glycan and protein" |
| ) |
|
|
| |
| parser.add_argument( |
| "--lr_encoder", type=float, default=1e-5, help="Encoder LR" |
| ) |
| parser.add_argument( |
| "--lr_head", type=float, default=1e-4, help="Head LR" |
| ) |
| parser.add_argument( |
| "--weight_decay", type=float, default=0.01, help="Weight decay" |
| ) |
| parser.add_argument( |
| "--max_grad_norm", type=float, default=1.0, help="Grad clip" |
| ) |
| parser.add_argument( |
| "--batch_size", type=int, default=32, help="Batch size" |
| ) |
| parser.add_argument( |
| "--epochs", type=int, default=50, help="Max epochs" |
| ) |
| parser.add_argument( |
| "--patience", type=int, default=15, help="Early stopping" |
| ) |
| parser.add_argument( |
| "--max_glycan_length", type=int, default=256, help="Max glycan len" |
| ) |
| parser.add_argument( |
| "--max_protein_length", type=int, default=1024, help="Max protein len" |
| ) |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| parser.add_argument( |
| "--warmup_pct", type=float, default=0.05, |
| help="Fraction of total steps for warmup (0.05=5%%, 0.10=10%%)" |
| ) |
| parser.add_argument( |
| "--target_col", default="target_rank", help="Target column" |
| ) |
| parser.add_argument( |
| "--checkpoint_interval", type=int, default=5, help="Ckpt every N" |
| ) |
| parser.add_argument( |
| "--resume", action="store_true", help="Resume from checkpoint" |
| ) |
| |
| parser.add_argument( |
| "--pooling_mode", default="swe", |
| choices=["swe", "mean", "joint_swe"], |
| help="Pooling strategy: swe (default), mean, or joint_swe" |
| ) |
| parser.add_argument( |
| "--interaction_mode", default="product_sum", |
| choices=["product_sum", "concat"], |
| help="Interaction: product_sum (default) or concat" |
| ) |
| parser.add_argument( |
| "--no_cross_attention", action="store_true", |
| help="Disable cross-attention blocks (ablation)" |
| ) |
|
|
| args = parser.parse_args() |
|
|
| set_seed(args.seed) |
| logger.info("Bertint V8 Training — Cross-Attention + Live Bertose") |
| logger.info(f" freeze_layers={args.freeze_layers}") |
| logger.info(f" lr_encoder={args.lr_encoder}") |
| logger.info(f" lr_head={args.lr_head}") |
| logger.info(f" batch_size={args.batch_size}") |
| logger.info(f" shared_dim={args.shared_dim}") |
| logger.info(f" cross_layers={args.num_cross_layers}") |
| logger.info(f" separate_swe={args.separate_swe}") |
| logger.info(f" pooling_mode={args.pooling_mode}") |
| logger.info(f" interaction_mode={args.interaction_mode}") |
| logger.info(f" cross_attention={not args.no_cross_attention}") |
|
|
| |
| logger.info("\nLoading datasets...") |
| train_ds = BertintV8Dataset( |
| args.csv_path, args.split_path, "train", |
| args.protein_emb_path, args.vocab_path, |
| max_glycan_length=args.max_glycan_length, |
| max_protein_length=args.max_protein_length, |
| target_col=args.target_col, |
| ) |
| val_ds = BertintV8Dataset( |
| args.csv_path, args.split_path, "val", |
| args.protein_emb_path, args.vocab_path, |
| max_glycan_length=args.max_glycan_length, |
| max_protein_length=args.max_protein_length, |
| target_col=args.target_col, |
| ) |
| test_ds = BertintV8Dataset( |
| args.csv_path, args.split_path, "test", |
| args.protein_emb_path, args.vocab_path, |
| max_glycan_length=args.max_glycan_length, |
| max_protein_length=args.max_protein_length, |
| target_col=args.target_col, |
| ) |
|
|
| train_loader = DataLoader( |
| train_ds, batch_size=args.batch_size, shuffle=True, |
| num_workers=4, pin_memory=True, collate_fn=collate_fn, |
| ) |
| val_loader = DataLoader( |
| val_ds, batch_size=args.batch_size, shuffle=False, |
| num_workers=2, pin_memory=True, collate_fn=collate_fn, |
| ) |
| test_loader = DataLoader( |
| test_ds, batch_size=args.batch_size, shuffle=False, |
| num_workers=2, pin_memory=True, collate_fn=collate_fn, |
| ) |
|
|
| |
| logger.info("\nBuilding model...") |
| config, seq_emb, seq_layers = load_bertose_encoder( |
| args.bertose_checkpoint, freeze_layers=args.freeze_layers |
| ) |
|
|
| model = BertintV8( |
| seq_embeddings=seq_emb, |
| seq_layers=seq_layers, |
| glycan_dim=config.seq_hidden_size, |
| protein_dim=args.protein_dim, |
| shared_dim=args.shared_dim, |
| num_cross_layers=args.num_cross_layers, |
| num_heads=args.num_heads, |
| swe_slices=args.swe_slices, |
| dropout=args.dropout, |
| separate_swe=args.separate_swe, |
| pooling_mode=args.pooling_mode, |
| interaction_mode=args.interaction_mode, |
| use_cross_attention=not args.no_cross_attention, |
| ) |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum( |
| p.numel() for p in model.parameters() if p.requires_grad |
| ) |
| logger.info(f" Total params: {total_params:,}") |
| logger.info(f" Trainable: {trainable_params:,}") |
|
|
| |
| criterion = BertintV8Loss() |
|
|
| |
| trainer = BertintV8Trainer( |
| model=model, |
| criterion=criterion, |
| train_loader=train_loader, |
| val_loader=val_loader, |
| test_loader=test_loader, |
| output_dir=args.output_dir, |
| lr_encoder=args.lr_encoder, |
| lr_head=args.lr_head, |
| weight_decay=args.weight_decay, |
| max_grad_norm=args.max_grad_norm, |
| epochs=args.epochs, |
| patience=args.patience, |
| checkpoint_interval=args.checkpoint_interval, |
| resume=args.resume, |
| warmup_pct=args.warmup_pct, |
| ) |
|
|
| results = trainer.train() |
| logger.info("\nTraining complete!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|