""" BERTose model Core glycan representation model with three modalities: - Sequence (WURCS atomic tokenization) - MS (mass spectrometry peaks, RT, intensity) - 3D structure (VQ-VAE discrete tokens, 4 per residue) Each modality has its own encoder, with cross-attention for sequence-structure alignment. """ import torch import torch.nn as nn from typing import Dict, Optional, Tuple import math try: from .bertose_layers import GlycanBERTConfig, GlycanBERTEmbeddings, GlycanBERTLayer except ImportError: from bertose_layers import GlycanBERTConfig, GlycanBERTEmbeddings, GlycanBERTLayer class ConvGlycanBERTEmbeddings(nn.Module): """ Improved Convolutional front-end that mixes local WURCS context before the Transformer. Key improvements over original: 1. Position embeddings added BEFORE convolution (provides spatial context to conv) 2. Residual connection (conv enriches embeddings rather than replacing them) 3. Multi-scale convolutions (kernel sizes 3, 5, 7) for better receptive field 4. Proper layer normalization on the residual path """ def __init__(self, config): super().__init__() self.token_embeddings = nn.Embedding( config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id ) self.position_embeddings = nn.Embedding( config.max_position_embeddings, config.hidden_size ) # Branch depth embeddings encode depth in the glycan tree. max_branch_depth = getattr(config, "max_branch_depth", 8) self.branch_embeddings = nn.Embedding(max_branch_depth, config.hidden_size) # Linkage type embeddings encode glycosidic bond chemistry. # 0=none, 1=1-3, 2=1-4, 3=1-6, etc. num_linkage_types = getattr(config, "num_linkage_types", 9) self.linkage_embeddings = nn.Embedding(num_linkage_types, config.hidden_size) # Multi-scale convolutions for different receptive fields kernel_size = getattr(config, "cnn_kernel_size", 3) # Split channels evenly: 256 + 256 + 256 = 768 for hidden_size=768 channels_per_scale = config.hidden_size // 3 self.conv_layers = nn.ModuleList([ nn.Conv1d( in_channels=config.hidden_size, out_channels=channels_per_scale, kernel_size=kernel_size + 2 * i, # Kernels: 3, 5, 7 padding=(kernel_size + 2 * i) // 2, # Same padding ) for i in range(3) ]) self.conv_activation = nn.GELU() self.conv_proj = nn.Linear(channels_per_scale * 3, config.hidden_size) # Project concatenated back self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.conv_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), ) self.hidden_size = config.hidden_size def forward(self, input_ids, branch_depths=None, linkage_types=None): seq_len = input_ids.shape[1] # Step 1: Token + Position embeddings FIRST (provides spatial context to conv) x = self.token_embeddings(input_ids) # (batch, seq, hidden) position_ids = self.position_ids[:, :seq_len] x = x + self.position_embeddings(position_ids) # Add branch depth embeddings. if branch_depths is not None: # Clamp to valid range branch_depths = branch_depths.clamp(0, self.branch_embeddings.num_embeddings - 1) x = x + self.branch_embeddings(branch_depths) # Add linkage type embeddings. if linkage_types is not None: linkage_types = linkage_types.clamp(0, self.linkage_embeddings.num_embeddings - 1) x = x + self.linkage_embeddings(linkage_types) x = self.LayerNorm(x) # Step 2: Multi-scale convolution with RESIDUAL connection # Convolution expects (batch, hidden, seq) conv_in = x.permute(0, 2, 1) # Apply multi-scale convolutions and concatenate conv_outputs = [] for conv in self.conv_layers: conv_out = self.conv_activation(conv(conv_in)) conv_outputs.append(conv_out) # Concatenate multi-scale features and project back conv_out = torch.cat(conv_outputs, dim=1) # (batch, hidden, seq) conv_out = conv_out.permute(0, 2, 1) # (batch, seq, hidden) conv_out = self.conv_proj(conv_out) # Project to correct size # Step 3: Residual connection - conv ENRICHES rather than replaces x = self.conv_norm(x + self.dropout(conv_out)) return x def create_residue_level_mask( seq_residue_ids: torch.Tensor, # (batch, N_seq) struct_residue_ids: torch.Tensor # (batch, N_struct) ) -> torch.Tensor: """ Create residue-level attention mask for cross-attention. Maps WURCS tokens to VQ-VAE structural tokens based on residue IDs. A WURCS token with residue_id=0 can only attend to VQ-VAE tokens with residue_id=0. Args: seq_residue_ids: Residue IDs for sequence tokens (batch, N_seq) struct_residue_ids: Residue IDs for structural tokens (batch, N_struct) Returns: Boolean mask (batch, N_seq, N_struct) where True = can attend """ # Expand dimensions for broadcasting # seq: (batch, N_seq, 1) # struct: (batch, 1, N_struct) mask = seq_residue_ids.unsqueeze(2) == struct_residue_ids.unsqueeze(1) # Shape: (batch, N_seq, N_struct) # Mask out structural tokens (residue_id = -1) and MS tokens (residue_id = -2) # Only tokens with residue_id >= 0 can attend mask &= (seq_residue_ids.unsqueeze(2) >= 0) return mask # True = can attend, False = cannot attend class MultimodalGlycanBERTConfig: """Configuration for the BERTose model.""" def __init__( self, # Sequence modality seq_vocab_size: int = 166, seq_hidden_size: int = 768, seq_num_layers: int = 12, seq_num_heads: int = 12, seq_max_length: int = 512, # MS modality ms_vocab_size: int = 242, ms_hidden_size: int = 384, ms_num_layers: int = 6, ms_num_heads: int = 6, ms_max_length: int = 150, # 3D structure modality struct_vocab_size: int = 1024, # VQ-VAE codebook size struct_hidden_size: int = 512, struct_num_layers: int = 8, struct_num_heads: int = 8, struct_max_length: int = 200, use_3d: bool = True, # Cross-attention use_cross_attention: bool = True, cross_attn_num_heads: int = 8, # Fusion fusion_hidden_size: int = 768, fusion_num_layers: int = 2, # Training hidden_dropout_prob: float = 0.1, attention_probs_dropout_prob: float = 0.1, layer_norm_eps: float = 1e-12, initializer_range: float = 0.02, # Conv front-end use_cnn_frontend: bool = True, cnn_kernel_size: int = 3, # Loss weights seq_loss_weight: float = 0.60, ms_loss_weight: float = 0.15, struct_loss_weight: float = 0.25, # Token IDs pad_token_id: int = 0, mask_token_id: int = 1, ): # Sequence config self.seq_vocab_size = seq_vocab_size self.seq_hidden_size = seq_hidden_size self.seq_num_layers = seq_num_layers self.seq_num_heads = seq_num_heads self.seq_max_length = seq_max_length # MS config self.ms_vocab_size = ms_vocab_size self.ms_vocab_offset = seq_vocab_size # MS tokens start at 166 self.ms_total_vocab_size = seq_vocab_size + ms_vocab_size # 408 total self.ms_hidden_size = ms_hidden_size self.ms_num_layers = ms_num_layers self.ms_num_heads = ms_num_heads self.ms_max_length = ms_max_length # Structure config self.struct_vocab_size = struct_vocab_size self.struct_hidden_size = struct_hidden_size self.struct_num_layers = struct_num_layers self.struct_num_heads = struct_num_heads self.struct_max_length = struct_max_length self.use_3d = use_3d # Cross-attention config self.use_cross_attention = use_cross_attention self.cross_attn_num_heads = cross_attn_num_heads # Fusion config self.fusion_hidden_size = fusion_hidden_size self.fusion_num_layers = fusion_num_layers # Training config self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.layer_norm_eps = layer_norm_eps self.initializer_range = initializer_range # Conv front-end self.use_cnn_frontend = use_cnn_frontend self.cnn_kernel_size = cnn_kernel_size # Loss weights self.seq_loss_weight = seq_loss_weight self.ms_loss_weight = ms_loss_weight self.struct_loss_weight = struct_loss_weight self.dist_loss_weight = 0.25 # Token IDs self.pad_token_id = pad_token_id self.mask_token_id = mask_token_id def to_seq_config(self) -> GlycanBERTConfig: """Convert to sequence-only config.""" return GlycanBERTConfig( vocab_size=self.seq_vocab_size, hidden_size=self.seq_hidden_size, num_hidden_layers=self.seq_num_layers, num_attention_heads=self.seq_num_heads, intermediate_size=self.seq_hidden_size * 4, hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.seq_max_length, layer_norm_eps=self.layer_norm_eps, pad_token_id=self.pad_token_id, mask_token_id=self.mask_token_id, initializer_range=self.initializer_range, ) def to_ms_config(self) -> GlycanBERTConfig: """Convert to MS-only config.""" return GlycanBERTConfig( vocab_size=self.ms_total_vocab_size, hidden_size=self.ms_hidden_size, num_hidden_layers=self.ms_num_layers, num_attention_heads=self.ms_num_heads, intermediate_size=self.ms_hidden_size * 4, hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.ms_max_length, layer_norm_eps=self.layer_norm_eps, pad_token_id=self.pad_token_id, mask_token_id=self.mask_token_id, initializer_range=self.initializer_range, ) def to_struct_config(self) -> GlycanBERTConfig: """Convert to structure-only config.""" return GlycanBERTConfig( vocab_size=self.struct_vocab_size, hidden_size=self.struct_hidden_size, num_hidden_layers=self.struct_num_layers, num_attention_heads=self.struct_num_heads, intermediate_size=self.struct_hidden_size * 4, hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.struct_max_length, layer_norm_eps=self.layer_norm_eps, pad_token_id=self.pad_token_id, mask_token_id=self.mask_token_id, initializer_range=self.initializer_range, ) # ============================================================================= # Improvement #1: Monosaccharide-Level Pooling # ============================================================================= class MonosaccharidePooling(nn.Module): """ Pool token representations to monosaccharide level, then aggregate. This bridges the gap between token-level BERT and monosaccharide-level CNNs/GNNs. Uses monosaccharide_indices from the data to know where each residue starts. """ def __init__(self, hidden_size: int, num_attention_heads: int = 8, dropout: float = 0.1): super().__init__() self.hidden_size = hidden_size # Attention pooling over monosaccharide representations self.mono_attention = nn.MultiheadAttention( embed_dim=hidden_size, num_heads=num_attention_heads, dropout=dropout, batch_first=True ) self.mono_norm = nn.LayerNorm(hidden_size) # Final aggregation to single glycan representation self.glycan_query = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) self.glycan_attention = nn.MultiheadAttention( embed_dim=hidden_size, num_heads=num_attention_heads, dropout=dropout, batch_first=True ) self.glycan_norm = nn.LayerNorm(hidden_size) def forward( self, hidden_states: torch.Tensor, # (batch, seq_len, hidden) residue_ids: torch.Tensor, # (batch, seq_len) - which residue each token belongs to attention_mask: torch.Tensor = None, # (batch, seq_len) ) -> torch.Tensor: """ Pool tokens to monosaccharide level, then to glycan level. Returns: Glycan representation: (batch, hidden_size) """ batch_size = hidden_states.size(0) device = hidden_states.device # Get unique residue IDs per sample (excluding -1 padding) max_residues = 50 # Reasonable max for glycans # Pool tokens within each residue using mean pooling mono_reps = torch.zeros(batch_size, max_residues, self.hidden_size, device=device) mono_mask = torch.zeros(batch_size, max_residues, dtype=torch.bool, device=device) for b in range(batch_size): unique_residues = torch.unique(residue_ids[b][residue_ids[b] >= 0]) for i, rid in enumerate(unique_residues): if i >= max_residues: break token_mask = residue_ids[b] == rid if attention_mask is not None: token_mask = token_mask & (attention_mask[b] > 0) if token_mask.sum() > 0: mono_reps[b, i] = hidden_states[b][token_mask].mean(dim=0) mono_mask[b, i] = True # Apply attention over monosaccharide representations # Convert mask for attention: True = valid, need to invert for PyTorch key_padding_mask = ~mono_mask # True = ignore mono_out, _ = self.mono_attention( mono_reps, mono_reps, mono_reps, key_padding_mask=key_padding_mask ) mono_out = self.mono_norm(mono_reps + mono_out) # Aggregate to single glycan representation using learned query glycan_query = self.glycan_query.expand(batch_size, -1, -1) glycan_out, _ = self.glycan_attention( glycan_query, mono_out, mono_out, key_padding_mask=key_padding_mask ) glycan_out = self.glycan_norm(glycan_query + glycan_out) return glycan_out.squeeze(1) # (batch, hidden) # ============================================================================= # Improvement #2: Residue Type Embeddings # ============================================================================= # Common monosaccharide types vocabulary MONOSACCHARIDE_VOCAB = { '[PAD_MONO]': 0, '[UNK_MONO]': 1, 'Glc': 2, 'GlcNAc': 3, 'GlcA': 4, 'GlcN': 5, 'Gal': 6, 'GalNAc': 7, 'GalA': 8, 'GalN': 9, 'Man': 10, 'ManNAc': 11, 'ManA': 12, 'ManN': 13, 'Fuc': 14, 'Rha': 15, 'Xyl': 16, 'Ara': 17, 'Neu5Ac': 18, 'Neu5Gc': 19, 'Kdn': 20, 'Sia': 21, 'GalNAcA': 22, 'GlcNAcA': 23, 'IdoA': 24, 'GulA': 25, 'Rib': 26, 'Lyx': 27, 'All': 28, 'Alt': 29, 'Tal': 30, 'Ido': 31, 'Qui': 32, 'Oli': 33, 'Tyv': 34, 'Abe': 35, 'Par': 36, 'Dig': 37, 'Col': 38, 'Dha': 39, 'Kdo': 40, 'Hep': 41, 'NeuroGc': 42, 'Muramic': 43, 'LDManHep': 44, 'DDManHep': 45, 'Bac': 46, 'Pse': 47, 'Leg': 48, 'Aci': 49, '6dTal': 50, 'Fru': 51, 'Tag': 52, 'Sor': 53, 'Psi': 54, 'Sed': 55, 'MurNAc': 56, 'MurNGc': 57, 'Api': 58, 'Erwiniose': 59, 'Yer': 60, 'Thre': 61, # Add more as needed, up to ~70 } class ResidueTypeEmbeddings(nn.Module): """ Learnable embeddings for monosaccharide types. Instead of the model having to learn that 'a1221m' = Fucose from character patterns, we explicitly add a Fucose embedding to all tokens belonging to that residue. """ def __init__(self, hidden_size: int, num_mono_types: int = 70): super().__init__() self.mono_embeddings = nn.Embedding(num_mono_types, hidden_size) self.mono_vocab = MONOSACCHARIDE_VOCAB self.hidden_size = hidden_size def forward( self, token_embeddings: torch.Tensor, # (batch, seq_len, hidden) residue_ids: torch.Tensor, # (batch, seq_len) mono_type_ids: torch.Tensor = None, # (batch, max_residues) - monosaccharide type per residue ) -> torch.Tensor: """ Add residue type embeddings to token embeddings. Args: token_embeddings: Base token embeddings residue_ids: Which residue each token belongs to (-1 for special tokens) mono_type_ids: Monosaccharide type ID for each residue position Returns: Enhanced embeddings with residue type information """ if mono_type_ids is None: return token_embeddings batch_size, seq_len, _ = token_embeddings.shape enhanced = token_embeddings.clone() # Add mono type embedding to each token based on its residue for b in range(batch_size): for pos in range(seq_len): rid = residue_ids[b, pos].item() if rid >= 0 and rid < mono_type_ids.size(1): mono_id = mono_type_ids[b, rid] enhanced[b, pos] = enhanced[b, pos] + self.mono_embeddings(mono_id) return enhanced @staticmethod def get_mono_type_id(mono_name: str) -> int: """Convert monosaccharide name to type ID.""" return MONOSACCHARIDE_VOCAB.get(mono_name, MONOSACCHARIDE_VOCAB['[UNK_MONO]']) # ============================================================================= # Improvement #4: Relative Position Encoding for Glycan Trees # ============================================================================= class RelativePositionBias(nn.Module): """ Compute relative position bias for attention based on residue IDs. Tokens in the same residue get distance 0. Tokens in adjacent residues get distance ±1. This helps the model understand glycan tree structure. """ def __init__(self, num_heads: int, max_distance: int = 10): super().__init__() self.num_heads = num_heads self.max_distance = max_distance # Learnable bias for each relative distance (-max to +max) num_distances = 2 * max_distance + 1 self.relative_bias = nn.Embedding(num_distances, num_heads) def forward(self, residue_ids: torch.Tensor) -> torch.Tensor: """ Compute relative position bias. Args: residue_ids: (batch, seq_len) Returns: Bias to add to attention scores: (batch, num_heads, seq_len, seq_len) """ # Compute pairwise residue distances # (batch, seq_len, 1) - (batch, 1, seq_len) = (batch, seq_len, seq_len) distance = residue_ids.unsqueeze(2) - residue_ids.unsqueeze(1) # Clamp to max distance range and shift to 0-indexed distance_clamped = distance.clamp(-self.max_distance, self.max_distance) distance_idx = distance_clamped + self.max_distance # Now 0 to 2*max_distance # Look up bias: (batch, seq_len, seq_len, num_heads) bias = self.relative_bias(distance_idx) # Transpose to (batch, num_heads, seq_len, seq_len) bias = bias.permute(0, 3, 1, 2) return bias class CrossAttentionLayer(nn.Module): """ Cross-attention layer for sequence-structure alignment. Allows sequence tokens to attend to structural atoms using attention masks. """ def __init__(self, config: MultimodalGlycanBERTConfig): super().__init__() self.num_heads = config.cross_attn_num_heads self.hidden_size = config.seq_hidden_size self.head_dim = self.hidden_size // self.num_heads assert self.hidden_size % self.num_heads == 0, "hidden_size must be divisible by num_heads" # Query from sequence, Key/Value from structure (VQ-VAE tokens) self.query = nn.Linear(config.seq_hidden_size, self.hidden_size) self.key = nn.Linear(config.struct_hidden_size, self.hidden_size) self.value = nn.Linear(config.struct_hidden_size, self.hidden_size) self.output = nn.Linear(self.hidden_size, config.seq_hidden_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.layer_norm = nn.LayerNorm(config.seq_hidden_size, eps=config.layer_norm_eps) def forward( self, seq_hidden: torch.Tensor, # (batch, seq_len, seq_hidden) struct_hidden: torch.Tensor, # (batch, struct_len, struct_hidden) attention_mask: Optional[torch.Tensor] = None, # (batch, seq_len, struct_len) ) -> torch.Tensor: """ Apply cross-attention from sequence to structure. Args: seq_hidden: Sequence hidden states struct_hidden: Structure hidden states attention_mask: Boolean mask (True = can attend, False = cannot attend) Returns: Updated sequence hidden states """ batch_size, seq_len, _ = seq_hidden.shape struct_len = struct_hidden.shape[1] # Project to Q, K, V Q = self.query(seq_hidden) # (batch, seq_len, hidden) K = self.key(struct_hidden) # (batch, struct_len, hidden) V = self.value(struct_hidden) # (batch, struct_len, hidden) # Reshape for multi-head attention Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch, heads, seq_len, head_dim) K = K.view(batch_size, struct_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch, heads, struct_len, head_dim) V = V.view(batch_size, struct_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch, heads, struct_len, head_dim) # Compute attention scores scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) # (batch, heads, seq_len, struct_len) # Apply attention mask if attention_mask is not None: # attention_mask: (batch, seq_len, struct_len) -> (batch, 1, seq_len, struct_len) attention_mask = attention_mask.unsqueeze(1) # Convert boolean mask to float: True -> 0.0, False -> -10000.0 attention_mask = (~attention_mask).float() * -10000.0 scores = scores + attention_mask # Softmax and dropout attn_weights = torch.softmax(scores, dim=-1) # (batch, heads, seq_len, struct_len) attn_weights = self.dropout(attn_weights) # Apply attention to values context = torch.matmul(attn_weights, V) # (batch, heads, seq_len, head_dim) # Reshape back context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size) # Output projection output = self.output(context) output = self.dropout(output) # Residual connection + layer norm output = self.layer_norm(seq_hidden + output) return output class MultimodalGlycanBERT(nn.Module): """ BERTose model for glycan representation learning. Architecture: 1. Separate encoders for each modality (sequence, MS, 3D structure) 2. Cross-attention for sequence-structure alignment 3. Modality-specific MLM heads 4. Fusion layer for combined representation """ def __init__(self, config: MultimodalGlycanBERTConfig): super().__init__() self.config = config # ===== Sequence Encoder ===== seq_config = config.to_seq_config() seq_config.cnn_kernel_size = config.cnn_kernel_size if config.use_cnn_frontend: print(f"Enabled convolutional front-end (kernel={config.cnn_kernel_size})") self.seq_embeddings = ConvGlycanBERTEmbeddings(seq_config) else: self.seq_embeddings = GlycanBERTEmbeddings(seq_config) self.seq_layers = nn.ModuleList([GlycanBERTLayer(seq_config) for _ in range(seq_config.num_hidden_layers)]) self.seq_mlm_head = nn.Linear(seq_config.hidden_size, seq_config.vocab_size) # ===== MS Encoder ===== ms_config = config.to_ms_config() self.ms_embeddings = GlycanBERTEmbeddings(ms_config) self.ms_layers = nn.ModuleList([GlycanBERTLayer(ms_config) for _ in range(ms_config.num_hidden_layers)]) self.ms_mlm_head = nn.Linear(ms_config.hidden_size, ms_config.vocab_size) # ===== Structure Encoder (VQ-VAE tokens) ===== if config.use_3d: struct_config = config.to_struct_config() self.struct_embeddings = GlycanBERTEmbeddings(struct_config) self.struct_layers = nn.ModuleList([GlycanBERTLayer(struct_config) for _ in range(struct_config.num_hidden_layers)]) self.struct_mlm_head = nn.Linear(struct_config.hidden_size, struct_config.vocab_size) # Cross-attention layer (sequence → VQ-VAE structural tokens) if config.use_cross_attention: self.cross_attention = CrossAttentionLayer(config) # ===== Projection layers (align hidden sizes) ===== if config.ms_hidden_size != config.seq_hidden_size: self.ms_projection = nn.Linear(config.ms_hidden_size, config.seq_hidden_size) else: self.ms_projection = nn.Identity() if config.use_3d and config.struct_hidden_size != config.seq_hidden_size: self.struct_projection = nn.Linear(config.struct_hidden_size, config.seq_hidden_size) else: self.struct_projection = nn.Identity() # ===== Fusion Layer ===== # Concatenate seq + ms + struct fusion_input_size = config.seq_hidden_size * (3 if config.use_3d else 2) self.fusion_layer = nn.Sequential( nn.Linear(fusion_input_size, config.fusion_hidden_size), nn.LayerNorm(config.fusion_hidden_size, eps=config.layer_norm_eps), nn.GELU(), nn.Dropout(config.hidden_dropout_prob), nn.Linear(config.fusion_hidden_size, config.fusion_hidden_size), ) # ===== Distance Prediction Head (Topology) ===== # Project down to 128 dimensions first to reduce memory use. # (Batch, 256, 256, 768) -> (Batch, 256, 256, 128) reduces memory by 6x self.dist_proj = nn.Linear(config.seq_hidden_size, 128) self.distance_head = nn.Sequential( nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 1) ) # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): """Initialize weights.""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def forward( self, seq_token_ids: torch.Tensor, seq_attention_mask: torch.Tensor, seq_residue_ids: torch.Tensor, seq_branch_depths: Optional[torch.Tensor] = None, seq_linkage_types: Optional[torch.Tensor] = None, ms_token_ids: torch.Tensor = None, ms_attention_mask: torch.Tensor = None, has_ms: torch.Tensor = None, struct_token_ids: Optional[torch.Tensor] = None, struct_attention_mask: Optional[torch.Tensor] = None, struct_residue_ids: Optional[torch.Tensor] = None, has_3d: Optional[torch.Tensor] = None, seq_labels: Optional[torch.Tensor] = None, ms_labels: Optional[torch.Tensor] = None, struct_labels: Optional[torch.Tensor] = None, dist_labels: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Dict[str, torch.Tensor]: """ Forward pass for BERTose. Args: seq_token_ids: (batch_size, seq_len) - Sequence token IDs seq_attention_mask: (batch_size, seq_len) - Sequence attention mask seq_residue_ids: (batch_size, seq_len) - Sequence token residue IDs ms_token_ids: (batch_size, ms_len) - MS token IDs ms_attention_mask: (batch_size, ms_len) - MS attention mask has_ms: (batch_size,) - Boolean mask for samples with MS data struct_token_ids: (batch_size, struct_len) - Structure VQ-VAE token IDs (optional) struct_attention_mask: (batch_size, struct_len) - Structure attention mask (optional) struct_residue_ids: (batch_size, struct_len) - Structure token residue IDs (optional) has_3d: (batch_size,) - Boolean mask for samples with 3D data (optional) seq_labels: (batch_size, seq_len) - Masked sequence labels (optional) ms_labels: (batch_size, ms_len) - Masked MS labels (optional) struct_labels: (batch_size, struct_len) - Masked structure labels (optional) return_dict: Whether to return dict or tuple Returns: Dictionary containing logits, hidden states, losses, etc. """ batch_size = seq_token_ids.shape[0] device = seq_token_ids.device # ===== Sequence Encoder ===== # Pass branch_depths and linkage_types to embeddings for tree-aware encoding seq_hidden = self.seq_embeddings(seq_token_ids, seq_branch_depths, seq_linkage_types) for layer in self.seq_layers: seq_hidden = layer(seq_hidden, seq_attention_mask) seq_pooled = seq_hidden[:, 0, :] # [CLS] token seq_logits = self.seq_mlm_head(seq_hidden) # ===== Distance Predictions (Topology) ===== # Compute pairwise distance predictions # MEMORY OPTIMIZATION: Project to 128-dim first seq_hidden_small = self.dist_proj(seq_hidden) # (batch, seq_len, 128) # Expand for pairwise: (batch, seq_len, 1, 128) - (batch, 1, seq_len, 128) h_i = seq_hidden_small.unsqueeze(2) h_j = seq_hidden_small.unsqueeze(1) h_diff = torch.abs(h_i - h_j) # (batch, seq_len, seq_len, 128) - Much smaller! dist_predictions = self.distance_head(h_diff) # (batch, seq_len, seq_len, 1) # ===== MS Encoder ===== ms_hidden = None ms_pooled = None ms_logits = None if ms_token_ids is not None: ms_hidden = self.ms_embeddings(ms_token_ids) for layer in self.ms_layers: ms_hidden = layer(ms_hidden, ms_attention_mask) ms_pooled = ms_hidden[:, 0, :] # [CLS] token ms_logits = self.ms_mlm_head(ms_hidden) # Zero out MS representations for samples without MS data if has_ms is not None: has_ms_expanded = has_ms.unsqueeze(1).float() # (batch, 1) ms_pooled = ms_pooled * has_ms_expanded # ===== Structure Encoder ===== struct_pooled = None struct_logits = None struct_hidden = None if self.config.use_3d and struct_token_ids is not None: struct_hidden = self.struct_embeddings(struct_token_ids) for layer in self.struct_layers: struct_hidden = layer(struct_hidden, struct_attention_mask) struct_pooled = struct_hidden[:, 0, :] # [CLS] token struct_logits = self.struct_mlm_head(struct_hidden) # Zero out structure representations for samples without 3D data if has_3d is not None: has_3d_expanded = has_3d.unsqueeze(1).float() # (batch, 1) struct_pooled = struct_pooled * has_3d_expanded # ===== Cross-Attention (Sequence → VQ-VAE Structural Tokens) ===== # Use residue-level alignment between WURCS tokens and VQ-VAE tokens if self.config.use_cross_attention and struct_residue_ids is not None: # Create residue-level mask # WURCS token with residue_id=0 → VQ-VAE tokens with residue_id=0 residue_mask = create_residue_level_mask( seq_residue_ids=seq_residue_ids, struct_residue_ids=struct_residue_ids, ) # (batch, N_seq, N_struct) # Apply cross-attention: sequence tokens attend to VQ-VAE tokens seq_hidden = self.cross_attention( seq_hidden=seq_hidden, struct_hidden=struct_hidden, # VQ-VAE token features attention_mask=residue_mask, # Residue-based mask ) # Update seq_pooled after cross-attention seq_pooled = seq_hidden[:, 0, :] # ===== Fusion ===== # Project to common hidden size ms_pooled_projected = self.ms_projection(ms_pooled) if self.config.use_3d and struct_pooled is not None: struct_pooled_projected = self.struct_projection(struct_pooled) combined = torch.cat([seq_pooled, ms_pooled_projected, struct_pooled_projected], dim=-1) else: combined = torch.cat([seq_pooled, ms_pooled_projected], dim=-1) fused_repr = self.fusion_layer(combined) # ===== Compute Losses ===== total_loss = None seq_loss = None ms_loss = None struct_loss = None dist_loss = None if seq_labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-100) seq_loss = loss_fct( seq_logits.view(-1, self.config.seq_vocab_size), seq_labels.view(-1) ) if ms_labels is not None: ms_labels_masked = ms_labels.clone() ms_labels_masked[~has_ms] = -100 # Only compute loss if there are valid labels (not all -100) if (ms_labels_masked != -100).any(): loss_fct = nn.CrossEntropyLoss(ignore_index=-100) ms_loss = loss_fct( ms_logits.view(-1, self.config.ms_total_vocab_size), ms_labels_masked.view(-1) ) else: ms_loss = torch.tensor(0.0, device=seq_token_ids.device) if self.config.use_3d and struct_labels is not None and struct_logits is not None: struct_labels_masked = struct_labels.clone() if has_3d is not None: struct_labels_masked[~has_3d] = -100 # Only compute loss if there are valid labels (not all -100) if (struct_labels_masked != -100).any(): loss_fct = nn.CrossEntropyLoss(ignore_index=-100) struct_loss = loss_fct( struct_logits.view(-1, self.config.struct_vocab_size), struct_labels_masked.view(-1) ) else: struct_loss = torch.tensor(0.0, device=seq_token_ids.device) # ===== Distance Loss (Topology) ===== if dist_labels is not None: # dist_predictions: (Batch, Seq, Seq, 1) -> (Batch, Seq, Seq) preds = dist_predictions.squeeze(-1) # Create mask for valid distance pairs (label != -1) # Also respect attention mask to avoid padding valid_mask = (dist_labels != -1) & (seq_attention_mask.unsqueeze(1) * seq_attention_mask.unsqueeze(2) == 1) # DEBUG: Print once if not hasattr(self, '_dist_debug_printed'): print(f"[DIST DEBUG] dist_labels shape: {dist_labels.shape}, valid_mask.sum: {valid_mask.sum().item()}") self._dist_debug_printed = True if valid_mask.sum() > 0: # MSE loss on valid positions only loss_fct = nn.MSELoss() dist_loss = loss_fct(preds[valid_mask], dist_labels[valid_mask].float()) else: dist_loss = torch.tensor(0.0, device=seq_token_ids.device) else: # DEBUG: dist_labels is None if not hasattr(self, '_dist_none_printed'): print("[DIST DEBUG] dist_labels is None!") self._dist_none_printed = True # Weighted combination losses = [] if seq_loss is not None: losses.append(self.config.seq_loss_weight * seq_loss) if ms_loss is not None: losses.append(self.config.ms_loss_weight * ms_loss) if struct_loss is not None: losses.append(self.config.struct_loss_weight * struct_loss) if dist_loss is not None: losses.append(self.config.dist_loss_weight * dist_loss) if losses: total_loss = sum(losses) if return_dict: return { 'loss': total_loss, 'seq_loss': seq_loss, 'ms_loss': ms_loss, 'struct_loss': struct_loss, 'dist_loss': dist_loss, 'seq_logits': seq_logits, 'ms_logits': ms_logits, 'struct_logits': struct_logits, 'dist_predictions': dist_predictions, 'seq_hidden': seq_hidden, 'ms_hidden': ms_hidden, 'struct_hidden': struct_hidden, 'seq_pooled': seq_pooled, 'ms_pooled': ms_pooled, 'struct_pooled': struct_pooled, 'fused_repr': fused_repr, } else: return (total_loss, seq_logits, ms_logits, struct_logits, fused_repr) def get_multimodal_representation( self, seq_token_ids: torch.Tensor, seq_attention_mask: torch.Tensor, seq_residue_ids: torch.Tensor, ms_token_ids: torch.Tensor, ms_attention_mask: torch.Tensor, has_ms: torch.Tensor, struct_token_ids: Optional[torch.Tensor] = None, struct_attention_mask: Optional[torch.Tensor] = None, struct_residue_ids: Optional[torch.Tensor] = None, has_3d: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Get fused multimodal representation (for inference).""" outputs = self.forward( seq_token_ids=seq_token_ids, seq_attention_mask=seq_attention_mask, seq_residue_ids=seq_residue_ids, ms_token_ids=ms_token_ids, ms_attention_mask=ms_attention_mask, has_ms=has_ms, struct_token_ids=struct_token_ids, struct_attention_mask=struct_attention_mask, struct_residue_ids=struct_residue_ids, has_3d=has_3d, return_dict=True, ) return outputs['fused_repr'] if __name__ == "__main__": # Test the model print("="*80) print("Testing BERTose model") print("="*80) # Create config config = MultimodalGlycanBERTConfig( seq_vocab_size=166, seq_hidden_size=768, seq_num_layers=12, seq_num_heads=12, ms_vocab_size=242, ms_hidden_size=384, ms_num_layers=6, ms_num_heads=6, struct_vocab_size=1024, struct_hidden_size=512, struct_num_layers=8, struct_num_heads=8, use_3d=True, use_cross_attention=True, seq_loss_weight=0.60, ms_loss_weight=0.15, struct_loss_weight=0.25, ) print(f"\nConfig:") print(f" Sequence vocab: {config.seq_vocab_size}") print(f" MS vocab: {config.ms_vocab_size}") print(f" Structure vocab: {config.struct_vocab_size}") print(f" Loss weights: seq={config.seq_loss_weight}, ms={config.ms_loss_weight}, struct={config.struct_loss_weight}") # Create model model = MultimodalGlycanBERT(config) # Count parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"\nModel Parameters:") print(f" Total: {total_params:,}") print(f" Trainable: {trainable_params:,}") # Test forward pass print(f"\n{'='*80}") print("Testing Forward Pass (with Conv front-end)") print("="*80) batch_size = 4 seq_len = 128 ms_len = 50 struct_len = 40 # Create dummy inputs seq_token_ids = torch.randint(0, config.seq_vocab_size, (batch_size, seq_len)) seq_attention_mask = torch.ones(batch_size, seq_len) # Approximate: ~5 tokens per residue seq_residue_ids = torch.div( torch.arange(seq_len), 5, rounding_mode="floor" ).unsqueeze(0).expand(batch_size, -1) ms_token_ids = torch.randint(config.ms_vocab_offset, config.ms_total_vocab_size, (batch_size, ms_len)) ms_attention_mask = torch.ones(batch_size, ms_len) struct_token_ids = torch.randint(0, config.struct_vocab_size, (batch_size, struct_len)) struct_attention_mask = torch.ones(batch_size, struct_len) # Approximate: 4 tokens per residue for VQ-VAE tokens struct_residue_ids = torch.div( torch.arange(struct_len), 4, rounding_mode="floor" ).unsqueeze(0).expand(batch_size, -1) has_ms = torch.tensor([True, True, False, True]) has_3d = torch.tensor([True, False, True, True]) # Create labels for MLM seq_labels = seq_token_ids.clone() seq_labels[seq_labels != config.mask_token_id] = -100 ms_labels = ms_token_ids.clone() ms_labels[ms_labels != config.mask_token_id] = -100 struct_labels = struct_token_ids.clone() struct_labels[struct_labels != config.mask_token_id] = -100 # Forward pass outputs = model( seq_token_ids=seq_token_ids, seq_attention_mask=seq_attention_mask, seq_residue_ids=seq_residue_ids, ms_token_ids=ms_token_ids, ms_attention_mask=ms_attention_mask, has_ms=has_ms, struct_token_ids=struct_token_ids, struct_attention_mask=struct_attention_mask, struct_residue_ids=struct_residue_ids, has_3d=has_3d, seq_labels=seq_labels, ms_labels=ms_labels, struct_labels=struct_labels, ) print(f"\nOutput shapes:") print(f" seq_logits: {outputs['seq_logits'].shape}") print(f" ms_logits: {outputs['ms_logits'].shape}") print(f" struct_logits: {outputs['struct_logits'].shape}") print(f" fused_repr: {outputs['fused_repr'].shape}") print(f"\nLosses:") print(f" Total loss: {outputs['loss'].item():.4f}") print(f" Sequence loss: {outputs['seq_loss'].item():.4f}") print(f" MS loss: {outputs['ms_loss'].item():.4f}") print(f" Structure loss: {outputs['struct_loss'].item():.4f}") print(f"\n{'='*80}") print("Model Test Complete!") print("="*80)