bertose-affinose-training-code / code /training /multimodal_masking.py
supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
13.6 kB
"""
Multimodal Masking Strategy for Glycan BERT v3
Implements masked language modeling across three modalities:
- Sequence (WURCS atomic tokenization)
- MS (mass spectrometry peaks)
- Structure (VQ-VAE discrete tokens)
Each modality can be masked independently with different probabilities.
"""
import torch
import random
from typing import List, Tuple, Dict, Optional
try:
from .masking import GlycanMaskingStrategy, HierarchicalMaskingStrategy
except ImportError:
from masking import GlycanMaskingStrategy, HierarchicalMaskingStrategy
class MultimodalMaskingStrategy:
"""
Masking strategy for multimodal glycan BERT.
Handles masking across sequence, MS, and structure modalities.
NEW: Supports hierarchical masking (residue-level + token-level) when enabled.
"""
def __init__(
self,
# Sequence masking
seq_vocab_size: int,
seq_mask_token_id: int,
seq_pad_token_id: int,
seq_special_token_ids: List[int],
# MS masking
ms_vocab_size: int,
ms_vocab_offset: int,
ms_mask_token_id: int,
ms_pad_token_id: int,
ms_special_token_ids: List[int],
# Structure masking
struct_vocab_size: int,
struct_mask_token_id: int,
struct_pad_token_id: int,
struct_special_token_ids: List[int],
# Optional parameters
seq_ambiguous_token_ids: List[int] = None,
seq_mask_prob: float = 0.15,
ms_mask_prob: float = 0.15,
struct_mask_prob: float = 0.15,
# Common masking parameters
mask_token_prob: float = 0.8,
random_token_prob: float = 0.1,
unchanged_prob: float = 0.1,
# NEW: Hierarchical masking option
use_hierarchical_masking: bool = False,
token_mask_prob: float = 0.10, # For hierarchical: token-level prob
residue_mask_prob: float = 0.10, # For hierarchical: residue-level prob
seed: int = None
):
"""
Initialize multimodal masking strategy.
Args:
seq_vocab_size: Sequence vocabulary size
seq_mask_token_id: Sequence [MASK] token ID
seq_pad_token_id: Sequence [PAD] token ID
seq_special_token_ids: Sequence special token IDs to never mask
seq_ambiguous_token_ids: Sequence ambiguous token IDs to never mask
seq_mask_prob: Probability of masking sequence tokens
ms_vocab_size: MS vocabulary size
ms_vocab_offset: MS vocabulary offset (where MS tokens start in combined vocab)
ms_mask_token_id: MS [MASK] token ID
ms_pad_token_id: MS [PAD] token ID
ms_special_token_ids: MS special token IDs to never mask
ms_mask_prob: Probability of masking MS tokens
struct_vocab_size: Structure vocabulary size
struct_mask_token_id: Structure [MASK] token ID
struct_pad_token_id: Structure [PAD] token ID
struct_special_token_ids: Structure special token IDs to never mask
struct_mask_prob: Probability of masking structure tokens
mask_token_prob: Probability of replacing with [MASK]
random_token_prob: Probability of replacing with random token
unchanged_prob: Probability of leaving unchanged
use_hierarchical_masking: If True, use hierarchical (token+residue level) masking
token_mask_prob: Token-level mask probability for hierarchical
residue_mask_prob: Residue-level mask probability for hierarchical
seed: Random seed for reproducibility
"""
self.use_hierarchical_masking = use_hierarchical_masking
if use_hierarchical_masking:
# Use hierarchical masking for sequences
self.seq_masker = HierarchicalMaskingStrategy(
vocab_size=seq_vocab_size,
mask_token_id=seq_mask_token_id,
pad_token_id=seq_pad_token_id,
special_token_ids=seq_special_token_ids,
ambiguous_token_ids=seq_ambiguous_token_ids,
token_mask_prob=token_mask_prob,
residue_mask_prob=residue_mask_prob,
seed=seed,
)
else:
# Standard token-level masking
self.seq_masker = GlycanMaskingStrategy(
vocab_size=seq_vocab_size,
mask_token_id=seq_mask_token_id,
pad_token_id=seq_pad_token_id,
special_token_ids=seq_special_token_ids,
ambiguous_token_ids=seq_ambiguous_token_ids,
mask_prob=seq_mask_prob,
mask_token_prob=mask_token_prob,
random_token_prob=random_token_prob,
unchanged_prob=unchanged_prob,
seed=seed,
)
# MS masker (always token-level)
self.ms_masker = GlycanMaskingStrategy(
vocab_size=ms_vocab_size + ms_vocab_offset, # Total vocab including sequence
mask_token_id=ms_mask_token_id,
pad_token_id=ms_pad_token_id,
special_token_ids=ms_special_token_ids,
ambiguous_token_ids=[], # No ambiguous tokens in MS
mask_prob=ms_mask_prob,
mask_token_prob=mask_token_prob,
random_token_prob=random_token_prob,
unchanged_prob=unchanged_prob,
seed=seed,
)
# Structure masker (always token-level)
self.struct_masker = GlycanMaskingStrategy(
vocab_size=struct_vocab_size,
mask_token_id=struct_mask_token_id,
pad_token_id=struct_pad_token_id,
special_token_ids=struct_special_token_ids,
ambiguous_token_ids=[], # No ambiguous tokens in structure
mask_prob=struct_mask_prob,
mask_token_prob=mask_token_prob,
random_token_prob=random_token_prob,
unchanged_prob=unchanged_prob,
seed=seed,
)
self.ms_vocab_offset = ms_vocab_offset
def mask_multimodal_batch(
self,
seq_token_ids: torch.Tensor,
ms_token_ids: torch.Tensor,
has_ms: torch.Tensor,
struct_token_ids: Optional[torch.Tensor] = None,
has_3d: Optional[torch.Tensor] = None,
# NEW: For hierarchical masking
seq_residue_ids: Optional[torch.Tensor] = None,
monosaccharide_names: Optional[List[List[str]]] = None,
) -> Dict[str, torch.Tensor]:
"""
Apply masking to a multimodal batch.
Args:
seq_token_ids: (batch_size, seq_len) - Sequence token IDs
ms_token_ids: (batch_size, ms_len) - MS token IDs
has_ms: (batch_size,) - Boolean mask for samples with MS data
struct_token_ids: (batch_size, struct_len) - Structure token IDs (optional)
has_3d: (batch_size,) - Boolean mask for samples with 3D data (optional)
seq_residue_ids: (batch_size, seq_len) - Residue IDs for hierarchical masking (optional)
monosaccharide_names: List of lists of mono names per batch item (optional)
Returns:
Dictionary containing:
- seq_masked_ids: Masked sequence input
- seq_labels: Sequence labels (-100 for unmasked)
- ms_masked_ids: Masked MS input
- ms_labels: MS labels (-100 for unmasked)
- struct_masked_ids: Masked structure input (if provided)
- struct_labels: Structure labels (-100 for unmasked, if provided)
- mono_labels: Monosaccharide type labels (if hierarchical)
- statistics: Masking statistics
"""
# Mask sequence (different for hierarchical vs standard)
if self.use_hierarchical_masking and seq_residue_ids is not None:
# Hierarchical masking returns dict
hier_result = self.seq_masker.mask_sequence(
seq_token_ids, seq_residue_ids, monosaccharide_names
)
seq_masked_ids = hier_result['masked_input_ids']
seq_labels = hier_result['token_labels']
seq_mask_positions = (seq_labels != -100)
mono_labels = hier_result.get('mono_labels', None)
else:
# Standard masking
seq_masked_ids, seq_labels, seq_mask_positions = self.seq_masker.mask_sequence(seq_token_ids)
mono_labels = None
# Mask MS
ms_masked_ids, ms_labels, ms_mask_positions = self.ms_masker.mask_sequence(ms_token_ids)
# Zero out MS labels for samples without MS data
if has_ms is not None:
ms_labels[~has_ms] = -100
result = {
'seq_masked_ids': seq_masked_ids,
'seq_labels': seq_labels,
'ms_masked_ids': ms_masked_ids,
'ms_labels': ms_labels,
}
if mono_labels is not None:
result['mono_labels'] = mono_labels
# Mask structure if provided
if struct_token_ids is not None:
struct_masked_ids, struct_labels, struct_mask_positions = self.struct_masker.mask_sequence(struct_token_ids)
# Zero out structure labels for samples without 3D data
if has_3d is not None:
struct_labels[~has_3d] = -100
result['struct_masked_ids'] = struct_masked_ids
result['struct_labels'] = struct_labels
else:
struct_mask_positions = None
# Compute statistics (only for non-hierarchical, hierarchical doesn't have get_mask_statistics)
if not self.use_hierarchical_masking:
seq_stats = self.seq_masker.get_mask_statistics(seq_token_ids, seq_masked_ids, seq_mask_positions)
else:
seq_stats = {'masked_tokens': seq_mask_positions.sum().item()}
ms_stats = self.ms_masker.get_mask_statistics(ms_token_ids, ms_masked_ids, ms_mask_positions)
stats = {
'seq': seq_stats,
'ms': ms_stats,
}
if struct_token_ids is not None and struct_mask_positions is not None:
struct_stats = self.struct_masker.get_mask_statistics(struct_token_ids, struct_masked_ids, struct_mask_positions)
stats['struct'] = struct_stats
result['statistics'] = stats
return result
if __name__ == "__main__":
# Test multimodal masking
print("="*80)
print("Testing Multimodal Masking Strategy")
print("="*80)
# Create masking strategy
masker = MultimodalMaskingStrategy(
seq_vocab_size=166,
seq_mask_token_id=1,
seq_pad_token_id=0,
seq_special_token_ids=[0, 1, 2, 3], # [PAD], [MASK], [START], [END]
seq_ambiguous_token_ids=[10, 11, 12], # x, X, ?
seq_mask_prob=0.15,
ms_vocab_size=242,
ms_vocab_offset=166,
ms_mask_token_id=1,
ms_pad_token_id=0,
ms_special_token_ids=[0, 1, 2, 3],
ms_mask_prob=0.15,
struct_vocab_size=1024,
struct_mask_token_id=1,
struct_pad_token_id=0,
struct_special_token_ids=[0, 1],
struct_mask_prob=0.15,
seed=42,
)
# Create dummy batch
batch_size = 4
seq_len = 50
ms_len = 30
struct_len = 40
seq_token_ids = torch.randint(4, 166, (batch_size, seq_len))
ms_token_ids = torch.randint(166, 408, (batch_size, ms_len))
struct_token_ids = torch.randint(2, 1024, (batch_size, struct_len))
has_ms = torch.tensor([True, True, False, True])
has_3d = torch.tensor([True, False, True, True])
# Apply masking
result = masker.mask_multimodal_batch(
seq_token_ids=seq_token_ids,
ms_token_ids=ms_token_ids,
has_ms=has_ms,
struct_token_ids=struct_token_ids,
has_3d=has_3d,
)
print("\nMasked batch shapes:")
print(f" seq_masked_ids: {result['seq_masked_ids'].shape}")
print(f" seq_labels: {result['seq_labels'].shape}")
print(f" ms_masked_ids: {result['ms_masked_ids'].shape}")
print(f" ms_labels: {result['ms_labels'].shape}")
print(f" struct_masked_ids: {result['struct_masked_ids'].shape}")
print(f" struct_labels: {result['struct_labels'].shape}")
print("\nSequence masking statistics:")
for key, value in result['statistics']['seq'].items():
print(f" {key}: {value}")
print("\nMS masking statistics:")
for key, value in result['statistics']['ms'].items():
print(f" {key}: {value}")
print("\nStructure masking statistics:")
for key, value in result['statistics']['struct'].items():
print(f" {key}: {value}")
# Check that MS labels are zeroed for samples without MS
print(f"\nMS labels for sample 2 (no MS): {result['ms_labels'][2].unique()}")
print(f"MS labels for sample 0 (has MS): {result['ms_labels'][0].unique()[:10]}")
# Check that structure labels are zeroed for samples without 3D
print(f"\nStructure labels for sample 1 (no 3D): {result['struct_labels'][1].unique()}")
print(f"Structure labels for sample 0 (has 3D): {result['struct_labels'][0].unique()[:10]}")
print(f"\n{'='*80}")
print("Multimodal Masking Test Complete!")
print("="*80)