supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
18 kB
"""
ESM2-Style Masking Strategy for Glycan BERT
Implements masked language modeling following the ESM2 approach:
- Mask 15% of tokens randomly
- 80% replaced with [MASK]
- 10% replaced with random token
- 10% unchanged (for robustness)
"""
import torch
import random
from typing import List, Tuple
class GlycanMaskingStrategy:
"""
Masking strategy for glycan sequences following ESM2.
"""
def __init__(
self,
vocab_size: int,
mask_token_id: int,
pad_token_id: int,
special_token_ids: List[int],
ambiguous_token_ids: List[int] = None,
mask_prob: float = 0.15,
mask_token_prob: float = 0.8,
random_token_prob: float = 0.1,
unchanged_prob: float = 0.1,
seed: int = None
):
"""
Initialize masking strategy.
Args:
vocab_size: Size of vocabulary
mask_token_id: ID of [MASK] token
pad_token_id: ID of [PAD] token
special_token_ids: List of special token IDs to never mask
ambiguous_token_ids: List of ambiguous token IDs to never mask (x, X, ?, u, d, o)
mask_prob: Probability of masking a token (default: 0.15)
mask_token_prob: Probability of replacing with [MASK] (default: 0.8)
random_token_prob: Probability of replacing with random token (default: 0.1)
unchanged_prob: Probability of leaving unchanged (default: 0.1)
seed: Random seed for reproducibility
"""
assert abs(mask_token_prob + random_token_prob + unchanged_prob - 1.0) < 1e-6, \
"Masking probabilities must sum to 1.0"
self.vocab_size = vocab_size
self.mask_token_id = mask_token_id
self.pad_token_id = pad_token_id
self.special_token_ids = set(special_token_ids)
self.ambiguous_token_ids = set(ambiguous_token_ids) if ambiguous_token_ids else set()
self.mask_prob = mask_prob
self.mask_token_prob = mask_token_prob
self.random_token_prob = random_token_prob
self.unchanged_prob = unchanged_prob
if seed is not None:
random.seed(seed)
torch.manual_seed(seed)
def mask_sequence(
self,
input_ids: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Apply masking to a batch of sequences.
Args:
input_ids: Tensor of shape (batch_size, seq_len) with token IDs
Returns:
Tuple of:
- masked_input_ids: Input with masks applied
- labels: Original token IDs for masked positions (-100 for unmasked)
- mask_positions: Boolean tensor indicating masked positions
"""
batch_size, seq_len = input_ids.shape
# Create labels (-100 for positions we don't predict)
labels = torch.full_like(input_ids, -100)
# Create mask for maskable positions
maskable = torch.ones_like(input_ids, dtype=torch.bool)
# Don't mask padding or special tokens
maskable &= (input_ids != self.pad_token_id)
for special_id in self.special_token_ids:
maskable &= (input_ids != special_id)
# Don't mask ambiguous tokens (we don't know ground truth)
for ambig_id in self.ambiguous_token_ids:
maskable &= (input_ids != ambig_id)
# Randomly select mask_prob of maskable tokens
mask_positions = torch.zeros_like(input_ids, dtype=torch.bool)
for i in range(batch_size):
maskable_indices = maskable[i].nonzero(as_tuple=True)[0]
if len(maskable_indices) == 0:
continue
n_to_mask = max(1, int(len(maskable_indices) * self.mask_prob))
mask_indices = maskable_indices[torch.randperm(len(maskable_indices))[:n_to_mask]]
mask_positions[i, mask_indices] = True
# Store original tokens for masked positions
labels[mask_positions] = input_ids[mask_positions]
# Create masked input
masked_input_ids = input_ids.clone()
# For each masked position, decide what to do
masked_indices = mask_positions.nonzero(as_tuple=True)
for batch_idx, pos_idx in zip(*masked_indices):
rand_val = random.random()
if rand_val < self.mask_token_prob:
# Replace with [MASK]
masked_input_ids[batch_idx, pos_idx] = self.mask_token_id
elif rand_val < self.mask_token_prob + self.random_token_prob:
# Replace with random token (excluding special tokens)
random_token = random.randint(0, self.vocab_size - 1)
while random_token in self.special_token_ids or random_token == self.pad_token_id:
random_token = random.randint(0, self.vocab_size - 1)
masked_input_ids[batch_idx, pos_idx] = random_token
# else: leave unchanged (unchanged_prob)
return masked_input_ids, labels, mask_positions
def get_mask_statistics(
self,
input_ids: torch.Tensor,
masked_input_ids: torch.Tensor,
mask_positions: torch.Tensor
) -> dict:
"""
Calculate statistics about masking for logging.
Args:
input_ids: Original input IDs
masked_input_ids: Masked input IDs
mask_positions: Boolean mask indicating masked positions
Returns:
Dictionary with masking statistics
"""
total_tokens = (input_ids != self.pad_token_id).sum().item()
masked_tokens = mask_positions.sum().item()
# Count each masking type
mask_token_count = (masked_input_ids[mask_positions] == self.mask_token_id).sum().item()
random_token_count = ((masked_input_ids[mask_positions] != self.mask_token_id) &
(masked_input_ids[mask_positions] != input_ids[mask_positions])).sum().item()
unchanged_count = masked_tokens - mask_token_count - random_token_count
# Count ambiguous tokens in batch
ambiguous_tokens = 0
for ambig_id in self.ambiguous_token_ids:
ambiguous_tokens += (input_ids == ambig_id).sum().item()
stats = {
'total_tokens': total_tokens,
'masked_tokens': masked_tokens,
'mask_percentage': masked_tokens / total_tokens * 100 if total_tokens > 0 else 0,
'mask_token_count': mask_token_count,
'random_token_count': random_token_count,
'unchanged_count': unchanged_count,
'ambiguous_tokens': ambiguous_tokens,
'ambiguous_percentage': ambiguous_tokens / total_tokens * 100 if total_tokens > 0 else 0
}
return stats
class MonosaccharideMaskingStrategy:
"""
Monosaccharide-level masking strategy for Glycan BERT.
Instead of masking individual tokens, this masks entire monosaccharides
and optionally predicts the monosaccharide type (Glc, Gal, etc.).
This forces the model to learn holistic monosaccharide semantics rather
than just local token patterns.
"""
# Common monosaccharide types
MONO_TYPES = [
'<UNK>', 'Glc', 'Gal', 'Man', 'Fuc', 'Xyl', 'Rha', 'Ara',
'GlcNAc', 'GalNAc', 'ManNAc', 'GlcA', 'GalA', 'IdoA',
'Neu5Ac', 'Neu5Gc', 'Kdn', 'GlcN', 'GalN', 'Hex', 'HexNAc',
'dHex', 'Pent', 'Sia', 'GlcS', 'GalS', 'Ido', 'All', 'Alt', 'Gul', 'Tal'
]
def __init__(
self,
vocab_size: int,
mask_token_id: int,
pad_token_id: int,
special_token_ids: List[int],
mask_prob: float = 0.15, # Probability to mask a residue
predict_mono_type: bool = True, # If True, predict mono type; if False, predict tokens
seed: int = None
):
"""
Initialize monosaccharide masking strategy.
Args:
vocab_size: Size of vocabulary
mask_token_id: ID of [MASK] token
pad_token_id: ID of [PAD] token
special_token_ids: List of special token IDs to never mask
mask_prob: Probability of masking a residue (default: 0.15)
predict_mono_type: If True, labels are mono type IDs; if False, labels are token IDs
seed: Random seed
"""
self.vocab_size = vocab_size
self.mask_token_id = mask_token_id
self.pad_token_id = pad_token_id
self.special_token_ids = set(special_token_ids)
self.mask_prob = mask_prob
self.predict_mono_type = predict_mono_type
# Build mono type vocabulary
self.mono_to_id = {m: i for i, m in enumerate(self.MONO_TYPES)}
self.id_to_mono = {i: m for i, m in enumerate(self.MONO_TYPES)}
self.num_mono_types = len(self.MONO_TYPES)
if seed is not None:
random.seed(seed)
torch.manual_seed(seed)
def mask_sequence(
self,
input_ids: torch.Tensor,
residue_ids: torch.Tensor,
monosaccharide_names: List[List[str]] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Apply monosaccharide-level masking.
Args:
input_ids: (batch, seq_len) token IDs
residue_ids: (batch, seq_len) residue ID for each token (-1=special, -2=linkage, >=0=residue)
monosaccharide_names: List of lists of monosaccharide names per batch item
Returns:
Tuple of:
- masked_input_ids: Input with entire residue tokens masked
- token_labels: Original token IDs for masked positions (-100 for unmasked)
- mono_labels: Monosaccharide type IDs for masked residues (-100 for unmasked)
- mask_positions: Boolean tensor indicating masked token positions
"""
batch_size, seq_len = input_ids.shape
# Initialize outputs
masked_input_ids = input_ids.clone()
token_labels = torch.full_like(input_ids, -100)
mono_labels = torch.full((batch_size,), -100, dtype=torch.long, device=input_ids.device)
mask_positions = torch.zeros_like(input_ids, dtype=torch.bool)
for b in range(batch_size):
# Find unique residues (>=0 are real residues)
unique_residues = torch.unique(residue_ids[b])
real_residues = unique_residues[unique_residues >= 0].tolist()
if len(real_residues) == 0:
continue
# Select residues to mask
n_to_mask = max(1, int(len(real_residues) * self.mask_prob))
random.shuffle(real_residues)
residues_to_mask = real_residues[:n_to_mask]
for rid in residues_to_mask:
# Find all tokens belonging to this residue
token_mask = residue_ids[b] == rid
# Store original tokens as labels
token_labels[b, token_mask] = input_ids[b, token_mask]
# Mask all tokens in this residue
masked_input_ids[b, token_mask] = self.mask_token_id
mask_positions[b, token_mask] = True
# Get monosaccharide type label
if self.predict_mono_type and monosaccharide_names is not None:
if rid < len(monosaccharide_names[b]):
mono_name = monosaccharide_names[b][rid]
mono_labels[b] = self.mono_to_id.get(mono_name, 0) # 0 = <UNK>
return masked_input_ids, token_labels, mono_labels, mask_positions
def get_mono_type_id(self, mono_name: str) -> int:
"""Convert monosaccharide name to ID."""
return self.mono_to_id.get(mono_name, 0)
class HierarchicalMaskingStrategy:
"""
Hierarchical masking combining token-level and monosaccharide-level.
Novel approach:
1. Token-level MLM: Predict individual masked tokens (like BERT)
2. Residue-level MLM: Predict monosaccharide types from masked residues
3. Global contrastive: Align sequence with MS/3D representations
This provides multi-scale supervision for better glycan understanding.
"""
def __init__(
self,
vocab_size: int,
mask_token_id: int,
pad_token_id: int,
special_token_ids: List[int],
ambiguous_token_ids: List[int] = None,
token_mask_prob: float = 0.10, # Lower since we also mask residues
residue_mask_prob: float = 0.10, # Mask some whole residues
seed: int = None
):
"""
Initialize hierarchical masking.
Args:
vocab_size: Size of vocabulary
mask_token_id: ID of [MASK] token
pad_token_id: ID of [PAD] token
special_token_ids: Special tokens to never mask
ambiguous_token_ids: Ambiguous tokens to never mask
token_mask_prob: Probability to mask individual tokens
residue_mask_prob: Probability to mask entire residues
"""
self.vocab_size = vocab_size
self.mask_token_id = mask_token_id
self.pad_token_id = pad_token_id
self.special_token_ids = set(special_token_ids)
self.ambiguous_token_ids = set(ambiguous_token_ids) if ambiguous_token_ids else set()
self.token_mask_prob = token_mask_prob
self.residue_mask_prob = residue_mask_prob
# Mono type vocabulary (same as MonosaccharideMaskingStrategy)
self.MONO_TYPES = MonosaccharideMaskingStrategy.MONO_TYPES
self.mono_to_id = {m: i for i, m in enumerate(self.MONO_TYPES)}
self.num_mono_types = len(self.MONO_TYPES)
if seed is not None:
random.seed(seed)
torch.manual_seed(seed)
def mask_sequence(
self,
input_ids: torch.Tensor,
residue_ids: torch.Tensor,
monosaccharide_names: List[List[str]] = None
) -> dict:
"""
Apply hierarchical masking at both token and residue levels.
Returns:
Dictionary with:
- masked_input_ids: Input with masks applied
- token_labels: Token-level labels for MLM (-100 for unmasked)
- residue_mask: Which residues were completely masked
- mono_labels: Monosaccharide type labels for masked residues
"""
batch_size, seq_len = input_ids.shape
masked_input_ids = input_ids.clone()
token_labels = torch.full_like(input_ids, -100)
residue_mask = []
mono_labels = []
for b in range(batch_size):
# Step 1: Select residues to mask entirely
unique_residues = torch.unique(residue_ids[b])
real_residues = [r.item() for r in unique_residues if r >= 0]
batch_residue_mask = set()
batch_mono_labels = {}
if len(real_residues) > 0:
n_residue_mask = max(0, int(len(real_residues) * self.residue_mask_prob))
random.shuffle(real_residues)
residues_to_mask = real_residues[:n_residue_mask]
for rid in residues_to_mask:
token_mask = residue_ids[b] == rid
# Store labels
token_labels[b, token_mask] = input_ids[b, token_mask]
# Mask tokens
masked_input_ids[b, token_mask] = self.mask_token_id
batch_residue_mask.add(rid)
# Get mono label
if monosaccharide_names is not None and rid < len(monosaccharide_names[b]):
mono_name = monosaccharide_names[b][rid]
batch_mono_labels[rid] = self.mono_to_id.get(mono_name, 0)
# Step 2: Mask additional individual tokens (not in already-masked residues)
maskable = torch.ones(seq_len, dtype=torch.bool, device=input_ids.device)
maskable &= (input_ids[b] != self.pad_token_id)
for special_id in self.special_token_ids:
maskable &= (input_ids[b] != special_id)
for ambig_id in self.ambiguous_token_ids:
maskable &= (input_ids[b] != ambig_id)
# Don't mask tokens in already-masked residues
for rid in batch_residue_mask:
maskable &= (residue_ids[b] != rid)
maskable_indices = maskable.nonzero(as_tuple=True)[0]
if len(maskable_indices) > 0:
n_to_mask = max(0, int(len(maskable_indices) * self.token_mask_prob))
perm = torch.randperm(len(maskable_indices))[:n_to_mask]
for idx in maskable_indices[perm]:
token_labels[b, idx] = input_ids[b, idx]
# 80% mask, 10% random, 10% unchanged
rand_val = random.random()
if rand_val < 0.8:
masked_input_ids[b, idx] = self.mask_token_id
elif rand_val < 0.9:
random_token = random.randint(0, self.vocab_size - 1)
masked_input_ids[b, idx] = random_token
residue_mask.append(batch_residue_mask)
mono_labels.append(batch_mono_labels)
return {
'masked_input_ids': masked_input_ids,
'token_labels': token_labels,
'residue_mask': residue_mask,
'mono_labels': mono_labels,
}