supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
22.6 kB
"""
Bertint V8 Training — Cross-Attention + Live Bertose Finetuning
Based on V7 trainer with changes for V8 architecture:
- Per-residue protein embeddings (variable-length, padded in collate)
- protein_mask passed to model for cross-attention
- AMP (GradScaler + autocast) built in from the start
- Regression only (no classification mode — V7 showed regression wins)
"""
import argparse
import json
import logging
import os
import random
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from scipy.stats import spearmanr, pearsonr
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader
from bertint_v8 import BertintV8, BertintV8Loss, load_bertose_encoder
from dataset_v8 import BertintV8Dataset, collate_fn
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# ============================================================================
# Reproducibility
# ============================================================================
def set_seed(seed: int = 42) -> None:
"""Set random seeds for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# ============================================================================
# Metrics
# ============================================================================
def compute_metrics(
preds: np.ndarray, targets: np.ndarray
) -> Dict[str, float]:
"""Compute Spearman, Pearson, MSE."""
rho, _ = spearmanr(preds, targets)
r, _ = pearsonr(preds, targets)
mse = np.mean((preds - targets) ** 2)
return {
"spearman": float(rho) if not np.isnan(rho) else 0.0,
"pearson": float(r) if not np.isnan(r) else 0.0,
"mse": float(mse),
}
# ============================================================================
# Trainer
# ============================================================================
class BertintV8Trainer:
"""
Trainer for BertintV8 with cross-attention and AMP.
Args:
model: BertintV8 model.
criterion: Loss function.
train_loader: Training data loader.
val_loader: Validation data loader.
test_loader: Test data loader.
output_dir: Directory for checkpoints and results.
lr_encoder: Learning rate for Bertose encoder layers.
lr_head: Learning rate for cross-attention, SWE, and head.
weight_decay: Weight decay for AdamW.
max_grad_norm: Maximum gradient norm for clipping.
epochs: Number of training epochs.
patience: Early stopping patience.
checkpoint_interval: Save checkpoint every N epochs.
resume: Whether to resume from last checkpoint.
"""
def __init__(
self,
model: BertintV8,
criterion: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
test_loader: DataLoader,
output_dir: str,
lr_encoder: float = 1e-5,
lr_head: float = 1e-4,
weight_decay: float = 0.01,
max_grad_norm: float = 1.0,
epochs: int = 50,
patience: int = 15,
checkpoint_interval: int = 5,
resume: bool = False,
warmup_pct: float = 0.0,
):
self.model = model
self.criterion = criterion
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.output_dir = output_dir
self.epochs = epochs
self.patience = patience
self.checkpoint_interval = checkpoint_interval
self.resume = resume
self.max_grad_norm = max_grad_norm
os.makedirs(output_dir, exist_ok=True)
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self.model.to(self.device)
self.criterion.to(self.device)
# AMP scaler
self.scaler = GradScaler()
# Separate param groups: encoder (small lr) vs rest (larger lr)
encoder_params = []
head_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if name.startswith("seq_embeddings") or name.startswith(
"seq_layers"
):
encoder_params.append(param)
else:
head_params.append(param)
logger.info(
f" Param groups: encoder={len(encoder_params)} tensors "
f"(lr={lr_encoder}), head={len(head_params)} tensors "
f"(lr={lr_head})"
)
self.optimizer = torch.optim.AdamW(
[
{
"params": encoder_params,
"lr": lr_encoder,
"weight_decay": weight_decay,
},
{
"params": head_params,
"lr": lr_head,
"weight_decay": weight_decay,
},
]
)
# OneCycleLR with per-batch stepping (matches Twin Peaks pattern)
# Built-in warmup (pct_start) + cosine annealing
total_steps = len(train_loader) * epochs
if warmup_pct > 0:
pct_start = warmup_pct
else:
pct_start = 0.3 # Default: 30% warmup
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
self.optimizer,
max_lr=[lr_encoder, lr_head],
total_steps=total_steps,
pct_start=pct_start,
anneal_strategy='cos',
)
warmup_steps_actual = int(total_steps * pct_start)
logger.info(
f" Scheduler: OneCycleLR per-batch stepping"
)
logger.info(
f" total_steps={total_steps:,}, "
f"warmup={warmup_steps_actual:,} steps "
f"({pct_start*100:.0f}%), cosine decay"
)
# State
self.start_epoch = 0
self.best_metric = -float("inf")
self.patience_counter = 0
self.history: List[Dict] = []
if resume:
self._resume_from_checkpoint()
def train(self) -> Dict:
"""Full training loop with early stopping."""
logger.info(f"\nStarting V8 training for {self.epochs} epochs")
logger.info(f" Device: {self.device}")
logger.info(f" Train batches: {len(self.train_loader)}")
logger.info(f" Val batches: {len(self.val_loader)}")
logger.info(f" AMP: enabled")
for epoch in range(self.start_epoch, self.epochs):
t0 = time.time()
train_loss = self._train_epoch(epoch)
val_loss, val_metrics = self._eval_epoch(self.val_loader)
elapsed = time.time() - t0
rho = val_metrics["spearman"]
r = val_metrics["pearson"]
logger.info(
f" Epoch {epoch + 1:3d} | Train loss={train_loss:.4f} | "
f"Val loss={val_loss:.4f} rho={rho:.4f} r={r:.4f} | "
f"{elapsed:.1f}s"
)
# Track best
if rho > self.best_metric:
self.best_metric = rho
self.patience_counter = 0
torch.save(
self.model.state_dict(),
os.path.join(self.output_dir, "best_model.pt"),
)
logger.info(f" * New best: {rho:.4f}")
else:
self.patience_counter += 1
# History
self.history.append(
{
"epoch": epoch + 1,
"train_loss": train_loss,
"val_loss": val_loss,
"val_metrics": val_metrics,
"lr_encoder": self.optimizer.param_groups[0]["lr"],
"lr_head": self.optimizer.param_groups[1]["lr"],
}
)
# (scheduler.step() is now called per-batch in _train_epoch)
# Periodic checkpoint
if (epoch + 1) % self.checkpoint_interval == 0:
self._save_checkpoint(epoch + 1)
# Early stopping
if self.patience_counter >= self.patience:
logger.info(
f" Early stopping at epoch {epoch + 1} "
f"(no improvement for {self.patience} epochs)"
)
break
# Load best and test
logger.info(f"\n{'=' * 60}")
logger.info("Loading best model for test evaluation...")
best_path = os.path.join(self.output_dir, "best_model.pt")
self.model.load_state_dict(
torch.load(best_path, map_location=self.device)
)
test_loss, test_metrics = self._eval_epoch(self.test_loader)
logger.info(f"\n{'=' * 60}")
logger.info("TEST RESULTS:")
logger.info(f" Spearman rho: {test_metrics['spearman']:.4f}")
logger.info(f" Pearson r: {test_metrics['pearson']:.4f}")
logger.info(f" MSE: {test_metrics['mse']:.6f}")
logger.info(f"{'=' * 60}")
# Save results
results = {
"task_type": "regression",
"architecture": "cross-attention + SWE + live Bertose",
"best_metric": self.best_metric,
"test_metrics": test_metrics,
"test_loss": test_loss,
"history": self.history,
}
results_path = os.path.join(self.output_dir, "results.json")
with open(results_path, "w") as f:
json.dump(results, f, indent=2)
logger.info(f"Results saved to {results_path}")
return results
def _train_epoch(self, epoch: int) -> float:
"""Run one training epoch with AMP."""
self.model.train()
total_loss = 0.0
n_batches = len(self.train_loader)
for batch_idx, batch in enumerate(self.train_loader):
# Move to device
token_ids = batch["token_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
branch_depths = batch["branch_depths"].to(self.device)
linkage_types = batch["linkage_types"].to(self.device)
protein_emb = batch["protein_emb"].to(self.device)
protein_mask = batch["protein_mask"].to(self.device)
target = batch["target"].to(self.device)
self.optimizer.zero_grad()
# AMP forward
with autocast():
pred = self.model(
token_ids=token_ids,
attention_mask=attention_mask,
branch_depths=branch_depths,
linkage_types=linkage_types,
protein_emb=protein_emb,
protein_mask=protein_mask,
)
loss = self.criterion(pred, target)
# AMP backward
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.max_grad_norm
)
self.scaler.step(self.optimizer)
self.scaler.update()
# Per-batch LR scheduling (OneCycleLR)
self.scheduler.step()
total_loss += loss.item()
# Progress logging
if (batch_idx + 1) % 200 == 0:
avg = total_loss / (batch_idx + 1)
lr_enc = self.optimizer.param_groups[0]["lr"]
logger.info(
f" [E{epoch + 1}][{batch_idx + 1}/{n_batches}] "
f"loss={avg:.4f} lr_enc={lr_enc:.2e}"
)
return total_loss / n_batches
@torch.no_grad()
def _eval_epoch(
self, loader: DataLoader
) -> Tuple[float, Dict[str, float]]:
"""Run evaluation with AMP."""
self.model.eval()
total_loss = 0.0
all_preds = []
all_targets = []
for batch in loader:
token_ids = batch["token_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
branch_depths = batch["branch_depths"].to(self.device)
linkage_types = batch["linkage_types"].to(self.device)
protein_emb = batch["protein_emb"].to(self.device)
protein_mask = batch["protein_mask"].to(self.device)
target = batch["target"].to(self.device)
with autocast():
pred = self.model(
token_ids=token_ids,
attention_mask=attention_mask,
branch_depths=branch_depths,
linkage_types=linkage_types,
protein_emb=protein_emb,
protein_mask=protein_mask,
)
loss = self.criterion(pred, target)
total_loss += loss.item()
all_preds.extend(pred.float().cpu().numpy())
all_targets.extend(target.cpu().numpy())
avg_loss = total_loss / len(loader)
metrics = compute_metrics(
np.array(all_preds), np.array(all_targets)
)
return avg_loss, metrics
def _save_checkpoint(self, epoch: int) -> None:
"""Save full training state for resume."""
ckpt = {
"epoch": epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"scaler_state_dict": self.scaler.state_dict(),
"best_metric": self.best_metric,
"patience_counter": self.patience_counter,
"history": self.history,
}
path = os.path.join(self.output_dir, "last_checkpoint.pt")
torch.save(ckpt, path)
logger.info(f" [CKPT] Saved epoch {epoch}")
def _resume_from_checkpoint(self) -> None:
"""Resume training from last checkpoint."""
ckpt_path = os.path.join(self.output_dir, "last_checkpoint.pt")
if not os.path.exists(ckpt_path):
logger.info(" No checkpoint found, starting fresh")
return
ckpt = torch.load(ckpt_path, map_location=self.device)
self.model.load_state_dict(ckpt["model_state_dict"])
self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
self.scheduler.load_state_dict(ckpt["scheduler_state_dict"])
if "scaler_state_dict" in ckpt:
self.scaler.load_state_dict(ckpt["scaler_state_dict"])
self.start_epoch = ckpt["epoch"]
self.best_metric = ckpt["best_metric"]
self.patience_counter = ckpt["patience_counter"]
self.history = ckpt["history"]
logger.info(
f" Resumed from epoch {self.start_epoch}, "
f"best={self.best_metric:.4f}"
)
# ============================================================================
# Main
# ============================================================================
def main():
"""Entry point for V8 training."""
parser = argparse.ArgumentParser(description="Bertint V8 Training")
parser.add_argument(
"--csv_path", required=True, help="Path to binding data CSV"
)
parser.add_argument(
"--split_path", required=True, help="Path to glycan-cold splits JSON"
)
parser.add_argument(
"--protein_emb_path", required=True, help="Path to ESM-C HDF5"
)
parser.add_argument(
"--vocab_path", required=True, help="Path to BPE vocab JSON"
)
parser.add_argument(
"--bertose_checkpoint", required=True, help="Bertose checkpoint"
)
parser.add_argument("--output_dir", required=True, help="Output dir")
# Model architecture
parser.add_argument(
"--freeze_layers", type=int, default=4, help="Layers to freeze"
)
parser.add_argument(
"--shared_dim", type=int, default=512, help="Shared dim"
)
parser.add_argument(
"--num_cross_layers", type=int, default=2, help="Cross-attn layers"
)
parser.add_argument(
"--num_heads", type=int, default=8, help="Attention heads"
)
parser.add_argument(
"--swe_slices", type=int, default=512, help="SWE slices"
)
parser.add_argument(
"--dropout", type=float, default=0.1, help="Dropout rate"
)
parser.add_argument(
"--protein_dim", type=int, default=960, help="ESM-C dim"
)
parser.add_argument(
"--separate_swe", action="store_true",
help="Use separate SWE modules for glycan and protein"
)
# Training
parser.add_argument(
"--lr_encoder", type=float, default=1e-5, help="Encoder LR"
)
parser.add_argument(
"--lr_head", type=float, default=1e-4, help="Head LR"
)
parser.add_argument(
"--weight_decay", type=float, default=0.01, help="Weight decay"
)
parser.add_argument(
"--max_grad_norm", type=float, default=1.0, help="Grad clip"
)
parser.add_argument(
"--batch_size", type=int, default=32, help="Batch size"
)
parser.add_argument(
"--epochs", type=int, default=50, help="Max epochs"
)
parser.add_argument(
"--patience", type=int, default=15, help="Early stopping"
)
parser.add_argument(
"--max_glycan_length", type=int, default=256, help="Max glycan len"
)
parser.add_argument(
"--max_protein_length", type=int, default=1024, help="Max protein len"
)
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument(
"--warmup_pct", type=float, default=0.05,
help="Fraction of total steps for warmup (0.05=5%%, 0.10=10%%)"
)
parser.add_argument(
"--target_col", default="target_rank", help="Target column"
)
parser.add_argument(
"--checkpoint_interval", type=int, default=5, help="Ckpt every N"
)
parser.add_argument(
"--resume", action="store_true", help="Resume from checkpoint"
)
# Ablation controls
parser.add_argument(
"--pooling_mode", default="swe",
choices=["swe", "mean", "joint_swe"],
help="Pooling strategy: swe (default), mean, or joint_swe"
)
parser.add_argument(
"--interaction_mode", default="product_sum",
choices=["product_sum", "concat"],
help="Interaction: product_sum (default) or concat"
)
parser.add_argument(
"--no_cross_attention", action="store_true",
help="Disable cross-attention blocks (ablation)"
)
args = parser.parse_args()
set_seed(args.seed)
logger.info("Bertint V8 Training — Cross-Attention + Live Bertose")
logger.info(f" freeze_layers={args.freeze_layers}")
logger.info(f" lr_encoder={args.lr_encoder}")
logger.info(f" lr_head={args.lr_head}")
logger.info(f" batch_size={args.batch_size}")
logger.info(f" shared_dim={args.shared_dim}")
logger.info(f" cross_layers={args.num_cross_layers}")
logger.info(f" separate_swe={args.separate_swe}")
logger.info(f" pooling_mode={args.pooling_mode}")
logger.info(f" interaction_mode={args.interaction_mode}")
logger.info(f" cross_attention={not args.no_cross_attention}")
# Load datasets
logger.info("\nLoading datasets...")
train_ds = BertintV8Dataset(
args.csv_path, args.split_path, "train",
args.protein_emb_path, args.vocab_path,
max_glycan_length=args.max_glycan_length,
max_protein_length=args.max_protein_length,
target_col=args.target_col,
)
val_ds = BertintV8Dataset(
args.csv_path, args.split_path, "val",
args.protein_emb_path, args.vocab_path,
max_glycan_length=args.max_glycan_length,
max_protein_length=args.max_protein_length,
target_col=args.target_col,
)
test_ds = BertintV8Dataset(
args.csv_path, args.split_path, "test",
args.protein_emb_path, args.vocab_path,
max_glycan_length=args.max_glycan_length,
max_protein_length=args.max_protein_length,
target_col=args.target_col,
)
train_loader = DataLoader(
train_ds, batch_size=args.batch_size, shuffle=True,
num_workers=4, pin_memory=True, collate_fn=collate_fn,
)
val_loader = DataLoader(
val_ds, batch_size=args.batch_size, shuffle=False,
num_workers=2, pin_memory=True, collate_fn=collate_fn,
)
test_loader = DataLoader(
test_ds, batch_size=args.batch_size, shuffle=False,
num_workers=2, pin_memory=True, collate_fn=collate_fn,
)
# Build model
logger.info("\nBuilding model...")
config, seq_emb, seq_layers = load_bertose_encoder(
args.bertose_checkpoint, freeze_layers=args.freeze_layers
)
model = BertintV8(
seq_embeddings=seq_emb,
seq_layers=seq_layers,
glycan_dim=config.seq_hidden_size,
protein_dim=args.protein_dim,
shared_dim=args.shared_dim,
num_cross_layers=args.num_cross_layers,
num_heads=args.num_heads,
swe_slices=args.swe_slices,
dropout=args.dropout,
separate_swe=args.separate_swe,
pooling_mode=args.pooling_mode,
interaction_mode=args.interaction_mode,
use_cross_attention=not args.no_cross_attention,
)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
logger.info(f" Total params: {total_params:,}")
logger.info(f" Trainable: {trainable_params:,}")
# Loss
criterion = BertintV8Loss()
# Train
trainer = BertintV8Trainer(
model=model,
criterion=criterion,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
output_dir=args.output_dir,
lr_encoder=args.lr_encoder,
lr_head=args.lr_head,
weight_decay=args.weight_decay,
max_grad_norm=args.max_grad_norm,
epochs=args.epochs,
patience=args.patience,
checkpoint_interval=args.checkpoint_interval,
resume=args.resume,
warmup_pct=args.warmup_pct,
)
results = trainer.train()
logger.info("\nTraining complete!")
if __name__ == "__main__":
main()