| """ |
| ESM2-Style Masking Strategy for Glycan BERT |
| |
| Implements masked language modeling following the ESM2 approach: |
| - Mask 15% of tokens randomly |
| - 80% replaced with [MASK] |
| - 10% replaced with random token |
| - 10% unchanged (for robustness) |
| """ |
|
|
| import torch |
| import random |
| from typing import List, Tuple |
|
|
|
|
| class GlycanMaskingStrategy: |
| """ |
| Masking strategy for glycan sequences following ESM2. |
| """ |
| |
| def __init__( |
| self, |
| vocab_size: int, |
| mask_token_id: int, |
| pad_token_id: int, |
| special_token_ids: List[int], |
| ambiguous_token_ids: List[int] = None, |
| mask_prob: float = 0.15, |
| mask_token_prob: float = 0.8, |
| random_token_prob: float = 0.1, |
| unchanged_prob: float = 0.1, |
| seed: int = None |
| ): |
| """ |
| Initialize masking strategy. |
| |
| Args: |
| vocab_size: Size of vocabulary |
| mask_token_id: ID of [MASK] token |
| pad_token_id: ID of [PAD] token |
| special_token_ids: List of special token IDs to never mask |
| ambiguous_token_ids: List of ambiguous token IDs to never mask (x, X, ?, u, d, o) |
| mask_prob: Probability of masking a token (default: 0.15) |
| mask_token_prob: Probability of replacing with [MASK] (default: 0.8) |
| random_token_prob: Probability of replacing with random token (default: 0.1) |
| unchanged_prob: Probability of leaving unchanged (default: 0.1) |
| seed: Random seed for reproducibility |
| """ |
| assert abs(mask_token_prob + random_token_prob + unchanged_prob - 1.0) < 1e-6, \ |
| "Masking probabilities must sum to 1.0" |
| |
| self.vocab_size = vocab_size |
| self.mask_token_id = mask_token_id |
| self.pad_token_id = pad_token_id |
| self.special_token_ids = set(special_token_ids) |
| self.ambiguous_token_ids = set(ambiguous_token_ids) if ambiguous_token_ids else set() |
| |
| self.mask_prob = mask_prob |
| self.mask_token_prob = mask_token_prob |
| self.random_token_prob = random_token_prob |
| self.unchanged_prob = unchanged_prob |
| |
| if seed is not None: |
| random.seed(seed) |
| torch.manual_seed(seed) |
| |
| def mask_sequence( |
| self, |
| input_ids: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Apply masking to a batch of sequences. |
| |
| Args: |
| input_ids: Tensor of shape (batch_size, seq_len) with token IDs |
| |
| Returns: |
| Tuple of: |
| - masked_input_ids: Input with masks applied |
| - labels: Original token IDs for masked positions (-100 for unmasked) |
| - mask_positions: Boolean tensor indicating masked positions |
| """ |
| batch_size, seq_len = input_ids.shape |
| |
| |
| labels = torch.full_like(input_ids, -100) |
| |
| |
| maskable = torch.ones_like(input_ids, dtype=torch.bool) |
| |
| |
| maskable &= (input_ids != self.pad_token_id) |
| for special_id in self.special_token_ids: |
| maskable &= (input_ids != special_id) |
| |
| |
| for ambig_id in self.ambiguous_token_ids: |
| maskable &= (input_ids != ambig_id) |
| |
| |
| mask_positions = torch.zeros_like(input_ids, dtype=torch.bool) |
| for i in range(batch_size): |
| maskable_indices = maskable[i].nonzero(as_tuple=True)[0] |
| if len(maskable_indices) == 0: |
| continue |
| |
| n_to_mask = max(1, int(len(maskable_indices) * self.mask_prob)) |
| mask_indices = maskable_indices[torch.randperm(len(maskable_indices))[:n_to_mask]] |
| mask_positions[i, mask_indices] = True |
| |
| |
| labels[mask_positions] = input_ids[mask_positions] |
| |
| |
| masked_input_ids = input_ids.clone() |
| |
| |
| masked_indices = mask_positions.nonzero(as_tuple=True) |
| for batch_idx, pos_idx in zip(*masked_indices): |
| rand_val = random.random() |
| |
| if rand_val < self.mask_token_prob: |
| |
| masked_input_ids[batch_idx, pos_idx] = self.mask_token_id |
| elif rand_val < self.mask_token_prob + self.random_token_prob: |
| |
| random_token = random.randint(0, self.vocab_size - 1) |
| while random_token in self.special_token_ids or random_token == self.pad_token_id: |
| random_token = random.randint(0, self.vocab_size - 1) |
| masked_input_ids[batch_idx, pos_idx] = random_token |
| |
| |
| return masked_input_ids, labels, mask_positions |
| |
| def get_mask_statistics( |
| self, |
| input_ids: torch.Tensor, |
| masked_input_ids: torch.Tensor, |
| mask_positions: torch.Tensor |
| ) -> dict: |
| """ |
| Calculate statistics about masking for logging. |
| |
| Args: |
| input_ids: Original input IDs |
| masked_input_ids: Masked input IDs |
| mask_positions: Boolean mask indicating masked positions |
| |
| Returns: |
| Dictionary with masking statistics |
| """ |
| total_tokens = (input_ids != self.pad_token_id).sum().item() |
| masked_tokens = mask_positions.sum().item() |
| |
| |
| mask_token_count = (masked_input_ids[mask_positions] == self.mask_token_id).sum().item() |
| random_token_count = ((masked_input_ids[mask_positions] != self.mask_token_id) & |
| (masked_input_ids[mask_positions] != input_ids[mask_positions])).sum().item() |
| unchanged_count = masked_tokens - mask_token_count - random_token_count |
| |
| |
| ambiguous_tokens = 0 |
| for ambig_id in self.ambiguous_token_ids: |
| ambiguous_tokens += (input_ids == ambig_id).sum().item() |
| |
| stats = { |
| 'total_tokens': total_tokens, |
| 'masked_tokens': masked_tokens, |
| 'mask_percentage': masked_tokens / total_tokens * 100 if total_tokens > 0 else 0, |
| 'mask_token_count': mask_token_count, |
| 'random_token_count': random_token_count, |
| 'unchanged_count': unchanged_count, |
| 'ambiguous_tokens': ambiguous_tokens, |
| 'ambiguous_percentage': ambiguous_tokens / total_tokens * 100 if total_tokens > 0 else 0 |
| } |
| |
| return stats |
|
|
|
|
| class MonosaccharideMaskingStrategy: |
| """ |
| Monosaccharide-level masking strategy for Glycan BERT. |
| |
| Instead of masking individual tokens, this masks entire monosaccharides |
| and optionally predicts the monosaccharide type (Glc, Gal, etc.). |
| |
| This forces the model to learn holistic monosaccharide semantics rather |
| than just local token patterns. |
| """ |
| |
| |
| MONO_TYPES = [ |
| '<UNK>', 'Glc', 'Gal', 'Man', 'Fuc', 'Xyl', 'Rha', 'Ara', |
| 'GlcNAc', 'GalNAc', 'ManNAc', 'GlcA', 'GalA', 'IdoA', |
| 'Neu5Ac', 'Neu5Gc', 'Kdn', 'GlcN', 'GalN', 'Hex', 'HexNAc', |
| 'dHex', 'Pent', 'Sia', 'GlcS', 'GalS', 'Ido', 'All', 'Alt', 'Gul', 'Tal' |
| ] |
| |
| def __init__( |
| self, |
| vocab_size: int, |
| mask_token_id: int, |
| pad_token_id: int, |
| special_token_ids: List[int], |
| mask_prob: float = 0.15, |
| predict_mono_type: bool = True, |
| seed: int = None |
| ): |
| """ |
| Initialize monosaccharide masking strategy. |
| |
| Args: |
| vocab_size: Size of vocabulary |
| mask_token_id: ID of [MASK] token |
| pad_token_id: ID of [PAD] token |
| special_token_ids: List of special token IDs to never mask |
| mask_prob: Probability of masking a residue (default: 0.15) |
| predict_mono_type: If True, labels are mono type IDs; if False, labels are token IDs |
| seed: Random seed |
| """ |
| self.vocab_size = vocab_size |
| self.mask_token_id = mask_token_id |
| self.pad_token_id = pad_token_id |
| self.special_token_ids = set(special_token_ids) |
| self.mask_prob = mask_prob |
| self.predict_mono_type = predict_mono_type |
| |
| |
| self.mono_to_id = {m: i for i, m in enumerate(self.MONO_TYPES)} |
| self.id_to_mono = {i: m for i, m in enumerate(self.MONO_TYPES)} |
| self.num_mono_types = len(self.MONO_TYPES) |
| |
| if seed is not None: |
| random.seed(seed) |
| torch.manual_seed(seed) |
| |
| def mask_sequence( |
| self, |
| input_ids: torch.Tensor, |
| residue_ids: torch.Tensor, |
| monosaccharide_names: List[List[str]] = None |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Apply monosaccharide-level masking. |
| |
| Args: |
| input_ids: (batch, seq_len) token IDs |
| residue_ids: (batch, seq_len) residue ID for each token (-1=special, -2=linkage, >=0=residue) |
| monosaccharide_names: List of lists of monosaccharide names per batch item |
| |
| Returns: |
| Tuple of: |
| - masked_input_ids: Input with entire residue tokens masked |
| - token_labels: Original token IDs for masked positions (-100 for unmasked) |
| - mono_labels: Monosaccharide type IDs for masked residues (-100 for unmasked) |
| - mask_positions: Boolean tensor indicating masked token positions |
| """ |
| batch_size, seq_len = input_ids.shape |
| |
| |
| masked_input_ids = input_ids.clone() |
| token_labels = torch.full_like(input_ids, -100) |
| mono_labels = torch.full((batch_size,), -100, dtype=torch.long, device=input_ids.device) |
| mask_positions = torch.zeros_like(input_ids, dtype=torch.bool) |
| |
| for b in range(batch_size): |
| |
| unique_residues = torch.unique(residue_ids[b]) |
| real_residues = unique_residues[unique_residues >= 0].tolist() |
| |
| if len(real_residues) == 0: |
| continue |
| |
| |
| n_to_mask = max(1, int(len(real_residues) * self.mask_prob)) |
| random.shuffle(real_residues) |
| residues_to_mask = real_residues[:n_to_mask] |
| |
| for rid in residues_to_mask: |
| |
| token_mask = residue_ids[b] == rid |
| |
| |
| token_labels[b, token_mask] = input_ids[b, token_mask] |
| |
| |
| masked_input_ids[b, token_mask] = self.mask_token_id |
| mask_positions[b, token_mask] = True |
| |
| |
| if self.predict_mono_type and monosaccharide_names is not None: |
| if rid < len(monosaccharide_names[b]): |
| mono_name = monosaccharide_names[b][rid] |
| mono_labels[b] = self.mono_to_id.get(mono_name, 0) |
| |
| return masked_input_ids, token_labels, mono_labels, mask_positions |
| |
| def get_mono_type_id(self, mono_name: str) -> int: |
| """Convert monosaccharide name to ID.""" |
| return self.mono_to_id.get(mono_name, 0) |
|
|
|
|
| class HierarchicalMaskingStrategy: |
| """ |
| Hierarchical masking combining token-level and monosaccharide-level. |
| |
| Novel approach: |
| 1. Token-level MLM: Predict individual masked tokens (like BERT) |
| 2. Residue-level MLM: Predict monosaccharide types from masked residues |
| 3. Global contrastive: Align sequence with MS/3D representations |
| |
| This provides multi-scale supervision for better glycan understanding. |
| """ |
| |
| def __init__( |
| self, |
| vocab_size: int, |
| mask_token_id: int, |
| pad_token_id: int, |
| special_token_ids: List[int], |
| ambiguous_token_ids: List[int] = None, |
| token_mask_prob: float = 0.10, |
| residue_mask_prob: float = 0.10, |
| seed: int = None |
| ): |
| """ |
| Initialize hierarchical masking. |
| |
| Args: |
| vocab_size: Size of vocabulary |
| mask_token_id: ID of [MASK] token |
| pad_token_id: ID of [PAD] token |
| special_token_ids: Special tokens to never mask |
| ambiguous_token_ids: Ambiguous tokens to never mask |
| token_mask_prob: Probability to mask individual tokens |
| residue_mask_prob: Probability to mask entire residues |
| """ |
| self.vocab_size = vocab_size |
| self.mask_token_id = mask_token_id |
| self.pad_token_id = pad_token_id |
| self.special_token_ids = set(special_token_ids) |
| self.ambiguous_token_ids = set(ambiguous_token_ids) if ambiguous_token_ids else set() |
| self.token_mask_prob = token_mask_prob |
| self.residue_mask_prob = residue_mask_prob |
| |
| |
| self.MONO_TYPES = MonosaccharideMaskingStrategy.MONO_TYPES |
| self.mono_to_id = {m: i for i, m in enumerate(self.MONO_TYPES)} |
| self.num_mono_types = len(self.MONO_TYPES) |
| |
| if seed is not None: |
| random.seed(seed) |
| torch.manual_seed(seed) |
| |
| def mask_sequence( |
| self, |
| input_ids: torch.Tensor, |
| residue_ids: torch.Tensor, |
| monosaccharide_names: List[List[str]] = None |
| ) -> dict: |
| """ |
| Apply hierarchical masking at both token and residue levels. |
| |
| Returns: |
| Dictionary with: |
| - masked_input_ids: Input with masks applied |
| - token_labels: Token-level labels for MLM (-100 for unmasked) |
| - residue_mask: Which residues were completely masked |
| - mono_labels: Monosaccharide type labels for masked residues |
| """ |
| batch_size, seq_len = input_ids.shape |
| |
| masked_input_ids = input_ids.clone() |
| token_labels = torch.full_like(input_ids, -100) |
| residue_mask = [] |
| mono_labels = [] |
| |
| for b in range(batch_size): |
| |
| unique_residues = torch.unique(residue_ids[b]) |
| real_residues = [r.item() for r in unique_residues if r >= 0] |
| |
| batch_residue_mask = set() |
| batch_mono_labels = {} |
| |
| if len(real_residues) > 0: |
| n_residue_mask = max(0, int(len(real_residues) * self.residue_mask_prob)) |
| random.shuffle(real_residues) |
| residues_to_mask = real_residues[:n_residue_mask] |
| |
| for rid in residues_to_mask: |
| token_mask = residue_ids[b] == rid |
| |
| token_labels[b, token_mask] = input_ids[b, token_mask] |
| |
| masked_input_ids[b, token_mask] = self.mask_token_id |
| batch_residue_mask.add(rid) |
| |
| |
| if monosaccharide_names is not None and rid < len(monosaccharide_names[b]): |
| mono_name = monosaccharide_names[b][rid] |
| batch_mono_labels[rid] = self.mono_to_id.get(mono_name, 0) |
| |
| |
| maskable = torch.ones(seq_len, dtype=torch.bool, device=input_ids.device) |
| maskable &= (input_ids[b] != self.pad_token_id) |
| for special_id in self.special_token_ids: |
| maskable &= (input_ids[b] != special_id) |
| for ambig_id in self.ambiguous_token_ids: |
| maskable &= (input_ids[b] != ambig_id) |
| |
| |
| for rid in batch_residue_mask: |
| maskable &= (residue_ids[b] != rid) |
| |
| maskable_indices = maskable.nonzero(as_tuple=True)[0] |
| if len(maskable_indices) > 0: |
| n_to_mask = max(0, int(len(maskable_indices) * self.token_mask_prob)) |
| perm = torch.randperm(len(maskable_indices))[:n_to_mask] |
| for idx in maskable_indices[perm]: |
| token_labels[b, idx] = input_ids[b, idx] |
| |
| rand_val = random.random() |
| if rand_val < 0.8: |
| masked_input_ids[b, idx] = self.mask_token_id |
| elif rand_val < 0.9: |
| random_token = random.randint(0, self.vocab_size - 1) |
| masked_input_ids[b, idx] = random_token |
| |
| residue_mask.append(batch_residue_mask) |
| mono_labels.append(batch_mono_labels) |
| |
| return { |
| 'masked_input_ids': masked_input_ids, |
| 'token_labels': token_labels, |
| 'residue_mask': residue_mask, |
| 'mono_labels': mono_labels, |
| } |
|
|
|
|
|
|