| """ |
| Multimodal Masking Strategy for Glycan BERT v3 |
| |
| Implements masked language modeling across three modalities: |
| - Sequence (WURCS atomic tokenization) |
| - MS (mass spectrometry peaks) |
| - Structure (VQ-VAE discrete tokens) |
| |
| Each modality can be masked independently with different probabilities. |
| """ |
|
|
| import torch |
| import random |
| from typing import List, Tuple, Dict, Optional |
|
|
| try: |
| from .masking import GlycanMaskingStrategy, HierarchicalMaskingStrategy |
| except ImportError: |
| from masking import GlycanMaskingStrategy, HierarchicalMaskingStrategy |
|
|
|
|
| class MultimodalMaskingStrategy: |
| """ |
| Masking strategy for multimodal glycan BERT. |
| |
| Handles masking across sequence, MS, and structure modalities. |
| |
| NEW: Supports hierarchical masking (residue-level + token-level) when enabled. |
| """ |
| |
| def __init__( |
| self, |
| |
| seq_vocab_size: int, |
| seq_mask_token_id: int, |
| seq_pad_token_id: int, |
| seq_special_token_ids: List[int], |
| |
| |
| ms_vocab_size: int, |
| ms_vocab_offset: int, |
| ms_mask_token_id: int, |
| ms_pad_token_id: int, |
| ms_special_token_ids: List[int], |
| |
| |
| struct_vocab_size: int, |
| struct_mask_token_id: int, |
| struct_pad_token_id: int, |
| struct_special_token_ids: List[int], |
| |
| |
| seq_ambiguous_token_ids: List[int] = None, |
| seq_mask_prob: float = 0.15, |
| ms_mask_prob: float = 0.15, |
| struct_mask_prob: float = 0.15, |
| |
| |
| mask_token_prob: float = 0.8, |
| random_token_prob: float = 0.1, |
| unchanged_prob: float = 0.1, |
| |
| |
| use_hierarchical_masking: bool = False, |
| token_mask_prob: float = 0.10, |
| residue_mask_prob: float = 0.10, |
| |
| seed: int = None |
| ): |
| """ |
| Initialize multimodal masking strategy. |
| |
| Args: |
| seq_vocab_size: Sequence vocabulary size |
| seq_mask_token_id: Sequence [MASK] token ID |
| seq_pad_token_id: Sequence [PAD] token ID |
| seq_special_token_ids: Sequence special token IDs to never mask |
| seq_ambiguous_token_ids: Sequence ambiguous token IDs to never mask |
| seq_mask_prob: Probability of masking sequence tokens |
| |
| ms_vocab_size: MS vocabulary size |
| ms_vocab_offset: MS vocabulary offset (where MS tokens start in combined vocab) |
| ms_mask_token_id: MS [MASK] token ID |
| ms_pad_token_id: MS [PAD] token ID |
| ms_special_token_ids: MS special token IDs to never mask |
| ms_mask_prob: Probability of masking MS tokens |
| |
| struct_vocab_size: Structure vocabulary size |
| struct_mask_token_id: Structure [MASK] token ID |
| struct_pad_token_id: Structure [PAD] token ID |
| struct_special_token_ids: Structure special token IDs to never mask |
| struct_mask_prob: Probability of masking structure tokens |
| |
| mask_token_prob: Probability of replacing with [MASK] |
| random_token_prob: Probability of replacing with random token |
| unchanged_prob: Probability of leaving unchanged |
| |
| use_hierarchical_masking: If True, use hierarchical (token+residue level) masking |
| token_mask_prob: Token-level mask probability for hierarchical |
| residue_mask_prob: Residue-level mask probability for hierarchical |
| |
| seed: Random seed for reproducibility |
| """ |
| self.use_hierarchical_masking = use_hierarchical_masking |
| |
| if use_hierarchical_masking: |
| |
| self.seq_masker = HierarchicalMaskingStrategy( |
| vocab_size=seq_vocab_size, |
| mask_token_id=seq_mask_token_id, |
| pad_token_id=seq_pad_token_id, |
| special_token_ids=seq_special_token_ids, |
| ambiguous_token_ids=seq_ambiguous_token_ids, |
| token_mask_prob=token_mask_prob, |
| residue_mask_prob=residue_mask_prob, |
| seed=seed, |
| ) |
| else: |
| |
| self.seq_masker = GlycanMaskingStrategy( |
| vocab_size=seq_vocab_size, |
| mask_token_id=seq_mask_token_id, |
| pad_token_id=seq_pad_token_id, |
| special_token_ids=seq_special_token_ids, |
| ambiguous_token_ids=seq_ambiguous_token_ids, |
| mask_prob=seq_mask_prob, |
| mask_token_prob=mask_token_prob, |
| random_token_prob=random_token_prob, |
| unchanged_prob=unchanged_prob, |
| seed=seed, |
| ) |
| |
| |
| self.ms_masker = GlycanMaskingStrategy( |
| vocab_size=ms_vocab_size + ms_vocab_offset, |
| mask_token_id=ms_mask_token_id, |
| pad_token_id=ms_pad_token_id, |
| special_token_ids=ms_special_token_ids, |
| ambiguous_token_ids=[], |
| mask_prob=ms_mask_prob, |
| mask_token_prob=mask_token_prob, |
| random_token_prob=random_token_prob, |
| unchanged_prob=unchanged_prob, |
| seed=seed, |
| ) |
| |
| |
| self.struct_masker = GlycanMaskingStrategy( |
| vocab_size=struct_vocab_size, |
| mask_token_id=struct_mask_token_id, |
| pad_token_id=struct_pad_token_id, |
| special_token_ids=struct_special_token_ids, |
| ambiguous_token_ids=[], |
| mask_prob=struct_mask_prob, |
| mask_token_prob=mask_token_prob, |
| random_token_prob=random_token_prob, |
| unchanged_prob=unchanged_prob, |
| seed=seed, |
| ) |
| |
| self.ms_vocab_offset = ms_vocab_offset |
| |
| def mask_multimodal_batch( |
| self, |
| seq_token_ids: torch.Tensor, |
| ms_token_ids: torch.Tensor, |
| has_ms: torch.Tensor, |
| struct_token_ids: Optional[torch.Tensor] = None, |
| has_3d: Optional[torch.Tensor] = None, |
| |
| seq_residue_ids: Optional[torch.Tensor] = None, |
| monosaccharide_names: Optional[List[List[str]]] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Apply masking to a multimodal batch. |
| |
| Args: |
| seq_token_ids: (batch_size, seq_len) - Sequence token IDs |
| ms_token_ids: (batch_size, ms_len) - MS token IDs |
| has_ms: (batch_size,) - Boolean mask for samples with MS data |
| struct_token_ids: (batch_size, struct_len) - Structure token IDs (optional) |
| has_3d: (batch_size,) - Boolean mask for samples with 3D data (optional) |
| seq_residue_ids: (batch_size, seq_len) - Residue IDs for hierarchical masking (optional) |
| monosaccharide_names: List of lists of mono names per batch item (optional) |
| |
| Returns: |
| Dictionary containing: |
| - seq_masked_ids: Masked sequence input |
| - seq_labels: Sequence labels (-100 for unmasked) |
| - ms_masked_ids: Masked MS input |
| - ms_labels: MS labels (-100 for unmasked) |
| - struct_masked_ids: Masked structure input (if provided) |
| - struct_labels: Structure labels (-100 for unmasked, if provided) |
| - mono_labels: Monosaccharide type labels (if hierarchical) |
| - statistics: Masking statistics |
| """ |
| |
| if self.use_hierarchical_masking and seq_residue_ids is not None: |
| |
| hier_result = self.seq_masker.mask_sequence( |
| seq_token_ids, seq_residue_ids, monosaccharide_names |
| ) |
| seq_masked_ids = hier_result['masked_input_ids'] |
| seq_labels = hier_result['token_labels'] |
| seq_mask_positions = (seq_labels != -100) |
| mono_labels = hier_result.get('mono_labels', None) |
| else: |
| |
| seq_masked_ids, seq_labels, seq_mask_positions = self.seq_masker.mask_sequence(seq_token_ids) |
| mono_labels = None |
| |
| |
| ms_masked_ids, ms_labels, ms_mask_positions = self.ms_masker.mask_sequence(ms_token_ids) |
| |
| |
| if has_ms is not None: |
| ms_labels[~has_ms] = -100 |
| |
| result = { |
| 'seq_masked_ids': seq_masked_ids, |
| 'seq_labels': seq_labels, |
| 'ms_masked_ids': ms_masked_ids, |
| 'ms_labels': ms_labels, |
| } |
| |
| if mono_labels is not None: |
| result['mono_labels'] = mono_labels |
| |
| |
| if struct_token_ids is not None: |
| struct_masked_ids, struct_labels, struct_mask_positions = self.struct_masker.mask_sequence(struct_token_ids) |
| |
| |
| if has_3d is not None: |
| struct_labels[~has_3d] = -100 |
| |
| result['struct_masked_ids'] = struct_masked_ids |
| result['struct_labels'] = struct_labels |
| else: |
| struct_mask_positions = None |
| |
| |
| if not self.use_hierarchical_masking: |
| seq_stats = self.seq_masker.get_mask_statistics(seq_token_ids, seq_masked_ids, seq_mask_positions) |
| else: |
| seq_stats = {'masked_tokens': seq_mask_positions.sum().item()} |
| |
| ms_stats = self.ms_masker.get_mask_statistics(ms_token_ids, ms_masked_ids, ms_mask_positions) |
| |
| stats = { |
| 'seq': seq_stats, |
| 'ms': ms_stats, |
| } |
| |
| if struct_token_ids is not None and struct_mask_positions is not None: |
| struct_stats = self.struct_masker.get_mask_statistics(struct_token_ids, struct_masked_ids, struct_mask_positions) |
| stats['struct'] = struct_stats |
| |
| result['statistics'] = stats |
| |
| return result |
|
|
|
|
| if __name__ == "__main__": |
| |
| print("="*80) |
| print("Testing Multimodal Masking Strategy") |
| print("="*80) |
| |
| |
| masker = MultimodalMaskingStrategy( |
| seq_vocab_size=166, |
| seq_mask_token_id=1, |
| seq_pad_token_id=0, |
| seq_special_token_ids=[0, 1, 2, 3], |
| seq_ambiguous_token_ids=[10, 11, 12], |
| seq_mask_prob=0.15, |
| |
| ms_vocab_size=242, |
| ms_vocab_offset=166, |
| ms_mask_token_id=1, |
| ms_pad_token_id=0, |
| ms_special_token_ids=[0, 1, 2, 3], |
| ms_mask_prob=0.15, |
| |
| struct_vocab_size=1024, |
| struct_mask_token_id=1, |
| struct_pad_token_id=0, |
| struct_special_token_ids=[0, 1], |
| struct_mask_prob=0.15, |
| |
| seed=42, |
| ) |
| |
| |
| batch_size = 4 |
| seq_len = 50 |
| ms_len = 30 |
| struct_len = 40 |
| |
| seq_token_ids = torch.randint(4, 166, (batch_size, seq_len)) |
| ms_token_ids = torch.randint(166, 408, (batch_size, ms_len)) |
| struct_token_ids = torch.randint(2, 1024, (batch_size, struct_len)) |
| has_ms = torch.tensor([True, True, False, True]) |
| has_3d = torch.tensor([True, False, True, True]) |
| |
| |
| result = masker.mask_multimodal_batch( |
| seq_token_ids=seq_token_ids, |
| ms_token_ids=ms_token_ids, |
| has_ms=has_ms, |
| struct_token_ids=struct_token_ids, |
| has_3d=has_3d, |
| ) |
| |
| print("\nMasked batch shapes:") |
| print(f" seq_masked_ids: {result['seq_masked_ids'].shape}") |
| print(f" seq_labels: {result['seq_labels'].shape}") |
| print(f" ms_masked_ids: {result['ms_masked_ids'].shape}") |
| print(f" ms_labels: {result['ms_labels'].shape}") |
| print(f" struct_masked_ids: {result['struct_masked_ids'].shape}") |
| print(f" struct_labels: {result['struct_labels'].shape}") |
| |
| print("\nSequence masking statistics:") |
| for key, value in result['statistics']['seq'].items(): |
| print(f" {key}: {value}") |
| |
| print("\nMS masking statistics:") |
| for key, value in result['statistics']['ms'].items(): |
| print(f" {key}: {value}") |
| |
| print("\nStructure masking statistics:") |
| for key, value in result['statistics']['struct'].items(): |
| print(f" {key}: {value}") |
| |
| |
| print(f"\nMS labels for sample 2 (no MS): {result['ms_labels'][2].unique()}") |
| print(f"MS labels for sample 0 (has MS): {result['ms_labels'][0].unique()[:10]}") |
| |
| |
| print(f"\nStructure labels for sample 1 (no 3D): {result['struct_labels'][1].unique()}") |
| print(f"Structure labels for sample 0 (has 3D): {result['struct_labels'][0].unique()[:10]}") |
| |
| print(f"\n{'='*80}") |
| print("Multimodal Masking Test Complete!") |
| print("="*80) |
|
|
|
|