File size: 17,961 Bytes
1d6f391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
"""
ESM2-Style Masking Strategy for Glycan BERT

Implements masked language modeling following the ESM2 approach:
- Mask 15% of tokens randomly
- 80% replaced with [MASK]
- 10% replaced with random token
- 10% unchanged (for robustness)
"""

import torch
import random
from typing import List, Tuple


class GlycanMaskingStrategy:
    """
    Masking strategy for glycan sequences following ESM2.
    """
    
    def __init__(
        self,
        vocab_size: int,
        mask_token_id: int,
        pad_token_id: int,
        special_token_ids: List[int],
        ambiguous_token_ids: List[int] = None,
        mask_prob: float = 0.15,
        mask_token_prob: float = 0.8,
        random_token_prob: float = 0.1,
        unchanged_prob: float = 0.1,
        seed: int = None
    ):
        """
        Initialize masking strategy.
        
        Args:
            vocab_size: Size of vocabulary
            mask_token_id: ID of [MASK] token
            pad_token_id: ID of [PAD] token
            special_token_ids: List of special token IDs to never mask
            ambiguous_token_ids: List of ambiguous token IDs to never mask (x, X, ?, u, d, o)
            mask_prob: Probability of masking a token (default: 0.15)
            mask_token_prob: Probability of replacing with [MASK] (default: 0.8)
            random_token_prob: Probability of replacing with random token (default: 0.1)
            unchanged_prob: Probability of leaving unchanged (default: 0.1)
            seed: Random seed for reproducibility
        """
        assert abs(mask_token_prob + random_token_prob + unchanged_prob - 1.0) < 1e-6, \
            "Masking probabilities must sum to 1.0"
        
        self.vocab_size = vocab_size
        self.mask_token_id = mask_token_id
        self.pad_token_id = pad_token_id
        self.special_token_ids = set(special_token_ids)
        self.ambiguous_token_ids = set(ambiguous_token_ids) if ambiguous_token_ids else set()
        
        self.mask_prob = mask_prob
        self.mask_token_prob = mask_token_prob
        self.random_token_prob = random_token_prob
        self.unchanged_prob = unchanged_prob
        
        if seed is not None:
            random.seed(seed)
            torch.manual_seed(seed)
    
    def mask_sequence(
        self,
        input_ids: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Apply masking to a batch of sequences.
        
        Args:
            input_ids: Tensor of shape (batch_size, seq_len) with token IDs
        
        Returns:
            Tuple of:
            - masked_input_ids: Input with masks applied
            - labels: Original token IDs for masked positions (-100 for unmasked)
            - mask_positions: Boolean tensor indicating masked positions
        """
        batch_size, seq_len = input_ids.shape
        
        # Create labels (-100 for positions we don't predict)
        labels = torch.full_like(input_ids, -100)
        
        # Create mask for maskable positions
        maskable = torch.ones_like(input_ids, dtype=torch.bool)
        
        # Don't mask padding or special tokens
        maskable &= (input_ids != self.pad_token_id)
        for special_id in self.special_token_ids:
            maskable &= (input_ids != special_id)
        
        # Don't mask ambiguous tokens (we don't know ground truth)
        for ambig_id in self.ambiguous_token_ids:
            maskable &= (input_ids != ambig_id)
        
        # Randomly select mask_prob of maskable tokens
        mask_positions = torch.zeros_like(input_ids, dtype=torch.bool)
        for i in range(batch_size):
            maskable_indices = maskable[i].nonzero(as_tuple=True)[0]
            if len(maskable_indices) == 0:
                continue
            
            n_to_mask = max(1, int(len(maskable_indices) * self.mask_prob))
            mask_indices = maskable_indices[torch.randperm(len(maskable_indices))[:n_to_mask]]
            mask_positions[i, mask_indices] = True
        
        # Store original tokens for masked positions
        labels[mask_positions] = input_ids[mask_positions]
        
        # Create masked input
        masked_input_ids = input_ids.clone()
        
        # For each masked position, decide what to do
        masked_indices = mask_positions.nonzero(as_tuple=True)
        for batch_idx, pos_idx in zip(*masked_indices):
            rand_val = random.random()
            
            if rand_val < self.mask_token_prob:
                # Replace with [MASK]
                masked_input_ids[batch_idx, pos_idx] = self.mask_token_id
            elif rand_val < self.mask_token_prob + self.random_token_prob:
                # Replace with random token (excluding special tokens)
                random_token = random.randint(0, self.vocab_size - 1)
                while random_token in self.special_token_ids or random_token == self.pad_token_id:
                    random_token = random.randint(0, self.vocab_size - 1)
                masked_input_ids[batch_idx, pos_idx] = random_token
            # else: leave unchanged (unchanged_prob)
        
        return masked_input_ids, labels, mask_positions
    
    def get_mask_statistics(
        self,
        input_ids: torch.Tensor,
        masked_input_ids: torch.Tensor,
        mask_positions: torch.Tensor
    ) -> dict:
        """
        Calculate statistics about masking for logging.
        
        Args:
            input_ids: Original input IDs
            masked_input_ids: Masked input IDs
            mask_positions: Boolean mask indicating masked positions
        
        Returns:
            Dictionary with masking statistics
        """
        total_tokens = (input_ids != self.pad_token_id).sum().item()
        masked_tokens = mask_positions.sum().item()
        
        # Count each masking type
        mask_token_count = (masked_input_ids[mask_positions] == self.mask_token_id).sum().item()
        random_token_count = ((masked_input_ids[mask_positions] != self.mask_token_id) & 
                              (masked_input_ids[mask_positions] != input_ids[mask_positions])).sum().item()
        unchanged_count = masked_tokens - mask_token_count - random_token_count
        
        # Count ambiguous tokens in batch
        ambiguous_tokens = 0
        for ambig_id in self.ambiguous_token_ids:
            ambiguous_tokens += (input_ids == ambig_id).sum().item()
        
        stats = {
            'total_tokens': total_tokens,
            'masked_tokens': masked_tokens,
            'mask_percentage': masked_tokens / total_tokens * 100 if total_tokens > 0 else 0,
            'mask_token_count': mask_token_count,
            'random_token_count': random_token_count,
            'unchanged_count': unchanged_count,
            'ambiguous_tokens': ambiguous_tokens,
            'ambiguous_percentage': ambiguous_tokens / total_tokens * 100 if total_tokens > 0 else 0
        }
        
        return stats


class MonosaccharideMaskingStrategy:
    """
    Monosaccharide-level masking strategy for Glycan BERT.
    
    Instead of masking individual tokens, this masks entire monosaccharides
    and optionally predicts the monosaccharide type (Glc, Gal, etc.).
    
    This forces the model to learn holistic monosaccharide semantics rather
    than just local token patterns.
    """
    
    # Common monosaccharide types
    MONO_TYPES = [
        '<UNK>', 'Glc', 'Gal', 'Man', 'Fuc', 'Xyl', 'Rha', 'Ara',
        'GlcNAc', 'GalNAc', 'ManNAc', 'GlcA', 'GalA', 'IdoA',
        'Neu5Ac', 'Neu5Gc', 'Kdn', 'GlcN', 'GalN', 'Hex', 'HexNAc',
        'dHex', 'Pent', 'Sia', 'GlcS', 'GalS', 'Ido', 'All', 'Alt', 'Gul', 'Tal'
    ]
    
    def __init__(
        self,
        vocab_size: int,
        mask_token_id: int,
        pad_token_id: int,
        special_token_ids: List[int],
        mask_prob: float = 0.15,  # Probability to mask a residue
        predict_mono_type: bool = True,  # If True, predict mono type; if False, predict tokens
        seed: int = None
    ):
        """
        Initialize monosaccharide masking strategy.
        
        Args:
            vocab_size: Size of vocabulary
            mask_token_id: ID of [MASK] token
            pad_token_id: ID of [PAD] token
            special_token_ids: List of special token IDs to never mask
            mask_prob: Probability of masking a residue (default: 0.15)
            predict_mono_type: If True, labels are mono type IDs; if False, labels are token IDs
            seed: Random seed
        """
        self.vocab_size = vocab_size
        self.mask_token_id = mask_token_id
        self.pad_token_id = pad_token_id
        self.special_token_ids = set(special_token_ids)
        self.mask_prob = mask_prob
        self.predict_mono_type = predict_mono_type
        
        # Build mono type vocabulary
        self.mono_to_id = {m: i for i, m in enumerate(self.MONO_TYPES)}
        self.id_to_mono = {i: m for i, m in enumerate(self.MONO_TYPES)}
        self.num_mono_types = len(self.MONO_TYPES)
        
        if seed is not None:
            random.seed(seed)
            torch.manual_seed(seed)
    
    def mask_sequence(
        self,
        input_ids: torch.Tensor,
        residue_ids: torch.Tensor,
        monosaccharide_names: List[List[str]] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Apply monosaccharide-level masking.
        
        Args:
            input_ids: (batch, seq_len) token IDs
            residue_ids: (batch, seq_len) residue ID for each token (-1=special, -2=linkage, >=0=residue)
            monosaccharide_names: List of lists of monosaccharide names per batch item
        
        Returns:
            Tuple of:
            - masked_input_ids: Input with entire residue tokens masked
            - token_labels: Original token IDs for masked positions (-100 for unmasked)
            - mono_labels: Monosaccharide type IDs for masked residues (-100 for unmasked)
            - mask_positions: Boolean tensor indicating masked token positions
        """
        batch_size, seq_len = input_ids.shape
        
        # Initialize outputs
        masked_input_ids = input_ids.clone()
        token_labels = torch.full_like(input_ids, -100)
        mono_labels = torch.full((batch_size,), -100, dtype=torch.long, device=input_ids.device)
        mask_positions = torch.zeros_like(input_ids, dtype=torch.bool)
        
        for b in range(batch_size):
            # Find unique residues (>=0 are real residues)
            unique_residues = torch.unique(residue_ids[b])
            real_residues = unique_residues[unique_residues >= 0].tolist()
            
            if len(real_residues) == 0:
                continue
            
            # Select residues to mask
            n_to_mask = max(1, int(len(real_residues) * self.mask_prob))
            random.shuffle(real_residues)
            residues_to_mask = real_residues[:n_to_mask]
            
            for rid in residues_to_mask:
                # Find all tokens belonging to this residue
                token_mask = residue_ids[b] == rid
                
                # Store original tokens as labels
                token_labels[b, token_mask] = input_ids[b, token_mask]
                
                # Mask all tokens in this residue
                masked_input_ids[b, token_mask] = self.mask_token_id
                mask_positions[b, token_mask] = True
                
                # Get monosaccharide type label
                if self.predict_mono_type and monosaccharide_names is not None:
                    if rid < len(monosaccharide_names[b]):
                        mono_name = monosaccharide_names[b][rid]
                        mono_labels[b] = self.mono_to_id.get(mono_name, 0)  # 0 = <UNK>
        
        return masked_input_ids, token_labels, mono_labels, mask_positions
    
    def get_mono_type_id(self, mono_name: str) -> int:
        """Convert monosaccharide name to ID."""
        return self.mono_to_id.get(mono_name, 0)


class HierarchicalMaskingStrategy:
    """
    Hierarchical masking combining token-level and monosaccharide-level.
    
    Novel approach:
    1. Token-level MLM: Predict individual masked tokens (like BERT)
    2. Residue-level MLM: Predict monosaccharide types from masked residues
    3. Global contrastive: Align sequence with MS/3D representations
    
    This provides multi-scale supervision for better glycan understanding.
    """
    
    def __init__(
        self,
        vocab_size: int,
        mask_token_id: int,
        pad_token_id: int,
        special_token_ids: List[int],
        ambiguous_token_ids: List[int] = None,
        token_mask_prob: float = 0.10,  # Lower since we also mask residues
        residue_mask_prob: float = 0.10,  # Mask some whole residues
        seed: int = None
    ):
        """
        Initialize hierarchical masking.
        
        Args:
            vocab_size: Size of vocabulary
            mask_token_id: ID of [MASK] token
            pad_token_id: ID of [PAD] token
            special_token_ids: Special tokens to never mask
            ambiguous_token_ids: Ambiguous tokens to never mask
            token_mask_prob: Probability to mask individual tokens
            residue_mask_prob: Probability to mask entire residues
        """
        self.vocab_size = vocab_size
        self.mask_token_id = mask_token_id
        self.pad_token_id = pad_token_id
        self.special_token_ids = set(special_token_ids)
        self.ambiguous_token_ids = set(ambiguous_token_ids) if ambiguous_token_ids else set()
        self.token_mask_prob = token_mask_prob
        self.residue_mask_prob = residue_mask_prob
        
        # Mono type vocabulary (same as MonosaccharideMaskingStrategy)
        self.MONO_TYPES = MonosaccharideMaskingStrategy.MONO_TYPES
        self.mono_to_id = {m: i for i, m in enumerate(self.MONO_TYPES)}
        self.num_mono_types = len(self.MONO_TYPES)
        
        if seed is not None:
            random.seed(seed)
            torch.manual_seed(seed)
    
    def mask_sequence(
        self,
        input_ids: torch.Tensor,
        residue_ids: torch.Tensor,
        monosaccharide_names: List[List[str]] = None
    ) -> dict:
        """
        Apply hierarchical masking at both token and residue levels.
        
        Returns:
            Dictionary with:
            - masked_input_ids: Input with masks applied
            - token_labels: Token-level labels for MLM (-100 for unmasked)
            - residue_mask: Which residues were completely masked
            - mono_labels: Monosaccharide type labels for masked residues
        """
        batch_size, seq_len = input_ids.shape
        
        masked_input_ids = input_ids.clone()
        token_labels = torch.full_like(input_ids, -100)
        residue_mask = []
        mono_labels = []
        
        for b in range(batch_size):
            # Step 1: Select residues to mask entirely
            unique_residues = torch.unique(residue_ids[b])
            real_residues = [r.item() for r in unique_residues if r >= 0]
            
            batch_residue_mask = set()
            batch_mono_labels = {}
            
            if len(real_residues) > 0:
                n_residue_mask = max(0, int(len(real_residues) * self.residue_mask_prob))
                random.shuffle(real_residues)
                residues_to_mask = real_residues[:n_residue_mask]
                
                for rid in residues_to_mask:
                    token_mask = residue_ids[b] == rid
                    # Store labels
                    token_labels[b, token_mask] = input_ids[b, token_mask]
                    # Mask tokens
                    masked_input_ids[b, token_mask] = self.mask_token_id
                    batch_residue_mask.add(rid)
                    
                    # Get mono label
                    if monosaccharide_names is not None and rid < len(monosaccharide_names[b]):
                        mono_name = monosaccharide_names[b][rid]
                        batch_mono_labels[rid] = self.mono_to_id.get(mono_name, 0)
            
            # Step 2: Mask additional individual tokens (not in already-masked residues)
            maskable = torch.ones(seq_len, dtype=torch.bool, device=input_ids.device)
            maskable &= (input_ids[b] != self.pad_token_id)
            for special_id in self.special_token_ids:
                maskable &= (input_ids[b] != special_id)
            for ambig_id in self.ambiguous_token_ids:
                maskable &= (input_ids[b] != ambig_id)
            
            # Don't mask tokens in already-masked residues
            for rid in batch_residue_mask:
                maskable &= (residue_ids[b] != rid)
            
            maskable_indices = maskable.nonzero(as_tuple=True)[0]
            if len(maskable_indices) > 0:
                n_to_mask = max(0, int(len(maskable_indices) * self.token_mask_prob))
                perm = torch.randperm(len(maskable_indices))[:n_to_mask]
                for idx in maskable_indices[perm]:
                    token_labels[b, idx] = input_ids[b, idx]
                    # 80% mask, 10% random, 10% unchanged
                    rand_val = random.random()
                    if rand_val < 0.8:
                        masked_input_ids[b, idx] = self.mask_token_id
                    elif rand_val < 0.9:
                        random_token = random.randint(0, self.vocab_size - 1)
                        masked_input_ids[b, idx] = random_token
            
            residue_mask.append(batch_residue_mask)
            mono_labels.append(batch_mono_labels)
        
        return {
            'masked_input_ids': masked_input_ids,
            'token_labels': token_labels,
            'residue_mask': residue_mask,
            'mono_labels': mono_labels,
        }