| """ |
| AFFINose interaction model — cross-attention with live BERTose encoding |
| |
| Architecture: |
| GLYCAN: WURCS → BPE → BERTose (live, freeze layers 0-3) → [B, Lg, 768] |
| ↓ proj(768→512) |
| PROTEIN: precomputed ESM-C → [B, Lp, 960] ↓ |
| ↓ proj(960→512) ↓ |
| 2× CrossAttentionBlock(d=512, 8 heads, FFN=1024) ↓ |
| ↓ ↓ |
| SHARED mask-aware SWE(d=512, S=512, R=64) ↓ |
| ↓ ↓ |
| [B, 512] [B, 512] |
| ↓ element-wise product + sum |
| [B, 1024] |
| ↓ MLP → binding score |
| |
| This release exposes the manuscript-facing AFFINose architecture: BERTose glycan tokens, per-residue ESM-C protein embeddings, bidirectional cross-attention, pooled fusion and scalar interaction scoring. |
| """ |
|
|
| import os |
| import sys |
| import math |
| from pathlib import Path |
| from typing import Dict, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
| def _default_bertose_root() -> Path: |
| """Resolve the BERTose source root without assuming a specific local path.""" |
| env_root = os.environ.get("BERTOSE_ROOT") or os.environ.get("BERTOSE_REPO_ROOT") |
| if env_root: |
| return Path(env_root).expanduser().resolve() |
|
|
| here = Path(__file__).resolve() |
| for parent in here.parents: |
| if (parent / "src").exists() or (parent / "bertose_model.py").exists(): |
| return parent |
|
|
| return here.parent |
|
|
|
|
| BERTOSE_ROOT = _default_bertose_root() |
|
|
|
|
| def _ensure_bertose_imports(): |
| """Add BERTose source directories to sys.path if not already present.""" |
| source_dir = Path(__file__).resolve().parent |
| roots = [ |
| str(source_dir), |
| str(BERTOSE_ROOT), |
| str(BERTOSE_ROOT / "src"), |
| ] |
| for root in roots: |
| if root not in sys.path: |
| sys.path.insert(0, root) |
|
|
|
|
| def load_bertose_config(): |
| """Create BERTose config matching the BERTose glycan encoder checkpoint.""" |
| _ensure_bertose_imports() |
| try: |
| from model.bertose_model import MultimodalGlycanBERTConfig |
| except ModuleNotFoundError: |
| from bertose_model import MultimodalGlycanBERTConfig |
|
|
| return MultimodalGlycanBERTConfig( |
| seq_vocab_size=2200, |
| use_cnn_frontend=True, |
| ) |
|
|
|
|
| def load_bertose_encoder( |
| checkpoint_path: str, freeze_layers: int = 4 |
| ): |
| """ |
| Load BERTose sequence encoder with pretrained weights. |
| |
| Args: |
| checkpoint_path: Path to pretrained BERTose checkpoint. |
| freeze_layers: Number of transformer layers to freeze (0-indexed). |
| |
| Returns: |
| Tuple of (bertose_config, seq_embeddings, seq_layers). |
| """ |
| _ensure_bertose_imports() |
| try: |
| from model.bertose_model import ( |
| MultimodalGlycanBERT, |
| MultimodalGlycanBERTConfig, |
| ) |
| except ModuleNotFoundError: |
| from bertose_model import ( |
| MultimodalGlycanBERT, |
| MultimodalGlycanBERTConfig, |
| ) |
|
|
| |
| ckpt = torch.load(checkpoint_path, map_location="cpu") |
| state_dict = ckpt.get("model_state_dict", ckpt) |
|
|
| |
| vocab_size = state_dict["seq_embeddings.token_embeddings.weight"].shape[0] |
| max_pos = state_dict["seq_embeddings.position_embeddings.weight"].shape[0] |
| config = MultimodalGlycanBERTConfig( |
| seq_vocab_size=vocab_size, |
| seq_max_length=max_pos, |
| use_cnn_frontend=True, |
| ) |
|
|
| |
| model = MultimodalGlycanBERT(config) |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| loaded = len(state_dict) - len(unexpected) |
| print(f" Loaded {loaded}/{len(state_dict)} pretrained weight tensors") |
| print(f" ({len(missing)} missing in checkpoint, {len(unexpected)} unexpected)") |
|
|
| seq_embeddings = model.seq_embeddings |
| seq_layers = model.seq_layers |
|
|
| |
| for param in seq_embeddings.parameters(): |
| param.requires_grad = False |
| for i in range(min(freeze_layers, len(seq_layers))): |
| for param in seq_layers[i].parameters(): |
| param.requires_grad = False |
|
|
| trainable = sum( |
| p.numel() for p in seq_embeddings.parameters() if p.requires_grad |
| ) |
| trainable += sum( |
| p.numel() |
| for layer in seq_layers |
| for p in layer.parameters() |
| if p.requires_grad |
| ) |
| total = sum(p.numel() for p in seq_embeddings.parameters()) |
| total += sum( |
| p.numel() for layer in seq_layers for p in layer.parameters() |
| ) |
| print( |
| f" BERTose encoder: {total:,} params total, " |
| f"{trainable:,} trainable (frozen layers 0-{freeze_layers - 1})" |
| ) |
|
|
| return config, seq_embeddings, seq_layers |
|
|
|
|
| |
| |
| |
|
|
|
|
| def differentiable_interp1d( |
| x: torch.Tensor, y: torch.Tensor, xnew: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Fully differentiable 1D linear interpolation. |
| |
| Gradients flow through y (values) back to theta projection and |
| earlier layers. |
| |
| Args: |
| x: [B, N] sorted input coordinates (detached) |
| y: [B, N] values at x positions (REQUIRES grad flow!) |
| xnew: [B, R] query coordinates |
| |
| Returns: |
| [B, R] interpolated values |
| """ |
| n_pts = x.shape[1] |
|
|
| |
| ind = torch.searchsorted( |
| x.contiguous().detach(), xnew.contiguous().detach() |
| ) |
| ind = ind.clamp(1, n_pts - 1) |
|
|
| |
| x_lo = torch.gather(x, 1, ind - 1) |
| x_hi = torch.gather(x, 1, ind) |
| y_lo = torch.gather(y, 1, ind - 1) |
| y_hi = torch.gather(y, 1, ind) |
|
|
| |
| denom = (x_hi - x_lo).clamp(min=1e-8) |
| alpha = ((xnew - x_lo) / denom).clamp(0, 1) |
|
|
| |
| return y_lo + alpha * (y_hi - y_lo) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class SWE_Pooling(nn.Module): |
| """ |
| Sliced-Wasserstein Embedding pooling. |
| Maps variable-length token embeddings [B, L, d_in] => [B, num_slices]. |
| |
| Includes mask-aware sorting and degenerate-sample handling. |
| """ |
|
|
| def __init__( |
| self, |
| d_in: int, |
| num_slices: int, |
| num_ref_points: int, |
| freeze_swe: bool = False, |
| ): |
| super().__init__() |
| self.num_slices = num_slices |
| self.num_ref_points = num_ref_points |
|
|
| |
| ref = torch.linspace(-1, 1, num_ref_points).unsqueeze(1).repeat( |
| 1, num_slices |
| ) |
| self.reference = nn.Parameter(ref, requires_grad=not freeze_swe) |
|
|
| |
| self.theta = nn.utils.weight_norm( |
| nn.Linear(d_in, num_slices, bias=False), dim=0 |
| ) |
| self.theta.weight_g.data = torch.ones_like(self.theta.weight_g.data) |
| self.theta.weight_g.requires_grad = False |
| nn.init.normal_(self.theta.weight_v) |
|
|
| |
| self.weight = nn.Linear(num_ref_points, 1, bias=False) |
|
|
| if freeze_swe: |
| self.theta.weight_v.requires_grad = False |
| self.reference.requires_grad = False |
|
|
| def forward( |
| self, x: torch.Tensor, mask: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x: [B, L, d_in] token embeddings |
| mask: [B, L] attention mask (1=valid, 0=padding) |
| |
| Returns: |
| [B, num_slices] fixed-length representation |
| """ |
| batch_size, seq_len, _ = x.shape |
| device = x.device |
|
|
| |
| if seq_len == 1: |
| x_slices = self.theta(x) |
| return x_slices.squeeze(1) |
|
|
| |
| x_slices = self.theta(x) |
|
|
| |
| if mask is not None: |
| mask_exp = mask.unsqueeze(-1).expand_as(x_slices) |
| x_slices = x_slices.masked_fill(mask_exp == 0, float("-inf")) |
|
|
| x_sorted, _ = torch.sort(x_slices, dim=1) |
|
|
| |
| if mask is not None: |
| valid_counts = mask.sum(dim=1).long() |
| degenerate_mask = valid_counts < 2 |
| safe_counts = valid_counts.clamp(min=2) |
| max_valid = safe_counts.max().item() |
|
|
| |
| starts = (seq_len - safe_counts).unsqueeze(1) |
| offsets = torch.arange(max_valid, device=device).unsqueeze(0) |
| raw_idx = starts + offsets |
| gather_idx = raw_idx.clamp(max=seq_len - 1).unsqueeze(-1).expand( |
| batch_size, max_valid, self.num_slices |
| ) |
| x_sorted = torch.gather(x_sorted, 1, gather_idx) |
|
|
| |
| x_sorted = x_sorted.masked_fill( |
| x_sorted == float("-inf"), 0.0 |
| ) |
| n_eff = max_valid |
| else: |
| degenerate_mask = None |
| n_eff = seq_len |
|
|
| |
| x_coord = ( |
| torch.linspace(0, 1, n_eff, device=device) |
| .unsqueeze(0) |
| .expand(batch_size * self.num_slices, -1) |
| ) |
| x_flat = x_sorted.permute(0, 2, 1).reshape( |
| batch_size * self.num_slices, n_eff |
| ) |
| xnew = ( |
| torch.linspace(0, 1, self.num_ref_points, device=device) |
| .unsqueeze(0) |
| .expand(batch_size * self.num_slices, -1) |
| ) |
|
|
| y_intp = differentiable_interp1d(x_coord, x_flat, xnew) |
| x_interp = y_intp.view( |
| batch_size, self.num_slices, self.num_ref_points |
| ).permute(0, 2, 1) |
|
|
| |
| r_expanded = self.reference.expand_as(x_interp) |
| embeddings = (r_expanded - x_interp).permute(0, 2, 1) |
|
|
| |
| weighted = self.weight(embeddings).sum(dim=-1) |
|
|
| |
| if degenerate_mask is not None and degenerate_mask.any(): |
| weighted = weighted.masked_fill( |
| degenerate_mask.unsqueeze(-1), 0.0 |
| ) |
|
|
| return weighted |
|
|
|
|
| |
| |
| |
|
|
|
|
| class CrossAttentionBlock(nn.Module): |
| """ |
| Bidirectional cross-attention between glycan and protein tokens. |
| |
| Glycan tokens attend to protein residues (Q=glycan, KV=protein) |
| Protein residues attend to glycan tokens (Q=protein, KV=glycan) |
| |
| Includes NaN guard for all-masked keys and padding-position zeroing. |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| num_heads: int, |
| ffn_dim: int, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
|
|
| |
| self.glycan_cross_attn = nn.MultiheadAttention( |
| d_model, num_heads, dropout=dropout, batch_first=True |
| ) |
| self.glycan_norm1 = nn.LayerNorm(d_model) |
| self.glycan_ffn = nn.Sequential( |
| nn.Linear(d_model, ffn_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(ffn_dim, d_model), |
| nn.Dropout(dropout), |
| ) |
| self.glycan_norm2 = nn.LayerNorm(d_model) |
|
|
| |
| self.protein_cross_attn = nn.MultiheadAttention( |
| d_model, num_heads, dropout=dropout, batch_first=True |
| ) |
| self.protein_norm1 = nn.LayerNorm(d_model) |
| self.protein_ffn = nn.Sequential( |
| nn.Linear(d_model, ffn_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(ffn_dim, d_model), |
| nn.Dropout(dropout), |
| ) |
| self.protein_norm2 = nn.LayerNorm(d_model) |
|
|
| def forward( |
| self, |
| glycan: torch.Tensor, |
| protein: torch.Tensor, |
| glycan_mask: Optional[torch.Tensor] = None, |
| protein_mask: Optional[torch.Tensor] = None, |
| return_attention: bool = False, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Returns updated (glycan, protein) enriched with cross-modal info. |
| |
| NaN guard: nn.MultiheadAttention produces NaN when ALL keys |
| are masked. We replace NaN→0 so residual preserves query. |
| """ |
| |
| g_key_pad = (~glycan_mask.bool()) if glycan_mask is not None else None |
| p_key_pad = ( |
| (~protein_mask.bool()) if protein_mask is not None else None |
| ) |
|
|
| |
| g_cross, g_attn_weights = self.glycan_cross_attn( |
| query=glycan, key=protein, value=protein, |
| key_padding_mask=p_key_pad, |
| need_weights=return_attention, |
| average_attn_weights=False, |
| ) |
| g_cross = torch.nan_to_num(g_cross, nan=0.0) |
| glycan = self.glycan_norm1(glycan + g_cross) |
| glycan = self.glycan_norm2(glycan + self.glycan_ffn(glycan)) |
| if glycan_mask is not None: |
| glycan = glycan * glycan_mask.unsqueeze(-1) |
|
|
| |
| p_cross, p_attn_weights = self.protein_cross_attn( |
| query=protein, key=glycan, value=glycan, |
| key_padding_mask=g_key_pad, |
| need_weights=return_attention, |
| average_attn_weights=False, |
| ) |
| p_cross = torch.nan_to_num(p_cross, nan=0.0) |
| protein = self.protein_norm1(protein + p_cross) |
| protein = self.protein_norm2(protein + self.protein_ffn(protein)) |
| if protein_mask is not None: |
| protein = protein * protein_mask.unsqueeze(-1) |
|
|
| if return_attention: |
| attn_dict = { |
| "glycan_to_protein": g_attn_weights, |
| "protein_to_glycan": p_attn_weights, |
| } |
| return glycan, protein, attn_dict |
|
|
| return glycan, protein |
|
|
|
|
| |
| |
| |
|
|
|
|
| class AffinoseInteractionModel(nn.Module): |
| """ |
| Glycan-protein interaction predictor with cross-attention. |
| |
| Glycan: Live BERTose (partially frozen) → per-token [B, Lg, 768] |
| Protein: Precomputed ESM-C per-residue [B, Lp, 960] |
| Cross-attention: 2 bidirectional layers in shared 512-dim space |
| SWE: Variable-length → fixed [B, 512] for each side |
| Interaction: product + sum → MLP → scalar |
| """ |
|
|
| def __init__( |
| self, |
| seq_embeddings: nn.Module, |
| seq_layers: nn.ModuleList, |
| glycan_dim: int = 768, |
| protein_dim: int = 960, |
| shared_dim: int = 512, |
| num_cross_layers: int = 2, |
| num_heads: int = 8, |
| ffn_dim: int = 1024, |
| swe_slices: int = 512, |
| swe_ref_points: int = 64, |
| head_hidden: int = 256, |
| dropout: float = 0.1, |
| separate_swe: bool = False, |
| pooling_mode: str = "swe", |
| interaction_mode: str = "product_sum", |
| use_cross_attention: bool = True, |
| ): |
| """ |
| Args: |
| seq_embeddings: Pretrained BERTose embedding layer. |
| seq_layers: Pretrained BERTose transformer layers. |
| glycan_dim: BERTose output dimension (768). |
| protein_dim: ESM-C per-residue dimension (960). |
| shared_dim: Shared space for cross-attention (512). |
| num_cross_layers: Number of cross-attention blocks. |
| num_heads: Attention heads per block. |
| ffn_dim: FFN hidden dim in cross-attention. |
| swe_slices: Number of SWE projection directions. |
| swe_ref_points: Number of SWE reference distribution points. |
| head_hidden: MLP head hidden dimension. |
| dropout: Dropout rate. |
| separate_swe: If True, use independent SWE modules. |
| pooling_mode: 'swe', 'mean', or 'joint_swe'. |
| interaction_mode: 'product_sum' or 'concat'. |
| use_cross_attention: If False, skip cross-attention. |
| """ |
| super().__init__() |
|
|
| self.separate_swe = separate_swe |
| self.pooling_mode = pooling_mode |
| self.interaction_mode = interaction_mode |
| self.use_cross_attention = use_cross_attention |
|
|
| print(f" Architecture config:") |
| print(f" cross_attention={use_cross_attention}") |
| print(f" pooling_mode={pooling_mode}") |
| print(f" interaction_mode={interaction_mode}") |
|
|
| |
| self.seq_embeddings = seq_embeddings |
| self.seq_layers = seq_layers |
|
|
| |
| self.glycan_proj = nn.Sequential( |
| nn.Linear(glycan_dim, shared_dim), |
| nn.LayerNorm(shared_dim), |
| ) |
| self.protein_proj = nn.Sequential( |
| nn.Linear(protein_dim, shared_dim), |
| nn.LayerNorm(shared_dim), |
| ) |
|
|
| |
| if use_cross_attention: |
| self.cross_attention = nn.ModuleList([ |
| CrossAttentionBlock( |
| d_model=shared_dim, |
| num_heads=num_heads, |
| ffn_dim=ffn_dim, |
| dropout=dropout, |
| ) |
| for _ in range(num_cross_layers) |
| ]) |
| else: |
| self.cross_attention = nn.ModuleList() |
|
|
| |
| if pooling_mode == "swe": |
| if separate_swe: |
| self.swe_glycan = SWE_Pooling( |
| d_in=shared_dim, num_slices=swe_slices, |
| num_ref_points=swe_ref_points, |
| ) |
| self.swe_protein = SWE_Pooling( |
| d_in=shared_dim, num_slices=swe_slices, |
| num_ref_points=swe_ref_points, |
| ) |
| pool_out_dim = swe_slices |
| else: |
| self.swe = SWE_Pooling( |
| d_in=shared_dim, num_slices=swe_slices, |
| num_ref_points=swe_ref_points, |
| ) |
| pool_out_dim = swe_slices |
| elif pooling_mode == "mean": |
| pool_out_dim = shared_dim |
| elif pooling_mode == "joint_swe": |
| self.swe_joint = SWE_Pooling( |
| d_in=shared_dim, num_slices=swe_slices, |
| num_ref_points=swe_ref_points, |
| ) |
| pool_out_dim = swe_slices |
|
|
| |
| if pooling_mode == "joint_swe": |
| head_input_dim = pool_out_dim |
| else: |
| head_input_dim = 2 * pool_out_dim |
|
|
| self.head = nn.Sequential( |
| nn.Linear(head_input_dim, head_hidden), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(head_hidden, head_hidden), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(head_hidden, 1), |
| ) |
|
|
| |
| self.apply(self._init_weights) |
| self._count_params() |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| """Xavier init for Linear, skip weight-normed SWE modules.""" |
| if isinstance(module, nn.Linear): |
| if hasattr(module, "weight_v"): |
| return |
| nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
|
|
| def _count_params(self) -> None: |
| """Log parameter counts.""" |
| total = sum(p.numel() for p in self.parameters()) |
| trainable = sum( |
| p.numel() for p in self.parameters() if p.requires_grad |
| ) |
| print(f"AffinoseInteractionModel: {total:,} total, {trainable:,} trainable") |
|
|
| def _masked_mean_pool(self, x, mask): |
| """Masked mean pooling: average valid tokens only.""" |
| mask_expanded = mask.unsqueeze(-1) |
| x_masked = x * mask_expanded |
| summed = x_masked.sum(dim=1) |
| counts = mask.sum(dim=1, keepdim=True).clamp(min=1) |
| return summed / counts |
|
|
| def forward( |
| self, |
| token_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| branch_depths: torch.Tensor, |
| linkage_types: torch.Tensor, |
| protein_emb: torch.Tensor, |
| protein_mask: torch.Tensor, |
| log_conc: Optional[torch.Tensor] = None, |
| has_conc: Optional[torch.Tensor] = None, |
| return_attention: bool = False, |
| ) -> torch.Tensor: |
| """ |
| Forward pass with cross-attention. |
| |
| Args: |
| token_ids: [B, Lg] BPE token IDs. |
| attention_mask: [B, Lg] glycan attention mask (1=valid, 0=pad). |
| branch_depths: [B, Lg] branch depth per token. |
| linkage_types: [B, Lg] linkage type per token. |
| protein_emb: [B, Lp, protein_dim] per-residue ESM-C embeddings. |
| protein_mask: [B, Lp] protein attention mask (1=valid, 0=pad). |
| |
| Returns: |
| [B] binding score predictions. |
| """ |
| |
| x = self.seq_embeddings(token_ids, branch_depths, linkage_types) |
| for layer in self.seq_layers: |
| x = layer(x, attention_mask) |
| |
|
|
| |
| glycan_mask = attention_mask |
|
|
| |
| glycan = self.glycan_proj(x) |
| protein = self.protein_proj(protein_emb) |
|
|
| |
| |
| glycan = glycan * glycan_mask.unsqueeze(-1) |
| protein = protein * protein_mask.unsqueeze(-1) |
|
|
| |
| all_attention_maps = [] |
| if self.use_cross_attention: |
| for cross_layer in self.cross_attention: |
| if return_attention: |
| glycan, protein, attn_dict = cross_layer( |
| glycan, protein, glycan_mask, protein_mask, |
| return_attention=True, |
| ) |
| all_attention_maps.append(attn_dict) |
| else: |
| glycan, protein = cross_layer( |
| glycan, protein, glycan_mask, protein_mask |
| ) |
|
|
| |
| if self.pooling_mode == "joint_swe": |
| joint_tokens = torch.cat([glycan, protein], dim=1) |
| joint_mask = torch.cat([glycan_mask, protein_mask], dim=1) |
| pooled = self.swe_joint(joint_tokens, joint_mask) |
| return self.head(pooled).squeeze(-1) |
|
|
| if self.pooling_mode == "swe": |
| if self.separate_swe: |
| glycan_pooled = self.swe_glycan(glycan, glycan_mask) |
| protein_pooled = self.swe_protein(protein, protein_mask) |
| else: |
| glycan_pooled = self.swe(glycan, glycan_mask) |
| protein_pooled = self.swe(protein, protein_mask) |
| elif self.pooling_mode == "mean": |
| glycan_pooled = self._masked_mean_pool(glycan, glycan_mask) |
| protein_pooled = self._masked_mean_pool(protein, protein_mask) |
|
|
| |
| if self.interaction_mode == "product_sum": |
| interaction = torch.cat([ |
| glycan_pooled * protein_pooled, |
| glycan_pooled + protein_pooled, |
| ], dim=-1) |
| elif self.interaction_mode == "concat": |
| interaction = torch.cat([ |
| glycan_pooled, protein_pooled, |
| ], dim=-1) |
|
|
| |
| out = self.head(interaction).squeeze(-1) |
| if return_attention and all_attention_maps: |
| return out, all_attention_maps |
| return out |
|
|
|
|
| |
| |
| |
|
|
|
|
| class AffinoseInteractionLoss(nn.Module): |
| """MSE loss for regression.""" |
|
|
| def __init__(self): |
| super().__init__() |
| self.mse = nn.MSELoss() |
|
|
| def forward( |
| self, pred: torch.Tensor, target: torch.Tensor |
| ) -> torch.Tensor: |
| """Compute MSE loss.""" |
| return self.mse(pred, target) |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| print("=" * 60) |
| print("AFFINose interaction model architecture sanity check") |
| print("=" * 60) |
|
|
| |
| class MockEmbeddings(nn.Module): |
| """Mock BERTose embeddings for local testing.""" |
|
|
| def __init__(self, dim: int = 768): |
| super().__init__() |
| self.proj = nn.Linear(64, dim) |
|
|
| def forward(self, token_ids, branch_depths, linkage_types): |
| """Return random embeddings.""" |
| batch_size, seq_len = token_ids.shape |
| return torch.randn(batch_size, seq_len, 768) |
|
|
| class MockLayer(nn.Module): |
| """Mock transformer layer.""" |
|
|
| def forward(self, x, mask): |
| """Identity.""" |
| return x |
|
|
| seq_emb = MockEmbeddings() |
| seq_layers = nn.ModuleList([MockLayer() for _ in range(12)]) |
|
|
| model = AffinoseInteractionModel( |
| seq_embeddings=seq_emb, |
| seq_layers=seq_layers, |
| glycan_dim=768, |
| protein_dim=960, |
| ) |
|
|
| |
| batch_size = 4 |
| lg = 37 |
| lp = 150 |
|
|
| token_ids = torch.randint(0, 100, (batch_size, lg)) |
| attention_mask = torch.ones(batch_size, lg).float() |
| branch_depths = torch.zeros(batch_size, lg, dtype=torch.long) |
| linkage_types = torch.zeros(batch_size, lg, dtype=torch.long) |
| protein_emb = torch.randn(batch_size, lp, 960) |
| protein_mask = torch.ones(batch_size, lp).float() |
|
|
| out = model( |
| token_ids=token_ids, |
| attention_mask=attention_mask, |
| branch_depths=branch_depths, |
| linkage_types=linkage_types, |
| protein_emb=protein_emb, |
| protein_mask=protein_mask, |
| ) |
|
|
| print(f"\nInput shapes:") |
| print(f" Glycan tokens: {token_ids.shape}") |
| print(f" Protein emb: {protein_emb.shape}") |
| print(f"\nOutput shape: {out.shape} — values: {out.detach()}") |
|
|
| |
| loss_fn = AffinoseInteractionLoss() |
| target = torch.rand(batch_size) |
| loss = loss_fn(out, target) |
| print(f"\nLoss: {loss.item():.6f}") |
|
|
| |
| loss.backward() |
| grad_count = sum( |
| 1 for p in model.parameters() if p.grad is not None and p.requires_grad |
| ) |
| total_trainable = sum( |
| 1 for p in model.parameters() if p.requires_grad |
| ) |
| print(f"Gradients: {grad_count}/{total_trainable} trainable params") |
|
|
| print(f"\nAffinose sanity check passed.") |
|
|