""" Multimodal Masking Strategy for Glycan BERT v3 Implements masked language modeling across three modalities: - Sequence (WURCS atomic tokenization) - MS (mass spectrometry peaks) - Structure (VQ-VAE discrete tokens) Each modality can be masked independently with different probabilities. """ import torch import random from typing import List, Tuple, Dict, Optional try: from .masking import GlycanMaskingStrategy, HierarchicalMaskingStrategy except ImportError: from masking import GlycanMaskingStrategy, HierarchicalMaskingStrategy class MultimodalMaskingStrategy: """ Masking strategy for multimodal glycan BERT. Handles masking across sequence, MS, and structure modalities. NEW: Supports hierarchical masking (residue-level + token-level) when enabled. """ def __init__( self, # Sequence masking seq_vocab_size: int, seq_mask_token_id: int, seq_pad_token_id: int, seq_special_token_ids: List[int], # MS masking ms_vocab_size: int, ms_vocab_offset: int, ms_mask_token_id: int, ms_pad_token_id: int, ms_special_token_ids: List[int], # Structure masking struct_vocab_size: int, struct_mask_token_id: int, struct_pad_token_id: int, struct_special_token_ids: List[int], # Optional parameters seq_ambiguous_token_ids: List[int] = None, seq_mask_prob: float = 0.15, ms_mask_prob: float = 0.15, struct_mask_prob: float = 0.15, # Common masking parameters mask_token_prob: float = 0.8, random_token_prob: float = 0.1, unchanged_prob: float = 0.1, # NEW: Hierarchical masking option use_hierarchical_masking: bool = False, token_mask_prob: float = 0.10, # For hierarchical: token-level prob residue_mask_prob: float = 0.10, # For hierarchical: residue-level prob seed: int = None ): """ Initialize multimodal masking strategy. Args: seq_vocab_size: Sequence vocabulary size seq_mask_token_id: Sequence [MASK] token ID seq_pad_token_id: Sequence [PAD] token ID seq_special_token_ids: Sequence special token IDs to never mask seq_ambiguous_token_ids: Sequence ambiguous token IDs to never mask seq_mask_prob: Probability of masking sequence tokens ms_vocab_size: MS vocabulary size ms_vocab_offset: MS vocabulary offset (where MS tokens start in combined vocab) ms_mask_token_id: MS [MASK] token ID ms_pad_token_id: MS [PAD] token ID ms_special_token_ids: MS special token IDs to never mask ms_mask_prob: Probability of masking MS tokens struct_vocab_size: Structure vocabulary size struct_mask_token_id: Structure [MASK] token ID struct_pad_token_id: Structure [PAD] token ID struct_special_token_ids: Structure special token IDs to never mask struct_mask_prob: Probability of masking structure tokens mask_token_prob: Probability of replacing with [MASK] random_token_prob: Probability of replacing with random token unchanged_prob: Probability of leaving unchanged use_hierarchical_masking: If True, use hierarchical (token+residue level) masking token_mask_prob: Token-level mask probability for hierarchical residue_mask_prob: Residue-level mask probability for hierarchical seed: Random seed for reproducibility """ self.use_hierarchical_masking = use_hierarchical_masking if use_hierarchical_masking: # Use hierarchical masking for sequences self.seq_masker = HierarchicalMaskingStrategy( vocab_size=seq_vocab_size, mask_token_id=seq_mask_token_id, pad_token_id=seq_pad_token_id, special_token_ids=seq_special_token_ids, ambiguous_token_ids=seq_ambiguous_token_ids, token_mask_prob=token_mask_prob, residue_mask_prob=residue_mask_prob, seed=seed, ) else: # Standard token-level masking self.seq_masker = GlycanMaskingStrategy( vocab_size=seq_vocab_size, mask_token_id=seq_mask_token_id, pad_token_id=seq_pad_token_id, special_token_ids=seq_special_token_ids, ambiguous_token_ids=seq_ambiguous_token_ids, mask_prob=seq_mask_prob, mask_token_prob=mask_token_prob, random_token_prob=random_token_prob, unchanged_prob=unchanged_prob, seed=seed, ) # MS masker (always token-level) self.ms_masker = GlycanMaskingStrategy( vocab_size=ms_vocab_size + ms_vocab_offset, # Total vocab including sequence mask_token_id=ms_mask_token_id, pad_token_id=ms_pad_token_id, special_token_ids=ms_special_token_ids, ambiguous_token_ids=[], # No ambiguous tokens in MS mask_prob=ms_mask_prob, mask_token_prob=mask_token_prob, random_token_prob=random_token_prob, unchanged_prob=unchanged_prob, seed=seed, ) # Structure masker (always token-level) self.struct_masker = GlycanMaskingStrategy( vocab_size=struct_vocab_size, mask_token_id=struct_mask_token_id, pad_token_id=struct_pad_token_id, special_token_ids=struct_special_token_ids, ambiguous_token_ids=[], # No ambiguous tokens in structure mask_prob=struct_mask_prob, mask_token_prob=mask_token_prob, random_token_prob=random_token_prob, unchanged_prob=unchanged_prob, seed=seed, ) self.ms_vocab_offset = ms_vocab_offset def mask_multimodal_batch( self, seq_token_ids: torch.Tensor, ms_token_ids: torch.Tensor, has_ms: torch.Tensor, struct_token_ids: Optional[torch.Tensor] = None, has_3d: Optional[torch.Tensor] = None, # NEW: For hierarchical masking seq_residue_ids: Optional[torch.Tensor] = None, monosaccharide_names: Optional[List[List[str]]] = None, ) -> Dict[str, torch.Tensor]: """ Apply masking to a multimodal batch. Args: seq_token_ids: (batch_size, seq_len) - Sequence token IDs ms_token_ids: (batch_size, ms_len) - MS token IDs has_ms: (batch_size,) - Boolean mask for samples with MS data struct_token_ids: (batch_size, struct_len) - Structure token IDs (optional) has_3d: (batch_size,) - Boolean mask for samples with 3D data (optional) seq_residue_ids: (batch_size, seq_len) - Residue IDs for hierarchical masking (optional) monosaccharide_names: List of lists of mono names per batch item (optional) Returns: Dictionary containing: - seq_masked_ids: Masked sequence input - seq_labels: Sequence labels (-100 for unmasked) - ms_masked_ids: Masked MS input - ms_labels: MS labels (-100 for unmasked) - struct_masked_ids: Masked structure input (if provided) - struct_labels: Structure labels (-100 for unmasked, if provided) - mono_labels: Monosaccharide type labels (if hierarchical) - statistics: Masking statistics """ # Mask sequence (different for hierarchical vs standard) if self.use_hierarchical_masking and seq_residue_ids is not None: # Hierarchical masking returns dict hier_result = self.seq_masker.mask_sequence( seq_token_ids, seq_residue_ids, monosaccharide_names ) seq_masked_ids = hier_result['masked_input_ids'] seq_labels = hier_result['token_labels'] seq_mask_positions = (seq_labels != -100) mono_labels = hier_result.get('mono_labels', None) else: # Standard masking seq_masked_ids, seq_labels, seq_mask_positions = self.seq_masker.mask_sequence(seq_token_ids) mono_labels = None # Mask MS ms_masked_ids, ms_labels, ms_mask_positions = self.ms_masker.mask_sequence(ms_token_ids) # Zero out MS labels for samples without MS data if has_ms is not None: ms_labels[~has_ms] = -100 result = { 'seq_masked_ids': seq_masked_ids, 'seq_labels': seq_labels, 'ms_masked_ids': ms_masked_ids, 'ms_labels': ms_labels, } if mono_labels is not None: result['mono_labels'] = mono_labels # Mask structure if provided if struct_token_ids is not None: struct_masked_ids, struct_labels, struct_mask_positions = self.struct_masker.mask_sequence(struct_token_ids) # Zero out structure labels for samples without 3D data if has_3d is not None: struct_labels[~has_3d] = -100 result['struct_masked_ids'] = struct_masked_ids result['struct_labels'] = struct_labels else: struct_mask_positions = None # Compute statistics (only for non-hierarchical, hierarchical doesn't have get_mask_statistics) if not self.use_hierarchical_masking: seq_stats = self.seq_masker.get_mask_statistics(seq_token_ids, seq_masked_ids, seq_mask_positions) else: seq_stats = {'masked_tokens': seq_mask_positions.sum().item()} ms_stats = self.ms_masker.get_mask_statistics(ms_token_ids, ms_masked_ids, ms_mask_positions) stats = { 'seq': seq_stats, 'ms': ms_stats, } if struct_token_ids is not None and struct_mask_positions is not None: struct_stats = self.struct_masker.get_mask_statistics(struct_token_ids, struct_masked_ids, struct_mask_positions) stats['struct'] = struct_stats result['statistics'] = stats return result if __name__ == "__main__": # Test multimodal masking print("="*80) print("Testing Multimodal Masking Strategy") print("="*80) # Create masking strategy masker = MultimodalMaskingStrategy( seq_vocab_size=166, seq_mask_token_id=1, seq_pad_token_id=0, seq_special_token_ids=[0, 1, 2, 3], # [PAD], [MASK], [START], [END] seq_ambiguous_token_ids=[10, 11, 12], # x, X, ? seq_mask_prob=0.15, ms_vocab_size=242, ms_vocab_offset=166, ms_mask_token_id=1, ms_pad_token_id=0, ms_special_token_ids=[0, 1, 2, 3], ms_mask_prob=0.15, struct_vocab_size=1024, struct_mask_token_id=1, struct_pad_token_id=0, struct_special_token_ids=[0, 1], struct_mask_prob=0.15, seed=42, ) # Create dummy batch batch_size = 4 seq_len = 50 ms_len = 30 struct_len = 40 seq_token_ids = torch.randint(4, 166, (batch_size, seq_len)) ms_token_ids = torch.randint(166, 408, (batch_size, ms_len)) struct_token_ids = torch.randint(2, 1024, (batch_size, struct_len)) has_ms = torch.tensor([True, True, False, True]) has_3d = torch.tensor([True, False, True, True]) # Apply masking result = masker.mask_multimodal_batch( seq_token_ids=seq_token_ids, ms_token_ids=ms_token_ids, has_ms=has_ms, struct_token_ids=struct_token_ids, has_3d=has_3d, ) print("\nMasked batch shapes:") print(f" seq_masked_ids: {result['seq_masked_ids'].shape}") print(f" seq_labels: {result['seq_labels'].shape}") print(f" ms_masked_ids: {result['ms_masked_ids'].shape}") print(f" ms_labels: {result['ms_labels'].shape}") print(f" struct_masked_ids: {result['struct_masked_ids'].shape}") print(f" struct_labels: {result['struct_labels'].shape}") print("\nSequence masking statistics:") for key, value in result['statistics']['seq'].items(): print(f" {key}: {value}") print("\nMS masking statistics:") for key, value in result['statistics']['ms'].items(): print(f" {key}: {value}") print("\nStructure masking statistics:") for key, value in result['statistics']['struct'].items(): print(f" {key}: {value}") # Check that MS labels are zeroed for samples without MS print(f"\nMS labels for sample 2 (no MS): {result['ms_labels'][2].unique()}") print(f"MS labels for sample 0 (has MS): {result['ms_labels'][0].unique()[:10]}") # Check that structure labels are zeroed for samples without 3D print(f"\nStructure labels for sample 1 (no 3D): {result['struct_labels'][1].unique()}") print(f"Structure labels for sample 0 (has 3D): {result['struct_labels'][0].unique()[:10]}") print(f"\n{'='*80}") print("Multimodal Masking Test Complete!") print("="*80)