| """ |
| BERTose model |
| |
| Core glycan representation model with three modalities: |
| - Sequence (WURCS atomic tokenization) |
| - MS (mass spectrometry peaks, RT, intensity) |
| - 3D structure (VQ-VAE discrete tokens, 4 per residue) |
| |
| Each modality has its own encoder, with cross-attention for sequence-structure alignment. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from typing import Dict, Optional, Tuple |
| import math |
|
|
| try: |
| from .bertose_layers import GlycanBERTConfig, GlycanBERTEmbeddings, GlycanBERTLayer |
| except ImportError: |
| from bertose_layers import GlycanBERTConfig, GlycanBERTEmbeddings, GlycanBERTLayer |
|
|
|
|
| class ConvGlycanBERTEmbeddings(nn.Module): |
| """ |
| Improved Convolutional front-end that mixes local WURCS context before the Transformer. |
| |
| Key improvements over original: |
| 1. Position embeddings added BEFORE convolution (provides spatial context to conv) |
| 2. Residual connection (conv enriches embeddings rather than replacing them) |
| 3. Multi-scale convolutions (kernel sizes 3, 5, 7) for better receptive field |
| 4. Proper layer normalization on the residual path |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.token_embeddings = nn.Embedding( |
| config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id |
| ) |
| self.position_embeddings = nn.Embedding( |
| config.max_position_embeddings, config.hidden_size |
| ) |
| |
| |
| max_branch_depth = getattr(config, "max_branch_depth", 8) |
| self.branch_embeddings = nn.Embedding(max_branch_depth, config.hidden_size) |
| |
| |
| |
| num_linkage_types = getattr(config, "num_linkage_types", 9) |
| self.linkage_embeddings = nn.Embedding(num_linkage_types, config.hidden_size) |
| |
| |
| kernel_size = getattr(config, "cnn_kernel_size", 3) |
| |
| channels_per_scale = config.hidden_size // 3 |
| self.conv_layers = nn.ModuleList([ |
| nn.Conv1d( |
| in_channels=config.hidden_size, |
| out_channels=channels_per_scale, |
| kernel_size=kernel_size + 2 * i, |
| padding=(kernel_size + 2 * i) // 2, |
| ) |
| for i in range(3) |
| ]) |
| self.conv_activation = nn.GELU() |
| self.conv_proj = nn.Linear(channels_per_scale * 3, config.hidden_size) |
| |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.conv_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| self.register_buffer( |
| "position_ids", |
| torch.arange(config.max_position_embeddings).expand((1, -1)), |
| ) |
|
|
| self.hidden_size = config.hidden_size |
|
|
| def forward(self, input_ids, branch_depths=None, linkage_types=None): |
| seq_len = input_ids.shape[1] |
| |
| |
| x = self.token_embeddings(input_ids) |
| position_ids = self.position_ids[:, :seq_len] |
| x = x + self.position_embeddings(position_ids) |
| |
| |
| if branch_depths is not None: |
| |
| branch_depths = branch_depths.clamp(0, self.branch_embeddings.num_embeddings - 1) |
| x = x + self.branch_embeddings(branch_depths) |
| |
| |
| if linkage_types is not None: |
| linkage_types = linkage_types.clamp(0, self.linkage_embeddings.num_embeddings - 1) |
| x = x + self.linkage_embeddings(linkage_types) |
| |
| x = self.LayerNorm(x) |
| |
| |
| |
| conv_in = x.permute(0, 2, 1) |
| |
| |
| conv_outputs = [] |
| for conv in self.conv_layers: |
| conv_out = self.conv_activation(conv(conv_in)) |
| conv_outputs.append(conv_out) |
| |
| |
| conv_out = torch.cat(conv_outputs, dim=1) |
| conv_out = conv_out.permute(0, 2, 1) |
| conv_out = self.conv_proj(conv_out) |
| |
| |
| x = self.conv_norm(x + self.dropout(conv_out)) |
| |
| return x |
|
|
|
|
| def create_residue_level_mask( |
| seq_residue_ids: torch.Tensor, |
| struct_residue_ids: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Create residue-level attention mask for cross-attention. |
| |
| Maps WURCS tokens to VQ-VAE structural tokens based on residue IDs. |
| A WURCS token with residue_id=0 can only attend to VQ-VAE tokens with residue_id=0. |
| |
| Args: |
| seq_residue_ids: Residue IDs for sequence tokens (batch, N_seq) |
| struct_residue_ids: Residue IDs for structural tokens (batch, N_struct) |
| |
| Returns: |
| Boolean mask (batch, N_seq, N_struct) where True = can attend |
| """ |
| |
| |
| |
| mask = seq_residue_ids.unsqueeze(2) == struct_residue_ids.unsqueeze(1) |
| |
| |
| |
| |
| mask &= (seq_residue_ids.unsqueeze(2) >= 0) |
| |
| return mask |
|
|
|
|
| class MultimodalGlycanBERTConfig: |
| """Configuration for the BERTose model.""" |
| |
| def __init__( |
| self, |
| |
| seq_vocab_size: int = 166, |
| seq_hidden_size: int = 768, |
| seq_num_layers: int = 12, |
| seq_num_heads: int = 12, |
| seq_max_length: int = 512, |
| |
| |
| ms_vocab_size: int = 242, |
| ms_hidden_size: int = 384, |
| ms_num_layers: int = 6, |
| ms_num_heads: int = 6, |
| ms_max_length: int = 150, |
| |
| |
| struct_vocab_size: int = 1024, |
| struct_hidden_size: int = 512, |
| struct_num_layers: int = 8, |
| struct_num_heads: int = 8, |
| struct_max_length: int = 200, |
| use_3d: bool = True, |
| |
| |
| use_cross_attention: bool = True, |
| cross_attn_num_heads: int = 8, |
| |
| |
| fusion_hidden_size: int = 768, |
| fusion_num_layers: int = 2, |
| |
| |
| hidden_dropout_prob: float = 0.1, |
| attention_probs_dropout_prob: float = 0.1, |
| layer_norm_eps: float = 1e-12, |
| initializer_range: float = 0.02, |
| |
| |
| use_cnn_frontend: bool = True, |
| cnn_kernel_size: int = 3, |
| |
| |
| seq_loss_weight: float = 0.60, |
| ms_loss_weight: float = 0.15, |
| struct_loss_weight: float = 0.25, |
| |
| |
| pad_token_id: int = 0, |
| mask_token_id: int = 1, |
| ): |
| |
| self.seq_vocab_size = seq_vocab_size |
| self.seq_hidden_size = seq_hidden_size |
| self.seq_num_layers = seq_num_layers |
| self.seq_num_heads = seq_num_heads |
| self.seq_max_length = seq_max_length |
| |
| |
| self.ms_vocab_size = ms_vocab_size |
| self.ms_vocab_offset = seq_vocab_size |
| self.ms_total_vocab_size = seq_vocab_size + ms_vocab_size |
| self.ms_hidden_size = ms_hidden_size |
| self.ms_num_layers = ms_num_layers |
| self.ms_num_heads = ms_num_heads |
| self.ms_max_length = ms_max_length |
| |
| |
| self.struct_vocab_size = struct_vocab_size |
| self.struct_hidden_size = struct_hidden_size |
| self.struct_num_layers = struct_num_layers |
| self.struct_num_heads = struct_num_heads |
| self.struct_max_length = struct_max_length |
| self.use_3d = use_3d |
| |
| |
| self.use_cross_attention = use_cross_attention |
| self.cross_attn_num_heads = cross_attn_num_heads |
| |
| |
| self.fusion_hidden_size = fusion_hidden_size |
| self.fusion_num_layers = fusion_num_layers |
| |
| |
| self.hidden_dropout_prob = hidden_dropout_prob |
| self.attention_probs_dropout_prob = attention_probs_dropout_prob |
| self.layer_norm_eps = layer_norm_eps |
| self.initializer_range = initializer_range |
|
|
| |
| self.use_cnn_frontend = use_cnn_frontend |
| self.cnn_kernel_size = cnn_kernel_size |
| |
| |
| self.seq_loss_weight = seq_loss_weight |
| self.ms_loss_weight = ms_loss_weight |
| self.struct_loss_weight = struct_loss_weight |
| self.dist_loss_weight = 0.25 |
| |
| |
| self.pad_token_id = pad_token_id |
| self.mask_token_id = mask_token_id |
| |
| def to_seq_config(self) -> GlycanBERTConfig: |
| """Convert to sequence-only config.""" |
| return GlycanBERTConfig( |
| vocab_size=self.seq_vocab_size, |
| hidden_size=self.seq_hidden_size, |
| num_hidden_layers=self.seq_num_layers, |
| num_attention_heads=self.seq_num_heads, |
| intermediate_size=self.seq_hidden_size * 4, |
| hidden_dropout_prob=self.hidden_dropout_prob, |
| attention_probs_dropout_prob=self.attention_probs_dropout_prob, |
| max_position_embeddings=self.seq_max_length, |
| layer_norm_eps=self.layer_norm_eps, |
| pad_token_id=self.pad_token_id, |
| mask_token_id=self.mask_token_id, |
| initializer_range=self.initializer_range, |
| ) |
| |
| def to_ms_config(self) -> GlycanBERTConfig: |
| """Convert to MS-only config.""" |
| return GlycanBERTConfig( |
| vocab_size=self.ms_total_vocab_size, |
| hidden_size=self.ms_hidden_size, |
| num_hidden_layers=self.ms_num_layers, |
| num_attention_heads=self.ms_num_heads, |
| intermediate_size=self.ms_hidden_size * 4, |
| hidden_dropout_prob=self.hidden_dropout_prob, |
| attention_probs_dropout_prob=self.attention_probs_dropout_prob, |
| max_position_embeddings=self.ms_max_length, |
| layer_norm_eps=self.layer_norm_eps, |
| pad_token_id=self.pad_token_id, |
| mask_token_id=self.mask_token_id, |
| initializer_range=self.initializer_range, |
| ) |
| |
| def to_struct_config(self) -> GlycanBERTConfig: |
| """Convert to structure-only config.""" |
| return GlycanBERTConfig( |
| vocab_size=self.struct_vocab_size, |
| hidden_size=self.struct_hidden_size, |
| num_hidden_layers=self.struct_num_layers, |
| num_attention_heads=self.struct_num_heads, |
| intermediate_size=self.struct_hidden_size * 4, |
| hidden_dropout_prob=self.hidden_dropout_prob, |
| attention_probs_dropout_prob=self.attention_probs_dropout_prob, |
| max_position_embeddings=self.struct_max_length, |
| layer_norm_eps=self.layer_norm_eps, |
| pad_token_id=self.pad_token_id, |
| mask_token_id=self.mask_token_id, |
| initializer_range=self.initializer_range, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class MonosaccharidePooling(nn.Module): |
| """ |
| Pool token representations to monosaccharide level, then aggregate. |
| |
| This bridges the gap between token-level BERT and monosaccharide-level CNNs/GNNs. |
| Uses monosaccharide_indices from the data to know where each residue starts. |
| """ |
| |
| def __init__(self, hidden_size: int, num_attention_heads: int = 8, dropout: float = 0.1): |
| super().__init__() |
| self.hidden_size = hidden_size |
| |
| |
| self.mono_attention = nn.MultiheadAttention( |
| embed_dim=hidden_size, |
| num_heads=num_attention_heads, |
| dropout=dropout, |
| batch_first=True |
| ) |
| self.mono_norm = nn.LayerNorm(hidden_size) |
| |
| |
| self.glycan_query = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) |
| self.glycan_attention = nn.MultiheadAttention( |
| embed_dim=hidden_size, |
| num_heads=num_attention_heads, |
| dropout=dropout, |
| batch_first=True |
| ) |
| self.glycan_norm = nn.LayerNorm(hidden_size) |
| |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| residue_ids: torch.Tensor, |
| attention_mask: torch.Tensor = None, |
| ) -> torch.Tensor: |
| """ |
| Pool tokens to monosaccharide level, then to glycan level. |
| |
| Returns: |
| Glycan representation: (batch, hidden_size) |
| """ |
| batch_size = hidden_states.size(0) |
| device = hidden_states.device |
| |
| |
| max_residues = 50 |
| |
| |
| mono_reps = torch.zeros(batch_size, max_residues, self.hidden_size, device=device) |
| mono_mask = torch.zeros(batch_size, max_residues, dtype=torch.bool, device=device) |
| |
| for b in range(batch_size): |
| unique_residues = torch.unique(residue_ids[b][residue_ids[b] >= 0]) |
| for i, rid in enumerate(unique_residues): |
| if i >= max_residues: |
| break |
| token_mask = residue_ids[b] == rid |
| if attention_mask is not None: |
| token_mask = token_mask & (attention_mask[b] > 0) |
| if token_mask.sum() > 0: |
| mono_reps[b, i] = hidden_states[b][token_mask].mean(dim=0) |
| mono_mask[b, i] = True |
| |
| |
| |
| key_padding_mask = ~mono_mask |
| |
| mono_out, _ = self.mono_attention( |
| mono_reps, mono_reps, mono_reps, |
| key_padding_mask=key_padding_mask |
| ) |
| mono_out = self.mono_norm(mono_reps + mono_out) |
| |
| |
| glycan_query = self.glycan_query.expand(batch_size, -1, -1) |
| glycan_out, _ = self.glycan_attention( |
| glycan_query, mono_out, mono_out, |
| key_padding_mask=key_padding_mask |
| ) |
| glycan_out = self.glycan_norm(glycan_query + glycan_out) |
| |
| return glycan_out.squeeze(1) |
|
|
|
|
| |
| |
| |
|
|
| |
| MONOSACCHARIDE_VOCAB = { |
| '[PAD_MONO]': 0, '[UNK_MONO]': 1, |
| 'Glc': 2, 'GlcNAc': 3, 'GlcA': 4, 'GlcN': 5, |
| 'Gal': 6, 'GalNAc': 7, 'GalA': 8, 'GalN': 9, |
| 'Man': 10, 'ManNAc': 11, 'ManA': 12, 'ManN': 13, |
| 'Fuc': 14, 'Rha': 15, 'Xyl': 16, 'Ara': 17, |
| 'Neu5Ac': 18, 'Neu5Gc': 19, 'Kdn': 20, 'Sia': 21, |
| 'GalNAcA': 22, 'GlcNAcA': 23, 'IdoA': 24, 'GulA': 25, |
| 'Rib': 26, 'Lyx': 27, 'All': 28, 'Alt': 29, |
| 'Tal': 30, 'Ido': 31, 'Qui': 32, 'Oli': 33, |
| 'Tyv': 34, 'Abe': 35, 'Par': 36, 'Dig': 37, |
| 'Col': 38, 'Dha': 39, 'Kdo': 40, 'Hep': 41, |
| 'NeuroGc': 42, 'Muramic': 43, 'LDManHep': 44, 'DDManHep': 45, |
| 'Bac': 46, 'Pse': 47, 'Leg': 48, 'Aci': 49, |
| '6dTal': 50, 'Fru': 51, 'Tag': 52, 'Sor': 53, |
| 'Psi': 54, 'Sed': 55, 'MurNAc': 56, 'MurNGc': 57, |
| 'Api': 58, 'Erwiniose': 59, 'Yer': 60, 'Thre': 61, |
| |
| } |
|
|
|
|
| class ResidueTypeEmbeddings(nn.Module): |
| """ |
| Learnable embeddings for monosaccharide types. |
| |
| Instead of the model having to learn that 'a1221m' = Fucose from character patterns, |
| we explicitly add a Fucose embedding to all tokens belonging to that residue. |
| """ |
| |
| def __init__(self, hidden_size: int, num_mono_types: int = 70): |
| super().__init__() |
| self.mono_embeddings = nn.Embedding(num_mono_types, hidden_size) |
| self.mono_vocab = MONOSACCHARIDE_VOCAB |
| self.hidden_size = hidden_size |
| |
| def forward( |
| self, |
| token_embeddings: torch.Tensor, |
| residue_ids: torch.Tensor, |
| mono_type_ids: torch.Tensor = None, |
| ) -> torch.Tensor: |
| """ |
| Add residue type embeddings to token embeddings. |
| |
| Args: |
| token_embeddings: Base token embeddings |
| residue_ids: Which residue each token belongs to (-1 for special tokens) |
| mono_type_ids: Monosaccharide type ID for each residue position |
| |
| Returns: |
| Enhanced embeddings with residue type information |
| """ |
| if mono_type_ids is None: |
| return token_embeddings |
| |
| batch_size, seq_len, _ = token_embeddings.shape |
| enhanced = token_embeddings.clone() |
| |
| |
| for b in range(batch_size): |
| for pos in range(seq_len): |
| rid = residue_ids[b, pos].item() |
| if rid >= 0 and rid < mono_type_ids.size(1): |
| mono_id = mono_type_ids[b, rid] |
| enhanced[b, pos] = enhanced[b, pos] + self.mono_embeddings(mono_id) |
| |
| return enhanced |
| |
| @staticmethod |
| def get_mono_type_id(mono_name: str) -> int: |
| """Convert monosaccharide name to type ID.""" |
| return MONOSACCHARIDE_VOCAB.get(mono_name, MONOSACCHARIDE_VOCAB['[UNK_MONO]']) |
|
|
|
|
| |
| |
| |
|
|
| class RelativePositionBias(nn.Module): |
| """ |
| Compute relative position bias for attention based on residue IDs. |
| |
| Tokens in the same residue get distance 0. |
| Tokens in adjacent residues get distance ±1. |
| This helps the model understand glycan tree structure. |
| """ |
| |
| def __init__(self, num_heads: int, max_distance: int = 10): |
| super().__init__() |
| self.num_heads = num_heads |
| self.max_distance = max_distance |
| |
| |
| num_distances = 2 * max_distance + 1 |
| self.relative_bias = nn.Embedding(num_distances, num_heads) |
| |
| def forward(self, residue_ids: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute relative position bias. |
| |
| Args: |
| residue_ids: (batch, seq_len) |
| |
| Returns: |
| Bias to add to attention scores: (batch, num_heads, seq_len, seq_len) |
| """ |
| |
| |
| distance = residue_ids.unsqueeze(2) - residue_ids.unsqueeze(1) |
| |
| |
| distance_clamped = distance.clamp(-self.max_distance, self.max_distance) |
| distance_idx = distance_clamped + self.max_distance |
| |
| |
| bias = self.relative_bias(distance_idx) |
| |
| |
| bias = bias.permute(0, 3, 1, 2) |
| |
| return bias |
|
|
|
|
| class CrossAttentionLayer(nn.Module): |
| """ |
| Cross-attention layer for sequence-structure alignment. |
| |
| Allows sequence tokens to attend to structural atoms using attention masks. |
| """ |
| |
| def __init__(self, config: MultimodalGlycanBERTConfig): |
| super().__init__() |
| self.num_heads = config.cross_attn_num_heads |
| self.hidden_size = config.seq_hidden_size |
| self.head_dim = self.hidden_size // self.num_heads |
| |
| assert self.hidden_size % self.num_heads == 0, "hidden_size must be divisible by num_heads" |
| |
| |
| self.query = nn.Linear(config.seq_hidden_size, self.hidden_size) |
| self.key = nn.Linear(config.struct_hidden_size, self.hidden_size) |
| self.value = nn.Linear(config.struct_hidden_size, self.hidden_size) |
| |
| self.output = nn.Linear(self.hidden_size, config.seq_hidden_size) |
| self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
| self.layer_norm = nn.LayerNorm(config.seq_hidden_size, eps=config.layer_norm_eps) |
| |
| def forward( |
| self, |
| seq_hidden: torch.Tensor, |
| struct_hidden: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Apply cross-attention from sequence to structure. |
| |
| Args: |
| seq_hidden: Sequence hidden states |
| struct_hidden: Structure hidden states |
| attention_mask: Boolean mask (True = can attend, False = cannot attend) |
| |
| Returns: |
| Updated sequence hidden states |
| """ |
| batch_size, seq_len, _ = seq_hidden.shape |
| struct_len = struct_hidden.shape[1] |
| |
| |
| Q = self.query(seq_hidden) |
| K = self.key(struct_hidden) |
| V = self.value(struct_hidden) |
| |
| |
| Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| K = K.view(batch_size, struct_len, self.num_heads, self.head_dim).transpose(1, 2) |
| V = V.view(batch_size, struct_len, self.num_heads, self.head_dim).transpose(1, 2) |
| |
| |
| scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) |
| |
| |
| if attention_mask is not None: |
| |
| attention_mask = attention_mask.unsqueeze(1) |
| |
| attention_mask = (~attention_mask).float() * -10000.0 |
| scores = scores + attention_mask |
| |
| |
| attn_weights = torch.softmax(scores, dim=-1) |
| attn_weights = self.dropout(attn_weights) |
| |
| |
| context = torch.matmul(attn_weights, V) |
| |
| |
| context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size) |
| |
| |
| output = self.output(context) |
| output = self.dropout(output) |
| |
| |
| output = self.layer_norm(seq_hidden + output) |
| |
| return output |
|
|
|
|
| class MultimodalGlycanBERT(nn.Module): |
| """ |
| BERTose model for glycan representation learning. |
| |
| Architecture: |
| 1. Separate encoders for each modality (sequence, MS, 3D structure) |
| 2. Cross-attention for sequence-structure alignment |
| 3. Modality-specific MLM heads |
| 4. Fusion layer for combined representation |
| """ |
| |
| def __init__(self, config: MultimodalGlycanBERTConfig): |
| super().__init__() |
| self.config = config |
| |
| |
| seq_config = config.to_seq_config() |
| seq_config.cnn_kernel_size = config.cnn_kernel_size |
|
|
| if config.use_cnn_frontend: |
| print(f"Enabled convolutional front-end (kernel={config.cnn_kernel_size})") |
| self.seq_embeddings = ConvGlycanBERTEmbeddings(seq_config) |
| else: |
| self.seq_embeddings = GlycanBERTEmbeddings(seq_config) |
| self.seq_layers = nn.ModuleList([GlycanBERTLayer(seq_config) for _ in range(seq_config.num_hidden_layers)]) |
| self.seq_mlm_head = nn.Linear(seq_config.hidden_size, seq_config.vocab_size) |
| |
| |
| ms_config = config.to_ms_config() |
| self.ms_embeddings = GlycanBERTEmbeddings(ms_config) |
| self.ms_layers = nn.ModuleList([GlycanBERTLayer(ms_config) for _ in range(ms_config.num_hidden_layers)]) |
| self.ms_mlm_head = nn.Linear(ms_config.hidden_size, ms_config.vocab_size) |
| |
| |
| if config.use_3d: |
| struct_config = config.to_struct_config() |
| self.struct_embeddings = GlycanBERTEmbeddings(struct_config) |
| self.struct_layers = nn.ModuleList([GlycanBERTLayer(struct_config) for _ in range(struct_config.num_hidden_layers)]) |
| self.struct_mlm_head = nn.Linear(struct_config.hidden_size, struct_config.vocab_size) |
| |
| |
| if config.use_cross_attention: |
| self.cross_attention = CrossAttentionLayer(config) |
| |
| |
| if config.ms_hidden_size != config.seq_hidden_size: |
| self.ms_projection = nn.Linear(config.ms_hidden_size, config.seq_hidden_size) |
| else: |
| self.ms_projection = nn.Identity() |
| |
| if config.use_3d and config.struct_hidden_size != config.seq_hidden_size: |
| self.struct_projection = nn.Linear(config.struct_hidden_size, config.seq_hidden_size) |
| else: |
| self.struct_projection = nn.Identity() |
| |
| |
| |
| fusion_input_size = config.seq_hidden_size * (3 if config.use_3d else 2) |
| self.fusion_layer = nn.Sequential( |
| nn.Linear(fusion_input_size, config.fusion_hidden_size), |
| nn.LayerNorm(config.fusion_hidden_size, eps=config.layer_norm_eps), |
| nn.GELU(), |
| nn.Dropout(config.hidden_dropout_prob), |
| nn.Linear(config.fusion_hidden_size, config.fusion_hidden_size), |
| ) |
| |
| |
| |
| |
| self.dist_proj = nn.Linear(config.seq_hidden_size, 128) |
| self.distance_head = nn.Sequential( |
| nn.Linear(128, 64), |
| nn.ReLU(), |
| nn.Linear(64, 1) |
| ) |
| |
| |
| self.apply(self._init_weights) |
| |
| def _init_weights(self, module): |
| """Initialize weights.""" |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
| |
| def forward( |
| self, |
| seq_token_ids: torch.Tensor, |
| seq_attention_mask: torch.Tensor, |
| seq_residue_ids: torch.Tensor, |
| seq_branch_depths: Optional[torch.Tensor] = None, |
| seq_linkage_types: Optional[torch.Tensor] = None, |
| ms_token_ids: torch.Tensor = None, |
| ms_attention_mask: torch.Tensor = None, |
| has_ms: torch.Tensor = None, |
| struct_token_ids: Optional[torch.Tensor] = None, |
| struct_attention_mask: Optional[torch.Tensor] = None, |
| struct_residue_ids: Optional[torch.Tensor] = None, |
| has_3d: Optional[torch.Tensor] = None, |
| seq_labels: Optional[torch.Tensor] = None, |
| ms_labels: Optional[torch.Tensor] = None, |
| struct_labels: Optional[torch.Tensor] = None, |
| dist_labels: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Forward pass for BERTose. |
| |
| Args: |
| seq_token_ids: (batch_size, seq_len) - Sequence token IDs |
| seq_attention_mask: (batch_size, seq_len) - Sequence attention mask |
| seq_residue_ids: (batch_size, seq_len) - Sequence token residue IDs |
| ms_token_ids: (batch_size, ms_len) - MS token IDs |
| ms_attention_mask: (batch_size, ms_len) - MS attention mask |
| has_ms: (batch_size,) - Boolean mask for samples with MS data |
| struct_token_ids: (batch_size, struct_len) - Structure VQ-VAE token IDs (optional) |
| struct_attention_mask: (batch_size, struct_len) - Structure attention mask (optional) |
| struct_residue_ids: (batch_size, struct_len) - Structure token residue IDs (optional) |
| has_3d: (batch_size,) - Boolean mask for samples with 3D data (optional) |
| seq_labels: (batch_size, seq_len) - Masked sequence labels (optional) |
| ms_labels: (batch_size, ms_len) - Masked MS labels (optional) |
| struct_labels: (batch_size, struct_len) - Masked structure labels (optional) |
| return_dict: Whether to return dict or tuple |
| |
| Returns: |
| Dictionary containing logits, hidden states, losses, etc. |
| """ |
| batch_size = seq_token_ids.shape[0] |
| device = seq_token_ids.device |
| |
| |
| |
| seq_hidden = self.seq_embeddings(seq_token_ids, seq_branch_depths, seq_linkage_types) |
| for layer in self.seq_layers: |
| seq_hidden = layer(seq_hidden, seq_attention_mask) |
| |
| seq_pooled = seq_hidden[:, 0, :] |
| seq_logits = self.seq_mlm_head(seq_hidden) |
| |
| |
| |
| |
| seq_hidden_small = self.dist_proj(seq_hidden) |
| |
| |
| h_i = seq_hidden_small.unsqueeze(2) |
| h_j = seq_hidden_small.unsqueeze(1) |
| h_diff = torch.abs(h_i - h_j) |
| dist_predictions = self.distance_head(h_diff) |
| |
| |
| ms_hidden = None |
| ms_pooled = None |
| ms_logits = None |
| |
| if ms_token_ids is not None: |
| ms_hidden = self.ms_embeddings(ms_token_ids) |
| for layer in self.ms_layers: |
| ms_hidden = layer(ms_hidden, ms_attention_mask) |
| |
| ms_pooled = ms_hidden[:, 0, :] |
| ms_logits = self.ms_mlm_head(ms_hidden) |
| |
| |
| if has_ms is not None: |
| has_ms_expanded = has_ms.unsqueeze(1).float() |
| ms_pooled = ms_pooled * has_ms_expanded |
| |
| |
| struct_pooled = None |
| struct_logits = None |
| struct_hidden = None |
| |
| if self.config.use_3d and struct_token_ids is not None: |
| struct_hidden = self.struct_embeddings(struct_token_ids) |
| for layer in self.struct_layers: |
| struct_hidden = layer(struct_hidden, struct_attention_mask) |
| |
| struct_pooled = struct_hidden[:, 0, :] |
| struct_logits = self.struct_mlm_head(struct_hidden) |
| |
| |
| if has_3d is not None: |
| has_3d_expanded = has_3d.unsqueeze(1).float() |
| struct_pooled = struct_pooled * has_3d_expanded |
| |
| |
| |
| if self.config.use_cross_attention and struct_residue_ids is not None: |
| |
| |
| residue_mask = create_residue_level_mask( |
| seq_residue_ids=seq_residue_ids, |
| struct_residue_ids=struct_residue_ids, |
| ) |
| |
| |
| seq_hidden = self.cross_attention( |
| seq_hidden=seq_hidden, |
| struct_hidden=struct_hidden, |
| attention_mask=residue_mask, |
| ) |
| |
| |
| seq_pooled = seq_hidden[:, 0, :] |
| |
| |
| |
| ms_pooled_projected = self.ms_projection(ms_pooled) |
| |
| if self.config.use_3d and struct_pooled is not None: |
| struct_pooled_projected = self.struct_projection(struct_pooled) |
| combined = torch.cat([seq_pooled, ms_pooled_projected, struct_pooled_projected], dim=-1) |
| else: |
| combined = torch.cat([seq_pooled, ms_pooled_projected], dim=-1) |
| |
| fused_repr = self.fusion_layer(combined) |
| |
| |
| total_loss = None |
| seq_loss = None |
| ms_loss = None |
| struct_loss = None |
| dist_loss = None |
| |
| if seq_labels is not None: |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| seq_loss = loss_fct( |
| seq_logits.view(-1, self.config.seq_vocab_size), |
| seq_labels.view(-1) |
| ) |
| |
| if ms_labels is not None: |
| ms_labels_masked = ms_labels.clone() |
| ms_labels_masked[~has_ms] = -100 |
| |
| if (ms_labels_masked != -100).any(): |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| ms_loss = loss_fct( |
| ms_logits.view(-1, self.config.ms_total_vocab_size), |
| ms_labels_masked.view(-1) |
| ) |
| else: |
| ms_loss = torch.tensor(0.0, device=seq_token_ids.device) |
| |
| if self.config.use_3d and struct_labels is not None and struct_logits is not None: |
| struct_labels_masked = struct_labels.clone() |
| if has_3d is not None: |
| struct_labels_masked[~has_3d] = -100 |
| |
| if (struct_labels_masked != -100).any(): |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| struct_loss = loss_fct( |
| struct_logits.view(-1, self.config.struct_vocab_size), |
| struct_labels_masked.view(-1) |
| ) |
| else: |
| struct_loss = torch.tensor(0.0, device=seq_token_ids.device) |
| |
| |
| if dist_labels is not None: |
| |
| preds = dist_predictions.squeeze(-1) |
| |
| |
| |
| valid_mask = (dist_labels != -1) & (seq_attention_mask.unsqueeze(1) * seq_attention_mask.unsqueeze(2) == 1) |
| |
| |
| if not hasattr(self, '_dist_debug_printed'): |
| print(f"[DIST DEBUG] dist_labels shape: {dist_labels.shape}, valid_mask.sum: {valid_mask.sum().item()}") |
| self._dist_debug_printed = True |
| |
| if valid_mask.sum() > 0: |
| |
| loss_fct = nn.MSELoss() |
| dist_loss = loss_fct(preds[valid_mask], dist_labels[valid_mask].float()) |
| else: |
| dist_loss = torch.tensor(0.0, device=seq_token_ids.device) |
| else: |
| |
| if not hasattr(self, '_dist_none_printed'): |
| print("[DIST DEBUG] dist_labels is None!") |
| self._dist_none_printed = True |
| |
| |
| losses = [] |
| if seq_loss is not None: |
| losses.append(self.config.seq_loss_weight * seq_loss) |
| if ms_loss is not None: |
| losses.append(self.config.ms_loss_weight * ms_loss) |
| if struct_loss is not None: |
| losses.append(self.config.struct_loss_weight * struct_loss) |
| if dist_loss is not None: |
| losses.append(self.config.dist_loss_weight * dist_loss) |
| |
| if losses: |
| total_loss = sum(losses) |
| |
| if return_dict: |
| return { |
| 'loss': total_loss, |
| 'seq_loss': seq_loss, |
| 'ms_loss': ms_loss, |
| 'struct_loss': struct_loss, |
| 'dist_loss': dist_loss, |
| 'seq_logits': seq_logits, |
| 'ms_logits': ms_logits, |
| 'struct_logits': struct_logits, |
| 'dist_predictions': dist_predictions, |
| 'seq_hidden': seq_hidden, |
| 'ms_hidden': ms_hidden, |
| 'struct_hidden': struct_hidden, |
| 'seq_pooled': seq_pooled, |
| 'ms_pooled': ms_pooled, |
| 'struct_pooled': struct_pooled, |
| 'fused_repr': fused_repr, |
| } |
| else: |
| return (total_loss, seq_logits, ms_logits, struct_logits, fused_repr) |
| |
| def get_multimodal_representation( |
| self, |
| seq_token_ids: torch.Tensor, |
| seq_attention_mask: torch.Tensor, |
| seq_residue_ids: torch.Tensor, |
| ms_token_ids: torch.Tensor, |
| ms_attention_mask: torch.Tensor, |
| has_ms: torch.Tensor, |
| struct_token_ids: Optional[torch.Tensor] = None, |
| struct_attention_mask: Optional[torch.Tensor] = None, |
| struct_residue_ids: Optional[torch.Tensor] = None, |
| has_3d: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Get fused multimodal representation (for inference).""" |
| outputs = self.forward( |
| seq_token_ids=seq_token_ids, |
| seq_attention_mask=seq_attention_mask, |
| seq_residue_ids=seq_residue_ids, |
| ms_token_ids=ms_token_ids, |
| ms_attention_mask=ms_attention_mask, |
| has_ms=has_ms, |
| struct_token_ids=struct_token_ids, |
| struct_attention_mask=struct_attention_mask, |
| struct_residue_ids=struct_residue_ids, |
| has_3d=has_3d, |
| return_dict=True, |
| ) |
| return outputs['fused_repr'] |
|
|
|
|
| if __name__ == "__main__": |
| |
| print("="*80) |
| print("Testing BERTose model") |
| print("="*80) |
| |
| |
| config = MultimodalGlycanBERTConfig( |
| seq_vocab_size=166, |
| seq_hidden_size=768, |
| seq_num_layers=12, |
| seq_num_heads=12, |
| ms_vocab_size=242, |
| ms_hidden_size=384, |
| ms_num_layers=6, |
| ms_num_heads=6, |
| struct_vocab_size=1024, |
| struct_hidden_size=512, |
| struct_num_layers=8, |
| struct_num_heads=8, |
| use_3d=True, |
| use_cross_attention=True, |
| seq_loss_weight=0.60, |
| ms_loss_weight=0.15, |
| struct_loss_weight=0.25, |
| ) |
| |
| print(f"\nConfig:") |
| print(f" Sequence vocab: {config.seq_vocab_size}") |
| print(f" MS vocab: {config.ms_vocab_size}") |
| print(f" Structure vocab: {config.struct_vocab_size}") |
| print(f" Loss weights: seq={config.seq_loss_weight}, ms={config.ms_loss_weight}, struct={config.struct_loss_weight}") |
| |
| |
| model = MultimodalGlycanBERT(config) |
| |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| |
| print(f"\nModel Parameters:") |
| print(f" Total: {total_params:,}") |
| print(f" Trainable: {trainable_params:,}") |
| |
| |
| print(f"\n{'='*80}") |
| print("Testing Forward Pass (with Conv front-end)") |
| print("="*80) |
| |
| batch_size = 4 |
| seq_len = 128 |
| ms_len = 50 |
| struct_len = 40 |
| |
| |
| seq_token_ids = torch.randint(0, config.seq_vocab_size, (batch_size, seq_len)) |
| seq_attention_mask = torch.ones(batch_size, seq_len) |
| |
| seq_residue_ids = torch.div( |
| torch.arange(seq_len), 5, rounding_mode="floor" |
| ).unsqueeze(0).expand(batch_size, -1) |
|
|
| ms_token_ids = torch.randint(config.ms_vocab_offset, config.ms_total_vocab_size, (batch_size, ms_len)) |
| ms_attention_mask = torch.ones(batch_size, ms_len) |
| struct_token_ids = torch.randint(0, config.struct_vocab_size, (batch_size, struct_len)) |
| struct_attention_mask = torch.ones(batch_size, struct_len) |
| |
| struct_residue_ids = torch.div( |
| torch.arange(struct_len), 4, rounding_mode="floor" |
| ).unsqueeze(0).expand(batch_size, -1) |
|
|
| has_ms = torch.tensor([True, True, False, True]) |
| has_3d = torch.tensor([True, False, True, True]) |
| |
| |
| seq_labels = seq_token_ids.clone() |
| seq_labels[seq_labels != config.mask_token_id] = -100 |
| ms_labels = ms_token_ids.clone() |
| ms_labels[ms_labels != config.mask_token_id] = -100 |
| struct_labels = struct_token_ids.clone() |
| struct_labels[struct_labels != config.mask_token_id] = -100 |
| |
| |
| outputs = model( |
| seq_token_ids=seq_token_ids, |
| seq_attention_mask=seq_attention_mask, |
| seq_residue_ids=seq_residue_ids, |
| ms_token_ids=ms_token_ids, |
| ms_attention_mask=ms_attention_mask, |
| has_ms=has_ms, |
| struct_token_ids=struct_token_ids, |
| struct_attention_mask=struct_attention_mask, |
| struct_residue_ids=struct_residue_ids, |
| has_3d=has_3d, |
| seq_labels=seq_labels, |
| ms_labels=ms_labels, |
| struct_labels=struct_labels, |
| ) |
| |
| print(f"\nOutput shapes:") |
| print(f" seq_logits: {outputs['seq_logits'].shape}") |
| print(f" ms_logits: {outputs['ms_logits'].shape}") |
| print(f" struct_logits: {outputs['struct_logits'].shape}") |
| print(f" fused_repr: {outputs['fused_repr'].shape}") |
| |
| print(f"\nLosses:") |
| print(f" Total loss: {outputs['loss'].item():.4f}") |
| print(f" Sequence loss: {outputs['seq_loss'].item():.4f}") |
| print(f" MS loss: {outputs['ms_loss'].item():.4f}") |
| print(f" Structure loss: {outputs['struct_loss'].item():.4f}") |
| |
| print(f"\n{'='*80}") |
| print("Model Test Complete!") |
| print("="*80) |
|
|