""" Bertint V8 — Cross-Attention + Live Bertose Finetuning Architecture: GLYCAN: WURCS → BPE → Bertose (live, freeze layers 0-3) → [B, Lg, 768] ↓ proj(768→512) PROTEIN: precomputed ESM-C → [B, Lp, 960] ↓ ↓ proj(960→512) ↓ 2× CrossAttentionBlock(d=512, 8 heads, FFN=1024) ↓ ↓ ↓ SHARED mask-aware SWE(d=512, S=512, R=64) ↓ ↓ ↓ [B, 512] [B, 512] ↓ element-wise product + sum [B, 1024] ↓ MLP → binding score Key changes from V7: - Per-residue protein embeddings (not mean-pooled) for cross-attention - CrossAttentionBlock: glycan tokens attend to protein residues and vice versa - SWE pooling: variable-length → fixed-length (mask-aware, differentiable) - Product + sum interaction (from Twin Peaks) instead of concat Key changes from V3: - Live Bertose forward pass (not frozen precomputed embeddings) - ESM-C 300M (960-dim, not ESM-C 600M 1152-dim) SWE, CrossAttention, and mask handling ported from V3 (Sessions 5-10). """ import os import sys import math from pathlib import Path from typing import Dict, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F # ============================================================================ # Bertose model imports (same as V7) # ============================================================================ def _default_bertose_root() -> Path: """Resolve the Bertose source root without assuming a Nova-only path.""" env_root = os.environ.get("BERTOSE_ROOT") or os.environ.get("BERTOSE_REPO_ROOT") if env_root: return Path(env_root).expanduser().resolve() here = Path(__file__).resolve() for parent in here.parents: if (parent / "bert_training_v4").exists() and (parent / "model").exists(): return parent return Path("/work/ratul1/supantha/glycan-SD-VS/bert_training_v3/v3.1_cluster_training") BERTOSE_ROOT = _default_bertose_root() def _ensure_bertose_imports(): """Add Bertose source directories to sys.path if not already present.""" roots = [ str(BERTOSE_ROOT), str(BERTOSE_ROOT / "bert_training_v4"), ] for root in roots: if root not in sys.path: sys.path.insert(0, root) def load_bertose_config(): """Create Bertose config matching the V5b checkpoint.""" _ensure_bertose_imports() from model.multimodal_glycan_bert_v3 import MultimodalGlycanBERTConfig return MultimodalGlycanBERTConfig( seq_vocab_size=2200, use_cnn_frontend=True, ) def load_bertose_encoder( checkpoint_path: str, freeze_layers: int = 4 ): """ Load Bertose sequence encoder with pretrained weights. Args: checkpoint_path: Path to pretrained Bertose checkpoint. freeze_layers: Number of transformer layers to freeze (0-indexed). Returns: Tuple of (bertose_config, seq_embeddings, seq_layers). """ _ensure_bertose_imports() from model.multimodal_glycan_bert_v3 import ( MultimodalGlycanBERT, MultimodalGlycanBERTConfig, ) # Load checkpoint ckpt = torch.load(checkpoint_path, map_location="cpu") state_dict = ckpt.get("model_state_dict", ckpt) # Infer vocab size and max position embeddings from checkpoint vocab_size = state_dict["seq_embeddings.token_embeddings.weight"].shape[0] max_pos = state_dict["seq_embeddings.position_embeddings.weight"].shape[0] config = MultimodalGlycanBERTConfig( seq_vocab_size=vocab_size, seq_max_length=max_pos, use_cnn_frontend=True, ) # Instantiate full model, then extract sequence encoder model = MultimodalGlycanBERT(config) missing, unexpected = model.load_state_dict(state_dict, strict=False) loaded = len(state_dict) - len(unexpected) print(f" Loaded {loaded}/{len(state_dict)} pretrained weight tensors") print(f" ({len(missing)} missing in checkpoint, {len(unexpected)} unexpected)") seq_embeddings = model.seq_embeddings seq_layers = model.seq_layers # Freeze embedding layer + first N transformer layers for param in seq_embeddings.parameters(): param.requires_grad = False for i in range(min(freeze_layers, len(seq_layers))): for param in seq_layers[i].parameters(): param.requires_grad = False trainable = sum( p.numel() for p in seq_embeddings.parameters() if p.requires_grad ) trainable += sum( p.numel() for layer in seq_layers for p in layer.parameters() if p.requires_grad ) total = sum(p.numel() for p in seq_embeddings.parameters()) total += sum( p.numel() for layer in seq_layers for p in layer.parameters() ) print( f" Bertose encoder: {total:,} params total, " f"{trainable:,} trainable (frozen layers 0-{freeze_layers - 1})" ) return config, seq_embeddings, seq_layers # ============================================================================ # Differentiable Interpolation (from V3, Sessions 5-6) # ============================================================================ def differentiable_interp1d( x: torch.Tensor, y: torch.Tensor, xnew: torch.Tensor ) -> torch.Tensor: """ Fully differentiable 1D linear interpolation. Gradients flow through y (values) back to theta projection and earlier layers. The original Interp1d.backward from Twin Peaks only returned gradients for xnew (query coords), NOT for y — killing 83% of gradient flow. Fixed in Session 5. Args: x: [B, N] sorted input coordinates (detached) y: [B, N] values at x positions (REQUIRES grad flow!) xnew: [B, R] query coordinates Returns: [B, R] interpolated values """ n_pts = x.shape[1] # Find interpolation indices ind = torch.searchsorted( x.contiguous().detach(), xnew.contiguous().detach() ) ind = ind.clamp(1, n_pts - 1) # Gather neighbor values — preserves gradient flow through y x_lo = torch.gather(x, 1, ind - 1) x_hi = torch.gather(x, 1, ind) y_lo = torch.gather(y, 1, ind - 1) y_hi = torch.gather(y, 1, ind) # Linear interpolation weights denom = (x_hi - x_lo).clamp(min=1e-8) alpha = ((xnew - x_lo) / denom).clamp(0, 1) # Interpolated value — fully differentiable w.r.t. y_lo and y_hi return y_lo + alpha * (y_hi - y_lo) # ============================================================================ # SWE Pooling (from V3, mask-aware fixes from Sessions 6-8) # ============================================================================ class SWE_Pooling(nn.Module): """ Sliced-Wasserstein Embedding pooling. Maps variable-length token embeddings [B, L, d_in] => [B, num_slices]. From Twin Peaks, with mask-aware sorting (Session 6) and degenerate-sample handling (Session 8). """ def __init__( self, d_in: int, num_slices: int, num_ref_points: int, freeze_swe: bool = False, ): super().__init__() self.num_slices = num_slices self.num_ref_points = num_ref_points # Learnable reference distribution ref = torch.linspace(-1, 1, num_ref_points).unsqueeze(1).repeat( 1, num_slices ) self.reference = nn.Parameter(ref, requires_grad=not freeze_swe) # Projection directions (weight-normalized) self.theta = nn.utils.weight_norm( nn.Linear(d_in, num_slices, bias=False), dim=0 ) self.theta.weight_g.data = torch.ones_like(self.theta.weight_g.data) self.theta.weight_g.requires_grad = False nn.init.normal_(self.theta.weight_v) # Weighted aggregation over reference points self.weight = nn.Linear(num_ref_points, 1, bias=False) if freeze_swe: self.theta.weight_v.requires_grad = False self.reference.requires_grad = False def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Args: x: [B, L, d_in] token embeddings mask: [B, L] attention mask (1=valid, 0=padding) Returns: [B, num_slices] fixed-length representation """ batch_size, seq_len, _ = x.shape device = x.device # Degenerate: single token → just project if seq_len == 1: x_slices = self.theta(x) return x_slices.squeeze(1) # Project onto learned directions x_slices = self.theta(x) # [B, L, num_slices] # MASK-AWARE SORTING: set padding → -inf so they sort to bottom if mask is not None: mask_exp = mask.unsqueeze(-1).expand_as(x_slices) x_slices = x_slices.masked_fill(mask_exp == 0, float("-inf")) x_sorted, _ = torch.sort(x_slices, dim=1) # Strip padding from sorted array if mask is not None: valid_counts = mask.sum(dim=1).long() degenerate_mask = valid_counts < 2 safe_counts = valid_counts.clamp(min=2) max_valid = safe_counts.max().item() # Vectorized gather: take last safe_count values starts = (seq_len - safe_counts).unsqueeze(1) offsets = torch.arange(max_valid, device=device).unsqueeze(0) raw_idx = starts + offsets gather_idx = raw_idx.clamp(max=seq_len - 1).unsqueeze(-1).expand( batch_size, max_valid, self.num_slices ) x_sorted = torch.gather(x_sorted, 1, gather_idx) # Replace -inf with 0.0 for degenerate samples x_sorted = x_sorted.masked_fill( x_sorted == float("-inf"), 0.0 ) n_eff = max_valid else: degenerate_mask = None n_eff = seq_len # Interpolate to fixed reference grid x_coord = ( torch.linspace(0, 1, n_eff, device=device) .unsqueeze(0) .expand(batch_size * self.num_slices, -1) ) x_flat = x_sorted.permute(0, 2, 1).reshape( batch_size * self.num_slices, n_eff ) xnew = ( torch.linspace(0, 1, self.num_ref_points, device=device) .unsqueeze(0) .expand(batch_size * self.num_slices, -1) ) y_intp = differentiable_interp1d(x_coord, x_flat, xnew) x_interp = y_intp.view( batch_size, self.num_slices, self.num_ref_points ).permute(0, 2, 1) # Compare with reference distribution r_expanded = self.reference.expand_as(x_interp) embeddings = (r_expanded - x_interp).permute(0, 2, 1) # Weighted aggregation → [B, num_slices] weighted = self.weight(embeddings).sum(dim=-1) # Zero out degenerate samples if degenerate_mask is not None and degenerate_mask.any(): weighted = weighted.masked_fill( degenerate_mask.unsqueeze(-1), 0.0 ) return weighted # ============================================================================ # Cross-Attention Block (from V3, NaN guard from Session 9) # ============================================================================ class CrossAttentionBlock(nn.Module): """ Bidirectional cross-attention between glycan and protein tokens. Glycan tokens attend to protein residues (Q=glycan, KV=protein) Protein residues attend to glycan tokens (Q=protein, KV=glycan) Includes NaN guard for all-masked keys (Session 9 fix) and padding-position zeroing (Session 7 fix). """ def __init__( self, d_model: int, num_heads: int, ffn_dim: int, dropout: float = 0.1, ): super().__init__() # Glycan → Protein cross-attention self.glycan_cross_attn = nn.MultiheadAttention( d_model, num_heads, dropout=dropout, batch_first=True ) self.glycan_norm1 = nn.LayerNorm(d_model) self.glycan_ffn = nn.Sequential( nn.Linear(d_model, ffn_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(ffn_dim, d_model), nn.Dropout(dropout), ) self.glycan_norm2 = nn.LayerNorm(d_model) # Protein → Glycan cross-attention self.protein_cross_attn = nn.MultiheadAttention( d_model, num_heads, dropout=dropout, batch_first=True ) self.protein_norm1 = nn.LayerNorm(d_model) self.protein_ffn = nn.Sequential( nn.Linear(d_model, ffn_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(ffn_dim, d_model), nn.Dropout(dropout), ) self.protein_norm2 = nn.LayerNorm(d_model) def forward( self, glycan: torch.Tensor, protein: torch.Tensor, glycan_mask: Optional[torch.Tensor] = None, protein_mask: Optional[torch.Tensor] = None, return_attention: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns updated (glycan, protein) enriched with cross-modal info. NaN guard: nn.MultiheadAttention produces NaN when ALL keys are masked. We replace NaN→0 so residual preserves query. """ # Convert to key_padding_mask (True = padded) g_key_pad = (~glycan_mask.bool()) if glycan_mask is not None else None p_key_pad = ( (~protein_mask.bool()) if protein_mask is not None else None ) # Glycan attends to protein g_cross, g_attn_weights = self.glycan_cross_attn( query=glycan, key=protein, value=protein, key_padding_mask=p_key_pad, need_weights=return_attention, average_attn_weights=False, ) g_cross = torch.nan_to_num(g_cross, nan=0.0) glycan = self.glycan_norm1(glycan + g_cross) glycan = self.glycan_norm2(glycan + self.glycan_ffn(glycan)) if glycan_mask is not None: glycan = glycan * glycan_mask.unsqueeze(-1) # Protein attends to glycan p_cross, p_attn_weights = self.protein_cross_attn( query=protein, key=glycan, value=glycan, key_padding_mask=g_key_pad, need_weights=return_attention, average_attn_weights=False, ) p_cross = torch.nan_to_num(p_cross, nan=0.0) protein = self.protein_norm1(protein + p_cross) protein = self.protein_norm2(protein + self.protein_ffn(protein)) if protein_mask is not None: protein = protein * protein_mask.unsqueeze(-1) if return_attention: attn_dict = { "glycan_to_protein": g_attn_weights, "protein_to_glycan": p_attn_weights, } return glycan, protein, attn_dict return glycan, protein # ============================================================================ # Bertint V8 Model # ============================================================================ class BertintV8(nn.Module): """ Glycan-protein interaction predictor with cross-attention. Glycan: Live Bertose (partially frozen) → per-token [B, Lg, 768] Protein: Precomputed ESM-C per-residue [B, Lp, 960] Cross-attention: 2 bidirectional layers in shared 512-dim space SWE: Variable-length → fixed [B, 512] for each side Interaction: product + sum → MLP → scalar """ def __init__( self, seq_embeddings: nn.Module, seq_layers: nn.ModuleList, glycan_dim: int = 768, protein_dim: int = 960, shared_dim: int = 512, num_cross_layers: int = 2, num_heads: int = 8, ffn_dim: int = 1024, swe_slices: int = 512, swe_ref_points: int = 64, head_hidden: int = 256, dropout: float = 0.1, separate_swe: bool = False, pooling_mode: str = "swe", interaction_mode: str = "product_sum", use_cross_attention: bool = True, ): """ Args: seq_embeddings: Pretrained Bertose embedding layer. seq_layers: Pretrained Bertose transformer layers. glycan_dim: Bertose output dimension (768). protein_dim: ESM-C per-residue dimension (960). shared_dim: Shared space for cross-attention (512). num_cross_layers: Number of cross-attention blocks. num_heads: Attention heads per block. ffn_dim: FFN hidden dim in cross-attention. swe_slices: Number of SWE projection directions. swe_ref_points: Number of SWE reference distribution points. head_hidden: MLP head hidden dimension. dropout: Dropout rate. separate_swe: If True, use independent SWE modules. pooling_mode: 'swe', 'mean', or 'joint_swe'. interaction_mode: 'product_sum' or 'concat'. use_cross_attention: If False, skip cross-attention. """ super().__init__() self.separate_swe = separate_swe self.pooling_mode = pooling_mode self.interaction_mode = interaction_mode self.use_cross_attention = use_cross_attention print(f" Architecture config:") print(f" cross_attention={use_cross_attention}") print(f" pooling_mode={pooling_mode}") print(f" interaction_mode={interaction_mode}") # === Bertose sequence encoder (partially frozen) === self.seq_embeddings = seq_embeddings self.seq_layers = seq_layers # === Projection to shared space === self.glycan_proj = nn.Sequential( nn.Linear(glycan_dim, shared_dim), nn.LayerNorm(shared_dim), ) self.protein_proj = nn.Sequential( nn.Linear(protein_dim, shared_dim), nn.LayerNorm(shared_dim), ) # === Cross-attention stack (optional) === if use_cross_attention: self.cross_attention = nn.ModuleList([ CrossAttentionBlock( d_model=shared_dim, num_heads=num_heads, ffn_dim=ffn_dim, dropout=dropout, ) for _ in range(num_cross_layers) ]) else: self.cross_attention = nn.ModuleList() # === Pooling === if pooling_mode == "swe": if separate_swe: self.swe_glycan = SWE_Pooling( d_in=shared_dim, num_slices=swe_slices, num_ref_points=swe_ref_points, ) self.swe_protein = SWE_Pooling( d_in=shared_dim, num_slices=swe_slices, num_ref_points=swe_ref_points, ) pool_out_dim = swe_slices else: self.swe = SWE_Pooling( d_in=shared_dim, num_slices=swe_slices, num_ref_points=swe_ref_points, ) pool_out_dim = swe_slices elif pooling_mode == "mean": pool_out_dim = shared_dim elif pooling_mode == "joint_swe": self.swe_joint = SWE_Pooling( d_in=shared_dim, num_slices=swe_slices, num_ref_points=swe_ref_points, ) pool_out_dim = swe_slices # === Regression head === if pooling_mode == "joint_swe": head_input_dim = pool_out_dim else: head_input_dim = 2 * pool_out_dim self.head = nn.Sequential( nn.Linear(head_input_dim, head_hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(head_hidden, head_hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(head_hidden, 1), ) # Initialize (skip SWE weight-normed params) self.apply(self._init_weights) self._count_params() def _init_weights(self, module: nn.Module) -> None: """Xavier init for Linear, skip weight-normed SWE modules.""" if isinstance(module, nn.Linear): if hasattr(module, "weight_v"): return # Preserve SWE initialization nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def _count_params(self) -> None: """Log parameter counts.""" total = sum(p.numel() for p in self.parameters()) trainable = sum( p.numel() for p in self.parameters() if p.requires_grad ) print(f"BertintV8: {total:,} total, {trainable:,} trainable") def _masked_mean_pool(self, x, mask): """Masked mean pooling: average valid tokens only.""" mask_expanded = mask.unsqueeze(-1) x_masked = x * mask_expanded summed = x_masked.sum(dim=1) counts = mask.sum(dim=1, keepdim=True).clamp(min=1) return summed / counts def forward( self, token_ids: torch.Tensor, attention_mask: torch.Tensor, branch_depths: torch.Tensor, linkage_types: torch.Tensor, protein_emb: torch.Tensor, protein_mask: torch.Tensor, log_conc: Optional[torch.Tensor] = None, has_conc: Optional[torch.Tensor] = None, return_attention: bool = False, ) -> torch.Tensor: """ Forward pass with cross-attention. Args: token_ids: [B, Lg] BPE token IDs. attention_mask: [B, Lg] glycan attention mask (1=valid, 0=pad). branch_depths: [B, Lg] branch depth per token. linkage_types: [B, Lg] linkage type per token. protein_emb: [B, Lp, protein_dim] per-residue ESM-C embeddings. protein_mask: [B, Lp] protein attention mask (1=valid, 0=pad). Returns: [B] binding score predictions. """ # === 1. Bertose forward: per-token embeddings === x = self.seq_embeddings(token_ids, branch_depths, linkage_types) for layer in self.seq_layers: x = layer(x, attention_mask) # x: [B, Lg, 768] — all tokens (not just CLS!) # Glycan mask: use the attention_mask from BPE tokenizer glycan_mask = attention_mask # [B, Lg], 1=valid, 0=pad # === 2. Project to shared space === glycan = self.glycan_proj(x) # [B, Lg, 512] protein = self.protein_proj(protein_emb) # [B, Lp, 512] # Zero padding positions after projection (Session 7 fix: # LayerNorm bias produces non-trivial values at padding) glycan = glycan * glycan_mask.unsqueeze(-1) protein = protein * protein_mask.unsqueeze(-1) # === 3. Cross-attention (optional) === all_attention_maps = [] if self.use_cross_attention: for cross_layer in self.cross_attention: if return_attention: glycan, protein, attn_dict = cross_layer( glycan, protein, glycan_mask, protein_mask, return_attention=True, ) all_attention_maps.append(attn_dict) else: glycan, protein = cross_layer( glycan, protein, glycan_mask, protein_mask ) # === 4. Pooling === if self.pooling_mode == "joint_swe": joint_tokens = torch.cat([glycan, protein], dim=1) joint_mask = torch.cat([glycan_mask, protein_mask], dim=1) pooled = self.swe_joint(joint_tokens, joint_mask) return self.head(pooled).squeeze(-1) if self.pooling_mode == "swe": if self.separate_swe: glycan_pooled = self.swe_glycan(glycan, glycan_mask) protein_pooled = self.swe_protein(protein, protein_mask) else: glycan_pooled = self.swe(glycan, glycan_mask) protein_pooled = self.swe(protein, protein_mask) elif self.pooling_mode == "mean": glycan_pooled = self._masked_mean_pool(glycan, glycan_mask) protein_pooled = self._masked_mean_pool(protein, protein_mask) # === 5. Interaction === if self.interaction_mode == "product_sum": interaction = torch.cat([ glycan_pooled * protein_pooled, glycan_pooled + protein_pooled, ], dim=-1) elif self.interaction_mode == "concat": interaction = torch.cat([ glycan_pooled, protein_pooled, ], dim=-1) # === 6. Predict binding score === out = self.head(interaction).squeeze(-1) if return_attention and all_attention_maps: return out, all_attention_maps return out # ============================================================================ # Loss (same as V7 — simple MSE) # ============================================================================ class BertintV8Loss(nn.Module): """MSE loss for regression.""" def __init__(self): super().__init__() self.mse = nn.MSELoss() def forward( self, pred: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: """Compute MSE loss.""" return self.mse(pred, target) # ============================================================================ # Sanity check # ============================================================================ if __name__ == "__main__": print("=" * 60) print("BertintV8 Architecture Sanity Check") print("=" * 60) # Mock Bertose encoder (for testing without cluster) class MockEmbeddings(nn.Module): """Mock Bertose embeddings for local testing.""" def __init__(self, dim: int = 768): super().__init__() self.proj = nn.Linear(64, dim) def forward(self, token_ids, branch_depths, linkage_types): """Return random embeddings.""" batch_size, seq_len = token_ids.shape return torch.randn(batch_size, seq_len, 768) class MockLayer(nn.Module): """Mock transformer layer.""" def forward(self, x, mask): """Identity.""" return x seq_emb = MockEmbeddings() seq_layers = nn.ModuleList([MockLayer() for _ in range(12)]) model = BertintV8( seq_embeddings=seq_emb, seq_layers=seq_layers, glycan_dim=768, protein_dim=960, ) # Simulate batch batch_size = 4 lg = 37 # Glycan: 37 BPE tokens lp = 150 # Protein: 150 residues token_ids = torch.randint(0, 100, (batch_size, lg)) attention_mask = torch.ones(batch_size, lg).float() branch_depths = torch.zeros(batch_size, lg, dtype=torch.long) linkage_types = torch.zeros(batch_size, lg, dtype=torch.long) protein_emb = torch.randn(batch_size, lp, 960) protein_mask = torch.ones(batch_size, lp).float() out = model( token_ids=token_ids, attention_mask=attention_mask, branch_depths=branch_depths, linkage_types=linkage_types, protein_emb=protein_emb, protein_mask=protein_mask, ) print(f"\nInput shapes:") print(f" Glycan tokens: {token_ids.shape}") print(f" Protein emb: {protein_emb.shape}") print(f"\nOutput shape: {out.shape} — values: {out.detach()}") # Test loss loss_fn = BertintV8Loss() target = torch.rand(batch_size) loss = loss_fn(out, target) print(f"\nLoss: {loss.item():.6f}") # Test backward loss.backward() grad_count = sum( 1 for p in model.parameters() if p.grad is not None and p.requires_grad ) total_trainable = sum( 1 for p in model.parameters() if p.requires_grad ) print(f"Gradients: {grad_count}/{total_trainable} trainable params") print(f"\n✅ V8 sanity check passed!") """Bertint V8 model — cross-attention + live Bertose finetuning."""