""" Trainer for the Q_theta state-selectivity scorer. Implements two-phase training: Phase 1: DockQ regression (learn complex quality from all data) Phase 2: Selectivity fine-tuning (learn to rank X+ > X- for the same binder) Integrates with Weights & Biases for experiment tracking. """ import os import time import logging import numpy as np import torch import torch.nn as nn from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from scipy.stats import spearmanr from sklearn.metrics import roc_auc_score import wandb logger = logging.getLogger(__name__) class AverageMeter: def __init__(self): self.reset() def reset(self): self.val = 0.0 self.avg = 0.0 self.sum = 0.0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count class AlloDesignerTrainer: """ Two-phase trainer for Q_theta. Phase 1 (DockQ regression): - Minimizes MSE(Q_theta(X, Y), DockQ_label) on all complex types - Learns general complex quality Phase 2 (Selectivity fine-tuning): - Minimizes selectivity margin loss on paired (pos, neg) data - Learns to rank Q(X+, Y) > Q(X-, Y) - Combined: L = L_regression + lambda_rank * L_selectivity """ def __init__(self, model, config, device='cuda'): self.model = model.to(device) self.config = config self.device = device self.use_sam = config.get('optimizer', 'adamw') == 'sam' # Optimizer if self.use_sam: from utils.sam import SAM self.optimizer = SAM( model.parameters(), base_optimizer=AdamW, rho=0.05, lr=config.get('lr', 1e-4), weight_decay=config.get('weight_decay', 1e-4), betas=(0.9, 0.999), ) # SAM wraps AdamW; scheduler goes on base_optimizer sched_optimizer = self.optimizer.base_optimizer else: self.optimizer = AdamW( model.parameters(), lr=config.get('lr', 1e-4), weight_decay=config.get('weight_decay', 1e-4), betas=(0.9, 0.999), ) sched_optimizer = self.optimizer # Learning rate scheduler (warmup + cosine) n_warmup = config.get('warmup_steps', 100) n_total = config.get('max_steps', 5000) warmup_sched = LinearLR(sched_optimizer, start_factor=0.01, end_factor=1.0, total_iters=n_warmup) cosine_sched = CosineAnnealingLR(sched_optimizer, T_max=n_total - n_warmup, eta_min=1e-6) self.scheduler = SequentialLR(sched_optimizer, [warmup_sched, cosine_sched], milestones=[n_warmup]) self.global_step = 0 self.best_val_metric = -float('inf') self.checkpoint_dir = config.get('checkpoint_dir', 'results/checkpoints') os.makedirs(self.checkpoint_dir, exist_ok=True) # ------------------------------------------------------------------ # # Phase 1: DockQ regression # ------------------------------------------------------------------ # def train_step_phase1(self, batch): """Single training step for Phase 1 (DockQ regression).""" self.model.train() node_feats = batch['node_feats'].to(self.device) # [B, N, node_dim] edge_feats = batch['edge_feats'].to(self.device) # [B, N, N, edge_dim] node_mask = batch['node_mask'].to(self.device) # [B, N] labels = batch['label'].to(self.device) # [B] esm_feats = batch['esm_feats'].to(self.device) if 'esm_feats' in batch else None self.optimizer.zero_grad() scores = self.model(node_feats, edge_feats, node_mask, esm_feats=esm_feats) # [B] loss = nn.functional.mse_loss(scores, labels) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) if self.use_sam: self.optimizer.first_step() # Second forward-backward pass scores2 = self.model(node_feats, edge_feats, node_mask, esm_feats=esm_feats) loss2 = nn.functional.mse_loss(scores2, labels) self.optimizer.zero_grad() loss2.backward() nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.optimizer.second_step() else: self.optimizer.step() self.scheduler.step() self.global_step += 1 return {'loss': loss.item(), 'scores': scores.detach(), 'labels': labels} def run_phase1(self, train_loader, val_loader, n_epochs: int = 30, run_name: str = 'phase1'): """Phase 1 training loop.""" logger.info(f"Starting Phase 1 (DockQ regression) for {n_epochs} epochs") wandb.define_metric('phase1/step') wandb.define_metric('phase1/*', step_metric='phase1/step') for epoch in range(n_epochs): # Train train_meter = AverageMeter() all_scores, all_labels = [], [] for batch in train_loader: result = self.train_step_phase1(batch) train_meter.update(result['loss'], n=len(result['scores'])) all_scores.append(result['scores'].cpu().numpy()) all_labels.append(result['labels'].cpu().numpy()) if self.global_step % 50 == 0: wandb.log({ 'phase1/train_loss': result['loss'], 'phase1/lr': self.optimizer.param_groups[0]['lr'], 'phase1/step': self.global_step, }) # Compute Spearman corr on training data all_scores = np.concatenate(all_scores) all_labels = np.concatenate(all_labels) train_spearman = spearmanr(all_scores, all_labels).correlation # Validate val_metrics = self.evaluate_phase1(val_loader) logger.info( f"Phase1 Epoch {epoch+1}/{n_epochs} | " f"Train Loss: {train_meter.avg:.4f} | " f"Train Spearman: {train_spearman:.3f} | " f"Val Loss: {val_metrics['val_loss']:.4f} | " f"Val Spearman: {val_metrics['val_spearman']:.3f} | " f"Val AUC: {val_metrics.get('val_auc', 0):.3f}" ) wandb.log({ 'phase1/epoch': epoch + 1, 'phase1/train_loss_epoch': train_meter.avg, 'phase1/train_spearman': train_spearman, **{f'phase1/{k}': v for k, v in val_metrics.items()}, }) # Checkpoint best model if val_metrics['val_spearman'] > self.best_val_metric: self.best_val_metric = val_metrics['val_spearman'] self.save_checkpoint('best_phase1.pt', extra={'epoch': epoch, 'phase': 1}) logger.info(f" -> New best Phase 1 model (val_spearman={self.best_val_metric:.3f})") logger.info("Phase 1 training complete.") @torch.no_grad() def evaluate_phase1(self, loader): """Evaluate Phase 1 model on val/test set.""" self.model.eval() all_scores, all_labels = [], [] total_loss = 0.0 n_batches = 0 for batch in loader: node_feats = batch['node_feats'].to(self.device) edge_feats = batch['edge_feats'].to(self.device) node_mask = batch['node_mask'].to(self.device) labels = batch['label'].to(self.device) esm_feats = batch['esm_feats'].to(self.device) if 'esm_feats' in batch else None scores = self.model(node_feats, edge_feats, node_mask, esm_feats=esm_feats) loss = nn.functional.mse_loss(scores, labels) total_loss += loss.item() n_batches += 1 all_scores.append(scores.cpu().numpy()) all_labels.append(labels.cpu().numpy()) all_scores = np.concatenate(all_scores) all_labels = np.concatenate(all_labels) spearman = spearmanr(all_scores, all_labels).correlation if np.isnan(spearman): spearman = 0.0 metrics = { 'val_loss': total_loss / max(n_batches, 1), 'val_spearman': float(spearman), } # AUC for binary quality (label > 0.5 = positive) binary_labels = (all_labels > 0.5).astype(int) if binary_labels.sum() > 0 and binary_labels.sum() < len(binary_labels): try: metrics['val_auc'] = roc_auc_score(binary_labels, all_scores) except Exception: pass return metrics # ------------------------------------------------------------------ # # Phase 2: Selectivity fine-tuning # ------------------------------------------------------------------ # def train_step_phase2(self, batch, lambda_rank: float = 1.0, margin: float = 0.2, lambda_ddg: float = 0.1): """Single training step for Phase 2 (selectivity margin + ddG auxiliary).""" self.model.train() pos = batch['pos'] neg = batch['neg'] pos_node = pos['node_feats'].to(self.device) pos_edge = pos['edge_feats'].to(self.device) pos_mask = pos['node_mask'].to(self.device) pos_label = pos['label'].to(self.device) pos_ce = pos.get('contact_energy', None) if pos_ce is not None: pos_ce = pos_ce.to(self.device) neg_node = neg['node_feats'].to(self.device) neg_edge = neg['edge_feats'].to(self.device) neg_mask = neg['node_mask'].to(self.device) pos_esm = pos['esm_feats'].to(self.device) if 'esm_feats' in pos else None neg_esm = neg['esm_feats'].to(self.device) if 'esm_feats' in neg else None self.optimizer.zero_grad() pos_scores = self.model(pos_node, pos_edge, pos_mask, esm_feats=pos_esm) # [B] neg_scores = self.model(neg_node, neg_edge, neg_mask, esm_feats=neg_esm) # [B] # Regression loss on positive examples loss_reg = nn.functional.mse_loss(pos_scores, pos_label) # Selectivity margin loss: pos_score - neg_score > margin loss_margin = nn.functional.relu(margin - (pos_scores - neg_scores)).mean() # InfoNCE-style selectivity loss eps = 1e-6 pos_logit = torch.log(pos_scores.clamp(eps, 1 - eps) / (1 - pos_scores).clamp(eps)) neg_logit = torch.log(neg_scores.clamp(eps, 1 - eps) / (1 - neg_scores).clamp(eps)) log_denom = torch.stack([pos_logit, neg_logit], dim=-1).logsumexp(dim=-1) infonce_loss = -(pos_logit - log_denom).mean() # ddG auxiliary loss: MSE against contact-energy proxy (physics-informed soft label) loss_ddg = torch.tensor(0.0, device=self.device) if pos_ce is not None and pos_ce.shape[0] > 0: # pos_ce is a contact-energy-based ddG proxy in [0, 1] # Align positive score toward the contact energy signal loss_ddg = nn.functional.mse_loss(pos_scores, pos_ce) loss = loss_reg + lambda_rank * (loss_margin + infonce_loss) + lambda_ddg * loss_ddg loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) if self.use_sam: self.optimizer.first_step() # Second forward-backward for SAM pos_scores2 = self.model(pos_node, pos_edge, pos_mask, esm_feats=pos_esm) neg_scores2 = self.model(neg_node, neg_edge, neg_mask, esm_feats=neg_esm) loss_reg2 = nn.functional.mse_loss(pos_scores2, pos_label) loss_margin2 = nn.functional.relu(margin - (pos_scores2 - neg_scores2)).mean() eps2 = 1e-6 pl2 = torch.log(pos_scores2.clamp(eps2, 1-eps2) / (1-pos_scores2).clamp(eps2)) nl2 = torch.log(neg_scores2.clamp(eps2, 1-eps2) / (1-neg_scores2).clamp(eps2)) ld2 = torch.stack([pl2, nl2], dim=-1).logsumexp(dim=-1) infonce2 = -(pl2 - ld2).mean() loss2 = loss_reg2 + lambda_rank * (loss_margin2 + infonce2) self.optimizer.zero_grad() loss2.backward() nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.optimizer.second_step() else: self.optimizer.step() self.scheduler.step() self.global_step += 1 selectivity_gap = (pos_scores - neg_scores).mean().item() return { 'loss': loss.item(), 'loss_reg': loss_reg.item(), 'loss_margin': loss_margin.item(), 'loss_infonce': infonce_loss.item(), 'loss_ddg': loss_ddg.item(), 'selectivity_gap': selectivity_gap, 'pos_scores': pos_scores.detach(), 'neg_scores': neg_scores.detach(), } def train_step_phase2_v2(self, batch, lambda_rank: float = 1.0, margin: float = 0.2, lambda_ddg: float = 0.0, lambda_path: float = 0.5): """Phase 2 training step with multi-negative + path monotonicity.""" self.model.train() pos = batch['pos'] neg = batch['neg'] pos_node = pos['node_feats'].to(self.device) pos_edge = pos['edge_feats'].to(self.device) pos_mask = pos['node_mask'].to(self.device) pos_label = pos['label'].to(self.device) pos_ce = pos.get('contact_energy', None) if pos_ce is not None: pos_ce = pos_ce.to(self.device) neg_node = neg['node_feats'].to(self.device) neg_edge = neg['edge_feats'].to(self.device) neg_mask = neg['node_mask'].to(self.device) pos_esm = pos['esm_feats'].to(self.device) if 'esm_feats' in pos else None neg_esm = neg['esm_feats'].to(self.device) if 'esm_feats' in neg else None self.optimizer.zero_grad() pos_scores = self.model(pos_node, pos_edge, pos_mask, esm_feats=pos_esm) neg_scores = self.model(neg_node, neg_edge, neg_mask, esm_feats=neg_esm) # Score path frames if present path_scores = [] path_taus = batch.get('path_taus', []) for path_frame in batch.get('path', []): p_node = path_frame['node_feats'].to(self.device) p_edge = path_frame['edge_feats'].to(self.device) p_mask = path_frame['node_mask'].to(self.device) p_score = self.model(p_node, p_edge, p_mask) path_scores.append(p_score) # Regression loss on positive examples loss_reg = nn.functional.mse_loss(pos_scores, pos_label) # Selectivity margin loss loss_margin = nn.functional.relu(margin - (pos_scores - neg_scores)).mean() # InfoNCE-style selectivity loss eps = 1e-6 pos_logit = torch.log(pos_scores.clamp(eps, 1 - eps) / (1 - pos_scores).clamp(eps)) neg_logit = torch.log(neg_scores.clamp(eps, 1 - eps) / (1 - neg_scores).clamp(eps)) log_denom = torch.stack([pos_logit, neg_logit], dim=-1).logsumexp(dim=-1) infonce_loss = -(pos_logit - log_denom).mean() # ddG auxiliary loss loss_ddg = torch.tensor(0.0, device=self.device) if pos_ce is not None and pos_ce.shape[0] > 0 and lambda_ddg > 0: loss_ddg = nn.functional.mse_loss(pos_scores, pos_ce) # Path monotonicity loss loss_path = torch.tensor(0.0, device=self.device) if path_scores and lambda_path > 0: small_margin = 0.05 for i in range(len(path_scores) - 1): loss_path = loss_path + nn.functional.relu( path_scores[i] - path_scores[i + 1] + small_margin ).mean() # Last path frame < positive score loss_path = loss_path + nn.functional.relu( path_scores[-1] - pos_scores + margin ).mean() # First path frame > negative score loss_path = loss_path + nn.functional.relu( neg_scores - path_scores[0] + small_margin ).mean() loss = (loss_reg + lambda_rank * (loss_margin + infonce_loss) + lambda_ddg * loss_ddg + lambda_path * loss_path) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.optimizer.step() self.scheduler.step() self.global_step += 1 selectivity_gap = (pos_scores - neg_scores).mean().item() return { 'loss': loss.item(), 'loss_reg': loss_reg.item(), 'loss_margin': loss_margin.item(), 'loss_infonce': infonce_loss.item(), 'loss_ddg': loss_ddg.item(), 'loss_path': loss_path.item(), 'selectivity_gap': selectivity_gap, 'pos_scores': pos_scores.detach(), 'neg_scores': neg_scores.detach(), } def run_phase2_path(self, train_loader, val_loader, n_epochs: int = 20, lambda_rank: float = 1.0, margin: float = 0.2, lambda_ddg: float = 0.0, lambda_path: float = 0.5): """Phase 2 with path-aware training loop.""" logger.info(f"Starting Phase 2 (path-aware) for {n_epochs} epochs " f"[lambda_rank={lambda_rank}, lambda_path={lambda_path}]") self.best_val_metric = -float('inf') for epoch in range(n_epochs): loss_meter = AverageMeter() gap_meter = AverageMeter() path_meter = AverageMeter() for batch in train_loader: result = self.train_step_phase2_v2( batch, lambda_rank, margin, lambda_ddg, lambda_path) B = len(result['pos_scores']) loss_meter.update(result['loss'], B) gap_meter.update(result['selectivity_gap'], B) path_meter.update(result['loss_path'], B) if self.global_step % 50 == 0: wandb.log({ 'phase2/train_loss': result['loss'], 'phase2/loss_margin': result['loss_margin'], 'phase2/loss_infonce': result['loss_infonce'], 'phase2/loss_path': result['loss_path'], 'phase2/selectivity_gap': result['selectivity_gap'], 'phase2/lr': self.optimizer.param_groups[0]['lr'], 'phase2/step': self.global_step, }) val_metrics = self.evaluate_phase2(val_loader) logger.info( f"Phase2-Path Epoch {epoch+1}/{n_epochs} | " f"Loss: {loss_meter.avg:.4f} | " f"Gap: {gap_meter.avg:.3f} | " f"Path: {path_meter.avg:.4f} | " f"Val Gap: {val_metrics['val_selectivity_gap']:.3f} | " f"Val Acc: {val_metrics['val_ranking_acc']:.3f}" ) wandb.log({ 'phase2/epoch': epoch + 1, 'phase2/train_loss_epoch': loss_meter.avg, 'phase2/train_gap_epoch': gap_meter.avg, 'phase2/train_path_loss_epoch': path_meter.avg, **{f'phase2/{k}': v for k, v in val_metrics.items()}, }) if val_metrics['val_selectivity_gap'] > self.best_val_metric: self.best_val_metric = val_metrics['val_selectivity_gap'] self.save_checkpoint('best_phase2.pt', extra={'epoch': epoch, 'phase': 2}) logger.info(f" -> New best Phase 2 model (val_gap={self.best_val_metric:.3f})") logger.info("Phase 2 (path-aware) training complete.") def run_phase2(self, train_loader, val_loader, n_epochs: int = 20, lambda_rank: float = 1.0, margin: float = 0.2, lambda_ddg: float = 0.1): """Phase 2 training loop (selectivity fine-tuning + ddG auxiliary).""" logger.info(f"Starting Phase 2 (selectivity fine-tuning) for {n_epochs} epochs " f"[lambda_rank={lambda_rank}, lambda_ddg={lambda_ddg}]") self.best_val_metric = -float('inf') for epoch in range(n_epochs): loss_meter = AverageMeter() gap_meter = AverageMeter() for batch in train_loader: result = self.train_step_phase2(batch, lambda_rank, margin, lambda_ddg) B = len(result['pos_scores']) loss_meter.update(result['loss'], B) gap_meter.update(result['selectivity_gap'], B) if self.global_step % 50 == 0: wandb.log({ 'phase2/train_loss': result['loss'], 'phase2/loss_margin': result['loss_margin'], 'phase2/loss_infonce': result['loss_infonce'], 'phase2/loss_ddg': result['loss_ddg'], 'phase2/selectivity_gap': result['selectivity_gap'], 'phase2/lr': self.optimizer.param_groups[0]['lr'], 'phase2/step': self.global_step, }) # Validate val_metrics = self.evaluate_phase2(val_loader) logger.info( f"Phase2 Epoch {epoch+1}/{n_epochs} | " f"Loss: {loss_meter.avg:.4f} | " f"Gap: {gap_meter.avg:.3f} | " f"Val Gap: {val_metrics['val_selectivity_gap']:.3f} | " f"Val Acc: {val_metrics['val_ranking_acc']:.3f}" ) wandb.log({ 'phase2/epoch': epoch + 1, 'phase2/train_loss_epoch': loss_meter.avg, 'phase2/train_gap_epoch': gap_meter.avg, **{f'phase2/{k}': v for k, v in val_metrics.items()}, }) # Checkpoint if val_metrics['val_selectivity_gap'] > self.best_val_metric: self.best_val_metric = val_metrics['val_selectivity_gap'] self.save_checkpoint('best_phase2.pt', extra={'epoch': epoch, 'phase': 2}) logger.info(f" -> New best Phase 2 model (val_gap={self.best_val_metric:.3f})") logger.info("Phase 2 training complete.") @torch.no_grad() def evaluate_phase2(self, loader): """Evaluate selectivity on paired (pos, neg) val set.""" self.model.eval() all_pos_scores, all_neg_scores = [], [] for batch in loader: if 'pos' not in batch: continue pos = batch['pos'] neg = batch['neg'] pos_esm = pos['esm_feats'].to(self.device) if 'esm_feats' in pos else None neg_esm = neg['esm_feats'].to(self.device) if 'esm_feats' in neg else None pos_scores = self.model( pos['node_feats'].to(self.device), pos['edge_feats'].to(self.device), pos['node_mask'].to(self.device), esm_feats=pos_esm ) neg_scores = self.model( neg['node_feats'].to(self.device), neg['edge_feats'].to(self.device), neg['node_mask'].to(self.device), esm_feats=neg_esm ) all_pos_scores.append(pos_scores.cpu().numpy()) all_neg_scores.append(neg_scores.cpu().numpy()) if not all_pos_scores: return {'val_selectivity_gap': 0.0, 'val_ranking_acc': 0.5} all_pos = np.concatenate(all_pos_scores) all_neg = np.concatenate(all_neg_scores) gap = float((all_pos - all_neg).mean()) acc = float((all_pos > all_neg).mean()) return { 'val_selectivity_gap': gap, 'val_ranking_acc': acc, 'val_pos_score_mean': float(all_pos.mean()), 'val_neg_score_mean': float(all_neg.mean()), } # ------------------------------------------------------------------ # # Checkpointing # ------------------------------------------------------------------ # def save_checkpoint(self, filename: str, extra: dict = None): path = os.path.join(self.checkpoint_dir, filename) state = { 'model_state': self.model.state_dict(), 'optimizer_state': self.optimizer.state_dict(), 'global_step': self.global_step, 'config': self.config, } if extra: state.update(extra) torch.save(state, path) logger.debug(f"Saved checkpoint: {path}") def load_checkpoint(self, filename: str): path = os.path.join(self.checkpoint_dir, filename) if not os.path.exists(path): logger.warning(f"Checkpoint not found: {path}") return False state = torch.load(path, map_location=self.device) self.model.load_state_dict(state['model_state']) self.optimizer.load_state_dict(state['optimizer_state']) self.global_step = state.get('global_step', 0) logger.info(f"Loaded checkpoint from {path} (step {self.global_step})") return True # ------------------------------------------------------------------ # # Full evaluation (test set) # ------------------------------------------------------------------ # @torch.no_grad() def evaluate_test(self, test_loader, phase: int = 2): """Full evaluation on test set with all metrics.""" self.model.eval() all_scores, all_labels, all_types = [], [], [] for batch in test_loader: if 'pos' in batch: # Paired batch for key in ['pos', 'neg']: d = batch[key] d_esm = d['esm_feats'].to(self.device) if 'esm_feats' in d else None scores = self.model( d['node_feats'].to(self.device), d['edge_feats'].to(self.device), d['node_mask'].to(self.device), esm_feats=d_esm ) all_scores.extend(scores.cpu().numpy().tolist()) all_labels.extend(d['label'].numpy().tolist()) all_types.extend(['pos' if key == 'pos' else 'neg'] * len(scores)) else: esm_feats = batch['esm_feats'].to(self.device) if 'esm_feats' in batch else None scores = self.model( batch['node_feats'].to(self.device), batch['edge_feats'].to(self.device), batch['node_mask'].to(self.device), esm_feats=esm_feats ) all_scores.extend(scores.cpu().numpy().tolist()) all_labels.extend(batch['label'].numpy().tolist()) all_types.extend(batch['type']) all_scores = np.array(all_scores) all_labels = np.array(all_labels) metrics = {} # Spearman correlation (all samples) metrics['test_spearman'] = float(spearmanr(all_scores, all_labels).correlation or 0) # AUC (binary: label > 0.5 = positive quality) binary = (all_labels > 0.5).astype(int) if binary.sum() > 0 and binary.sum() < len(binary): try: metrics['test_auc'] = float(roc_auc_score(binary, all_scores)) except Exception: pass # Selectivity gap (pos vs neg_apo pairs) pos_mask = np.array([t == 'pos' or t == 'positive' for t in all_types]) neg_mask = np.array([t == 'neg' or t == 'negative_apo' for t in all_types]) if pos_mask.sum() > 0 and neg_mask.sum() > 0: metrics['test_selectivity_gap'] = float(all_scores[pos_mask].mean() - all_scores[neg_mask].mean()) logger.info(f"Test evaluation: {metrics}") wandb.log({f'test/{k}': v for k, v in metrics.items()}) return metrics, all_scores, all_labels, all_types