""" ESM2-Style Masking Strategy for Glycan BERT Implements masked language modeling following the ESM2 approach: - Mask 15% of tokens randomly - 80% replaced with [MASK] - 10% replaced with random token - 10% unchanged (for robustness) """ import torch import random from typing import List, Tuple class GlycanMaskingStrategy: """ Masking strategy for glycan sequences following ESM2. """ def __init__( self, vocab_size: int, mask_token_id: int, pad_token_id: int, special_token_ids: List[int], ambiguous_token_ids: List[int] = None, mask_prob: float = 0.15, mask_token_prob: float = 0.8, random_token_prob: float = 0.1, unchanged_prob: float = 0.1, seed: int = None ): """ Initialize masking strategy. Args: vocab_size: Size of vocabulary mask_token_id: ID of [MASK] token pad_token_id: ID of [PAD] token special_token_ids: List of special token IDs to never mask ambiguous_token_ids: List of ambiguous token IDs to never mask (x, X, ?, u, d, o) mask_prob: Probability of masking a token (default: 0.15) mask_token_prob: Probability of replacing with [MASK] (default: 0.8) random_token_prob: Probability of replacing with random token (default: 0.1) unchanged_prob: Probability of leaving unchanged (default: 0.1) seed: Random seed for reproducibility """ assert abs(mask_token_prob + random_token_prob + unchanged_prob - 1.0) < 1e-6, \ "Masking probabilities must sum to 1.0" self.vocab_size = vocab_size self.mask_token_id = mask_token_id self.pad_token_id = pad_token_id self.special_token_ids = set(special_token_ids) self.ambiguous_token_ids = set(ambiguous_token_ids) if ambiguous_token_ids else set() self.mask_prob = mask_prob self.mask_token_prob = mask_token_prob self.random_token_prob = random_token_prob self.unchanged_prob = unchanged_prob if seed is not None: random.seed(seed) torch.manual_seed(seed) def mask_sequence( self, input_ids: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Apply masking to a batch of sequences. Args: input_ids: Tensor of shape (batch_size, seq_len) with token IDs Returns: Tuple of: - masked_input_ids: Input with masks applied - labels: Original token IDs for masked positions (-100 for unmasked) - mask_positions: Boolean tensor indicating masked positions """ batch_size, seq_len = input_ids.shape # Create labels (-100 for positions we don't predict) labels = torch.full_like(input_ids, -100) # Create mask for maskable positions maskable = torch.ones_like(input_ids, dtype=torch.bool) # Don't mask padding or special tokens maskable &= (input_ids != self.pad_token_id) for special_id in self.special_token_ids: maskable &= (input_ids != special_id) # Don't mask ambiguous tokens (we don't know ground truth) for ambig_id in self.ambiguous_token_ids: maskable &= (input_ids != ambig_id) # Randomly select mask_prob of maskable tokens mask_positions = torch.zeros_like(input_ids, dtype=torch.bool) for i in range(batch_size): maskable_indices = maskable[i].nonzero(as_tuple=True)[0] if len(maskable_indices) == 0: continue n_to_mask = max(1, int(len(maskable_indices) * self.mask_prob)) mask_indices = maskable_indices[torch.randperm(len(maskable_indices))[:n_to_mask]] mask_positions[i, mask_indices] = True # Store original tokens for masked positions labels[mask_positions] = input_ids[mask_positions] # Create masked input masked_input_ids = input_ids.clone() # For each masked position, decide what to do masked_indices = mask_positions.nonzero(as_tuple=True) for batch_idx, pos_idx in zip(*masked_indices): rand_val = random.random() if rand_val < self.mask_token_prob: # Replace with [MASK] masked_input_ids[batch_idx, pos_idx] = self.mask_token_id elif rand_val < self.mask_token_prob + self.random_token_prob: # Replace with random token (excluding special tokens) random_token = random.randint(0, self.vocab_size - 1) while random_token in self.special_token_ids or random_token == self.pad_token_id: random_token = random.randint(0, self.vocab_size - 1) masked_input_ids[batch_idx, pos_idx] = random_token # else: leave unchanged (unchanged_prob) return masked_input_ids, labels, mask_positions def get_mask_statistics( self, input_ids: torch.Tensor, masked_input_ids: torch.Tensor, mask_positions: torch.Tensor ) -> dict: """ Calculate statistics about masking for logging. Args: input_ids: Original input IDs masked_input_ids: Masked input IDs mask_positions: Boolean mask indicating masked positions Returns: Dictionary with masking statistics """ total_tokens = (input_ids != self.pad_token_id).sum().item() masked_tokens = mask_positions.sum().item() # Count each masking type mask_token_count = (masked_input_ids[mask_positions] == self.mask_token_id).sum().item() random_token_count = ((masked_input_ids[mask_positions] != self.mask_token_id) & (masked_input_ids[mask_positions] != input_ids[mask_positions])).sum().item() unchanged_count = masked_tokens - mask_token_count - random_token_count # Count ambiguous tokens in batch ambiguous_tokens = 0 for ambig_id in self.ambiguous_token_ids: ambiguous_tokens += (input_ids == ambig_id).sum().item() stats = { 'total_tokens': total_tokens, 'masked_tokens': masked_tokens, 'mask_percentage': masked_tokens / total_tokens * 100 if total_tokens > 0 else 0, 'mask_token_count': mask_token_count, 'random_token_count': random_token_count, 'unchanged_count': unchanged_count, 'ambiguous_tokens': ambiguous_tokens, 'ambiguous_percentage': ambiguous_tokens / total_tokens * 100 if total_tokens > 0 else 0 } return stats class MonosaccharideMaskingStrategy: """ Monosaccharide-level masking strategy for Glycan BERT. Instead of masking individual tokens, this masks entire monosaccharides and optionally predicts the monosaccharide type (Glc, Gal, etc.). This forces the model to learn holistic monosaccharide semantics rather than just local token patterns. """ # Common monosaccharide types MONO_TYPES = [ '', 'Glc', 'Gal', 'Man', 'Fuc', 'Xyl', 'Rha', 'Ara', 'GlcNAc', 'GalNAc', 'ManNAc', 'GlcA', 'GalA', 'IdoA', 'Neu5Ac', 'Neu5Gc', 'Kdn', 'GlcN', 'GalN', 'Hex', 'HexNAc', 'dHex', 'Pent', 'Sia', 'GlcS', 'GalS', 'Ido', 'All', 'Alt', 'Gul', 'Tal' ] def __init__( self, vocab_size: int, mask_token_id: int, pad_token_id: int, special_token_ids: List[int], mask_prob: float = 0.15, # Probability to mask a residue predict_mono_type: bool = True, # If True, predict mono type; if False, predict tokens seed: int = None ): """ Initialize monosaccharide masking strategy. Args: vocab_size: Size of vocabulary mask_token_id: ID of [MASK] token pad_token_id: ID of [PAD] token special_token_ids: List of special token IDs to never mask mask_prob: Probability of masking a residue (default: 0.15) predict_mono_type: If True, labels are mono type IDs; if False, labels are token IDs seed: Random seed """ self.vocab_size = vocab_size self.mask_token_id = mask_token_id self.pad_token_id = pad_token_id self.special_token_ids = set(special_token_ids) self.mask_prob = mask_prob self.predict_mono_type = predict_mono_type # Build mono type vocabulary self.mono_to_id = {m: i for i, m in enumerate(self.MONO_TYPES)} self.id_to_mono = {i: m for i, m in enumerate(self.MONO_TYPES)} self.num_mono_types = len(self.MONO_TYPES) if seed is not None: random.seed(seed) torch.manual_seed(seed) def mask_sequence( self, input_ids: torch.Tensor, residue_ids: torch.Tensor, monosaccharide_names: List[List[str]] = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Apply monosaccharide-level masking. Args: input_ids: (batch, seq_len) token IDs residue_ids: (batch, seq_len) residue ID for each token (-1=special, -2=linkage, >=0=residue) monosaccharide_names: List of lists of monosaccharide names per batch item Returns: Tuple of: - masked_input_ids: Input with entire residue tokens masked - token_labels: Original token IDs for masked positions (-100 for unmasked) - mono_labels: Monosaccharide type IDs for masked residues (-100 for unmasked) - mask_positions: Boolean tensor indicating masked token positions """ batch_size, seq_len = input_ids.shape # Initialize outputs masked_input_ids = input_ids.clone() token_labels = torch.full_like(input_ids, -100) mono_labels = torch.full((batch_size,), -100, dtype=torch.long, device=input_ids.device) mask_positions = torch.zeros_like(input_ids, dtype=torch.bool) for b in range(batch_size): # Find unique residues (>=0 are real residues) unique_residues = torch.unique(residue_ids[b]) real_residues = unique_residues[unique_residues >= 0].tolist() if len(real_residues) == 0: continue # Select residues to mask n_to_mask = max(1, int(len(real_residues) * self.mask_prob)) random.shuffle(real_residues) residues_to_mask = real_residues[:n_to_mask] for rid in residues_to_mask: # Find all tokens belonging to this residue token_mask = residue_ids[b] == rid # Store original tokens as labels token_labels[b, token_mask] = input_ids[b, token_mask] # Mask all tokens in this residue masked_input_ids[b, token_mask] = self.mask_token_id mask_positions[b, token_mask] = True # Get monosaccharide type label if self.predict_mono_type and monosaccharide_names is not None: if rid < len(monosaccharide_names[b]): mono_name = monosaccharide_names[b][rid] mono_labels[b] = self.mono_to_id.get(mono_name, 0) # 0 = return masked_input_ids, token_labels, mono_labels, mask_positions def get_mono_type_id(self, mono_name: str) -> int: """Convert monosaccharide name to ID.""" return self.mono_to_id.get(mono_name, 0) class HierarchicalMaskingStrategy: """ Hierarchical masking combining token-level and monosaccharide-level. Novel approach: 1. Token-level MLM: Predict individual masked tokens (like BERT) 2. Residue-level MLM: Predict monosaccharide types from masked residues 3. Global contrastive: Align sequence with MS/3D representations This provides multi-scale supervision for better glycan understanding. """ def __init__( self, vocab_size: int, mask_token_id: int, pad_token_id: int, special_token_ids: List[int], ambiguous_token_ids: List[int] = None, token_mask_prob: float = 0.10, # Lower since we also mask residues residue_mask_prob: float = 0.10, # Mask some whole residues seed: int = None ): """ Initialize hierarchical masking. Args: vocab_size: Size of vocabulary mask_token_id: ID of [MASK] token pad_token_id: ID of [PAD] token special_token_ids: Special tokens to never mask ambiguous_token_ids: Ambiguous tokens to never mask token_mask_prob: Probability to mask individual tokens residue_mask_prob: Probability to mask entire residues """ self.vocab_size = vocab_size self.mask_token_id = mask_token_id self.pad_token_id = pad_token_id self.special_token_ids = set(special_token_ids) self.ambiguous_token_ids = set(ambiguous_token_ids) if ambiguous_token_ids else set() self.token_mask_prob = token_mask_prob self.residue_mask_prob = residue_mask_prob # Mono type vocabulary (same as MonosaccharideMaskingStrategy) self.MONO_TYPES = MonosaccharideMaskingStrategy.MONO_TYPES self.mono_to_id = {m: i for i, m in enumerate(self.MONO_TYPES)} self.num_mono_types = len(self.MONO_TYPES) if seed is not None: random.seed(seed) torch.manual_seed(seed) def mask_sequence( self, input_ids: torch.Tensor, residue_ids: torch.Tensor, monosaccharide_names: List[List[str]] = None ) -> dict: """ Apply hierarchical masking at both token and residue levels. Returns: Dictionary with: - masked_input_ids: Input with masks applied - token_labels: Token-level labels for MLM (-100 for unmasked) - residue_mask: Which residues were completely masked - mono_labels: Monosaccharide type labels for masked residues """ batch_size, seq_len = input_ids.shape masked_input_ids = input_ids.clone() token_labels = torch.full_like(input_ids, -100) residue_mask = [] mono_labels = [] for b in range(batch_size): # Step 1: Select residues to mask entirely unique_residues = torch.unique(residue_ids[b]) real_residues = [r.item() for r in unique_residues if r >= 0] batch_residue_mask = set() batch_mono_labels = {} if len(real_residues) > 0: n_residue_mask = max(0, int(len(real_residues) * self.residue_mask_prob)) random.shuffle(real_residues) residues_to_mask = real_residues[:n_residue_mask] for rid in residues_to_mask: token_mask = residue_ids[b] == rid # Store labels token_labels[b, token_mask] = input_ids[b, token_mask] # Mask tokens masked_input_ids[b, token_mask] = self.mask_token_id batch_residue_mask.add(rid) # Get mono label if monosaccharide_names is not None and rid < len(monosaccharide_names[b]): mono_name = monosaccharide_names[b][rid] batch_mono_labels[rid] = self.mono_to_id.get(mono_name, 0) # Step 2: Mask additional individual tokens (not in already-masked residues) maskable = torch.ones(seq_len, dtype=torch.bool, device=input_ids.device) maskable &= (input_ids[b] != self.pad_token_id) for special_id in self.special_token_ids: maskable &= (input_ids[b] != special_id) for ambig_id in self.ambiguous_token_ids: maskable &= (input_ids[b] != ambig_id) # Don't mask tokens in already-masked residues for rid in batch_residue_mask: maskable &= (residue_ids[b] != rid) maskable_indices = maskable.nonzero(as_tuple=True)[0] if len(maskable_indices) > 0: n_to_mask = max(0, int(len(maskable_indices) * self.token_mask_prob)) perm = torch.randperm(len(maskable_indices))[:n_to_mask] for idx in maskable_indices[perm]: token_labels[b, idx] = input_ids[b, idx] # 80% mask, 10% random, 10% unchanged rand_val = random.random() if rand_val < 0.8: masked_input_ids[b, idx] = self.mask_token_id elif rand_val < 0.9: random_token = random.randint(0, self.vocab_size - 1) masked_input_ids[b, idx] = random_token residue_mask.append(batch_residue_mask) mono_labels.append(batch_mono_labels) return { 'masked_input_ids': masked_input_ids, 'token_labels': token_labels, 'residue_mask': residue_mask, 'mono_labels': mono_labels, }