supanthadey1 commited on
Commit
9b82aa2
·
verified ·
1 Parent(s): a3e4e43

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.bin filter=lfs diff=lfs merge=lfs -text
 
3
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: pytorch
3
+ license: other
4
+ tags:
5
+ - glycans
6
+ - wurcs
7
+ - bertose
8
+ - ambiguity-resolution
9
+ - contrastive-learning
10
+ - pytorch
11
+ ---
12
+
13
+ # Bertose IAR Ambiguity Resolver
14
+
15
+ Draft private release for Bertose ambiguity-resolution inference.
16
+
17
+ This repository contains the contrastive Bertose checkpoint used to score ambiguous WURCS BPE tokens and support iterative ambiguity resolution.
18
+
19
+ ## Files
20
+
21
+ - `checkpoints/best_v51_contrastive_model.pt` - contrastive ambiguity-resolution checkpoint.
22
+ - `vocab/bpe_vocabulary.json` - WURCS BPE vocabulary.
23
+ - `vocab/bpe_ambiguity_tokens.json` - ambiguous BPE token map used by the resolver.
24
+ - `src/multimodal_glycan_bert_v3.py` - model definition.
25
+ - `src/wurcs_bpe_tokenizer.py` - WURCS BPE tokenizer.
26
+
27
+ ## Expected Input
28
+
29
+ Single glycan or batch CSV with WURCS strings.
30
+
31
+ ## Output
32
+
33
+ Token-level ambiguity-resolution predictions with confidence scores. The companion notebook writes both summary and detail CSVs for batch runs.
34
+
35
+ ## Draft Notes
36
+
37
+ This release does not claim to reconstruct final canonical WURCS strings by itself. It provides model-backed token-level updates and confidence values for ambiguous positions.
SHA256SUMS ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ 622368f62c23e97e9137c277eaadcc93ee3901cbb420b591422bb1c2e19689a5 ./.gitattributes
2
+ 266caeb2fb9b68076343b40da91116dca0f2302f03cf28c2332b80b1a69c1758 ./README.md
3
+ ae468f4e8c06dc0c3848138a474dc43249aa6d14dfd0df8f58d68fcaad371152 ./checkpoints/best_v51_contrastive_model.pt
4
+ daf55c190fece0678064e41697a9545592beb1285f8aa74e595b933b9d37b4c2 ./config.json
5
+ 6a56e6f73b8f874470ecde6e538f3f5029ae23aa6c10559817d1c2a8b59b7c0f ./requirements.txt
6
+ 0d9ce16bf90242f38621d64cd974ea5679bff4c2013bea8d7bffe1b8dd120794 ./src/multimodal_glycan_bert_v3.py
7
+ 0bc54399362945601bcfd403441fc80968d173200dd0561f57568b2053a94839 ./src/wurcs_bpe_tokenizer.py
8
+ c68cd003370b2dcdb162f848f766e4e62f2653c6c38d205f8cbe53a9aabe2d74 ./vocab/bpe_ambiguity_tokens.json
9
+ 6a572afdf53f1494ab96c896876b824ca7ea749777352606aa9f96bf270ceecc ./vocab/bpe_vocabulary.json
checkpoints/best_v51_contrastive_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae468f4e8c06dc0c3848138a474dc43249aa6d14dfd0df8f58d68fcaad371152
3
+ size 557458637
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_family": "Bertose",
3
+ "release_name": "bertose-iar-ambiguity-resolver",
4
+ "checkpoint": "checkpoints/best_v51_contrastive_model.pt",
5
+ "vocabulary": "vocab/bpe_vocabulary.json",
6
+ "ambiguity_tokens": "vocab/bpe_ambiguity_tokens.json",
7
+ "embedding_dim": 768,
8
+ "max_glycan_length": 256,
9
+ "input_format": "WURCS",
10
+ "output_format": "token_level_predictions"
11
+ }
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ pandas
4
+ tqdm
5
+ huggingface_hub
src/multimodal_glycan_bert_v3.py ADDED
@@ -0,0 +1,1084 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multimodal Glycan BERT Model v3
3
+
4
+ Extends GlycanBERT to handle three modalities:
5
+ - Sequence (WURCS atomic tokenization)
6
+ - MS (mass spectrometry peaks, RT, intensity)
7
+ - 3D structure (VQ-VAE discrete tokens, 4 per residue)
8
+
9
+ Each modality has its own encoder, with cross-attention for sequence-structure alignment.
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from typing import Dict, Optional, Tuple
15
+ import math
16
+
17
+ try:
18
+ from .glycan_bert import GlycanBERTConfig, GlycanBERTEmbeddings, GlycanBERTLayer
19
+ except ImportError:
20
+ from glycan_bert import GlycanBERTConfig, GlycanBERTEmbeddings, GlycanBERTLayer
21
+
22
+
23
+ class ConvGlycanBERTEmbeddings(nn.Module):
24
+ """
25
+ Improved Convolutional front-end that mixes local WURCS context before the Transformer.
26
+
27
+ Key improvements over original:
28
+ 1. Position embeddings added BEFORE convolution (provides spatial context to conv)
29
+ 2. Residual connection (conv enriches embeddings rather than replacing them)
30
+ 3. Multi-scale convolutions (kernel sizes 3, 5, 7) for better receptive field
31
+ 4. Proper layer normalization on the residual path
32
+ """
33
+
34
+ def __init__(self, config):
35
+ super().__init__()
36
+ self.token_embeddings = nn.Embedding(
37
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
38
+ )
39
+ self.position_embeddings = nn.Embedding(
40
+ config.max_position_embeddings, config.hidden_size
41
+ )
42
+
43
+ # NEW: Branch depth embeddings - encodes depth in glycan tree (0=root, 1=child, etc.)
44
+ max_branch_depth = getattr(config, "max_branch_depth", 8)
45
+ self.branch_embeddings = nn.Embedding(max_branch_depth, config.hidden_size)
46
+
47
+ # NEW: Linkage type embeddings - encodes chemistry of glycosidic bond
48
+ # 0=none, 1=1-3, 2=1-4, 3=1-6, etc.
49
+ num_linkage_types = getattr(config, "num_linkage_types", 9)
50
+ self.linkage_embeddings = nn.Embedding(num_linkage_types, config.hidden_size)
51
+
52
+ # Multi-scale convolutions for different receptive fields
53
+ kernel_size = getattr(config, "cnn_kernel_size", 3)
54
+ # Split channels evenly: 256 + 256 + 256 = 768 for hidden_size=768
55
+ channels_per_scale = config.hidden_size // 3
56
+ self.conv_layers = nn.ModuleList([
57
+ nn.Conv1d(
58
+ in_channels=config.hidden_size,
59
+ out_channels=channels_per_scale,
60
+ kernel_size=kernel_size + 2 * i, # Kernels: 3, 5, 7
61
+ padding=(kernel_size + 2 * i) // 2, # Same padding
62
+ )
63
+ for i in range(3)
64
+ ])
65
+ self.conv_activation = nn.GELU()
66
+ self.conv_proj = nn.Linear(channels_per_scale * 3, config.hidden_size) # Project concatenated back
67
+
68
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
69
+ self.conv_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
70
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
71
+ self.register_buffer(
72
+ "position_ids",
73
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
74
+ )
75
+
76
+ self.hidden_size = config.hidden_size
77
+
78
+ def forward(self, input_ids, branch_depths=None, linkage_types=None):
79
+ seq_len = input_ids.shape[1]
80
+
81
+ # Step 1: Token + Position embeddings FIRST (provides spatial context to conv)
82
+ x = self.token_embeddings(input_ids) # (batch, seq, hidden)
83
+ position_ids = self.position_ids[:, :seq_len]
84
+ x = x + self.position_embeddings(position_ids)
85
+
86
+ # NEW: Add branch depth embeddings (encodes tree structure)
87
+ if branch_depths is not None:
88
+ # Clamp to valid range
89
+ branch_depths = branch_depths.clamp(0, self.branch_embeddings.num_embeddings - 1)
90
+ x = x + self.branch_embeddings(branch_depths)
91
+
92
+ # NEW: Add linkage type embeddings (encodes bond chemistry)
93
+ if linkage_types is not None:
94
+ linkage_types = linkage_types.clamp(0, self.linkage_embeddings.num_embeddings - 1)
95
+ x = x + self.linkage_embeddings(linkage_types)
96
+
97
+ x = self.LayerNorm(x)
98
+
99
+ # Step 2: Multi-scale convolution with RESIDUAL connection
100
+ # Convolution expects (batch, hidden, seq)
101
+ conv_in = x.permute(0, 2, 1)
102
+
103
+ # Apply multi-scale convolutions and concatenate
104
+ conv_outputs = []
105
+ for conv in self.conv_layers:
106
+ conv_out = self.conv_activation(conv(conv_in))
107
+ conv_outputs.append(conv_out)
108
+
109
+ # Concatenate multi-scale features and project back
110
+ conv_out = torch.cat(conv_outputs, dim=1) # (batch, hidden, seq)
111
+ conv_out = conv_out.permute(0, 2, 1) # (batch, seq, hidden)
112
+ conv_out = self.conv_proj(conv_out) # Project to correct size
113
+
114
+ # Step 3: Residual connection - conv ENRICHES rather than replaces
115
+ x = self.conv_norm(x + self.dropout(conv_out))
116
+
117
+ return x
118
+
119
+
120
+ def create_residue_level_mask(
121
+ seq_residue_ids: torch.Tensor, # (batch, N_seq)
122
+ struct_residue_ids: torch.Tensor # (batch, N_struct)
123
+ ) -> torch.Tensor:
124
+ """
125
+ Create residue-level attention mask for cross-attention.
126
+
127
+ Maps WURCS tokens to VQ-VAE structural tokens based on residue IDs.
128
+ A WURCS token with residue_id=0 can only attend to VQ-VAE tokens with residue_id=0.
129
+
130
+ Args:
131
+ seq_residue_ids: Residue IDs for sequence tokens (batch, N_seq)
132
+ struct_residue_ids: Residue IDs for structural tokens (batch, N_struct)
133
+
134
+ Returns:
135
+ Boolean mask (batch, N_seq, N_struct) where True = can attend
136
+ """
137
+ # Expand dimensions for broadcasting
138
+ # seq: (batch, N_seq, 1)
139
+ # struct: (batch, 1, N_struct)
140
+ mask = seq_residue_ids.unsqueeze(2) == struct_residue_ids.unsqueeze(1)
141
+ # Shape: (batch, N_seq, N_struct)
142
+
143
+ # Mask out structural tokens (residue_id = -1) and MS tokens (residue_id = -2)
144
+ # Only tokens with residue_id >= 0 can attend
145
+ mask &= (seq_residue_ids.unsqueeze(2) >= 0)
146
+
147
+ return mask # True = can attend, False = cannot attend
148
+
149
+
150
+ class MultimodalGlycanBERTConfig:
151
+ """Configuration for Multimodal GlycanBERT v3."""
152
+
153
+ def __init__(
154
+ self,
155
+ # Sequence modality
156
+ seq_vocab_size: int = 166,
157
+ seq_hidden_size: int = 768,
158
+ seq_num_layers: int = 12,
159
+ seq_num_heads: int = 12,
160
+ seq_max_length: int = 512,
161
+
162
+ # MS modality
163
+ ms_vocab_size: int = 242,
164
+ ms_hidden_size: int = 384,
165
+ ms_num_layers: int = 6,
166
+ ms_num_heads: int = 6,
167
+ ms_max_length: int = 150,
168
+
169
+ # 3D structure modality
170
+ struct_vocab_size: int = 1024, # VQ-VAE codebook size
171
+ struct_hidden_size: int = 512,
172
+ struct_num_layers: int = 8,
173
+ struct_num_heads: int = 8,
174
+ struct_max_length: int = 200,
175
+ use_3d: bool = True,
176
+
177
+ # Cross-attention
178
+ use_cross_attention: bool = True,
179
+ cross_attn_num_heads: int = 8,
180
+
181
+ # Fusion
182
+ fusion_hidden_size: int = 768,
183
+ fusion_num_layers: int = 2,
184
+
185
+ # Training
186
+ hidden_dropout_prob: float = 0.1,
187
+ attention_probs_dropout_prob: float = 0.1,
188
+ layer_norm_eps: float = 1e-12,
189
+ initializer_range: float = 0.02,
190
+
191
+ # Conv front-end
192
+ use_cnn_frontend: bool = True,
193
+ cnn_kernel_size: int = 3,
194
+
195
+ # Loss weights
196
+ seq_loss_weight: float = 0.60,
197
+ ms_loss_weight: float = 0.15,
198
+ struct_loss_weight: float = 0.25,
199
+
200
+ # Token IDs
201
+ pad_token_id: int = 0,
202
+ mask_token_id: int = 1,
203
+ ):
204
+ # Sequence config
205
+ self.seq_vocab_size = seq_vocab_size
206
+ self.seq_hidden_size = seq_hidden_size
207
+ self.seq_num_layers = seq_num_layers
208
+ self.seq_num_heads = seq_num_heads
209
+ self.seq_max_length = seq_max_length
210
+
211
+ # MS config
212
+ self.ms_vocab_size = ms_vocab_size
213
+ self.ms_vocab_offset = seq_vocab_size # MS tokens start at 166
214
+ self.ms_total_vocab_size = seq_vocab_size + ms_vocab_size # 408 total
215
+ self.ms_hidden_size = ms_hidden_size
216
+ self.ms_num_layers = ms_num_layers
217
+ self.ms_num_heads = ms_num_heads
218
+ self.ms_max_length = ms_max_length
219
+
220
+ # Structure config
221
+ self.struct_vocab_size = struct_vocab_size
222
+ self.struct_hidden_size = struct_hidden_size
223
+ self.struct_num_layers = struct_num_layers
224
+ self.struct_num_heads = struct_num_heads
225
+ self.struct_max_length = struct_max_length
226
+ self.use_3d = use_3d
227
+
228
+ # Cross-attention config
229
+ self.use_cross_attention = use_cross_attention
230
+ self.cross_attn_num_heads = cross_attn_num_heads
231
+
232
+ # Fusion config
233
+ self.fusion_hidden_size = fusion_hidden_size
234
+ self.fusion_num_layers = fusion_num_layers
235
+
236
+ # Training config
237
+ self.hidden_dropout_prob = hidden_dropout_prob
238
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
239
+ self.layer_norm_eps = layer_norm_eps
240
+ self.initializer_range = initializer_range
241
+
242
+ # Conv front-end
243
+ self.use_cnn_frontend = use_cnn_frontend
244
+ self.cnn_kernel_size = cnn_kernel_size
245
+
246
+ # Loss weights
247
+ self.seq_loss_weight = seq_loss_weight
248
+ self.ms_loss_weight = ms_loss_weight
249
+ self.struct_loss_weight = struct_loss_weight
250
+ self.dist_loss_weight = 0.25 # NEW: Topology loss weight (default, can override from config)
251
+
252
+ # Token IDs
253
+ self.pad_token_id = pad_token_id
254
+ self.mask_token_id = mask_token_id
255
+
256
+ def to_seq_config(self) -> GlycanBERTConfig:
257
+ """Convert to sequence-only config."""
258
+ return GlycanBERTConfig(
259
+ vocab_size=self.seq_vocab_size,
260
+ hidden_size=self.seq_hidden_size,
261
+ num_hidden_layers=self.seq_num_layers,
262
+ num_attention_heads=self.seq_num_heads,
263
+ intermediate_size=self.seq_hidden_size * 4,
264
+ hidden_dropout_prob=self.hidden_dropout_prob,
265
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
266
+ max_position_embeddings=self.seq_max_length,
267
+ layer_norm_eps=self.layer_norm_eps,
268
+ pad_token_id=self.pad_token_id,
269
+ mask_token_id=self.mask_token_id,
270
+ initializer_range=self.initializer_range,
271
+ )
272
+
273
+ def to_ms_config(self) -> GlycanBERTConfig:
274
+ """Convert to MS-only config."""
275
+ return GlycanBERTConfig(
276
+ vocab_size=self.ms_total_vocab_size,
277
+ hidden_size=self.ms_hidden_size,
278
+ num_hidden_layers=self.ms_num_layers,
279
+ num_attention_heads=self.ms_num_heads,
280
+ intermediate_size=self.ms_hidden_size * 4,
281
+ hidden_dropout_prob=self.hidden_dropout_prob,
282
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
283
+ max_position_embeddings=self.ms_max_length,
284
+ layer_norm_eps=self.layer_norm_eps,
285
+ pad_token_id=self.pad_token_id,
286
+ mask_token_id=self.mask_token_id,
287
+ initializer_range=self.initializer_range,
288
+ )
289
+
290
+ def to_struct_config(self) -> GlycanBERTConfig:
291
+ """Convert to structure-only config."""
292
+ return GlycanBERTConfig(
293
+ vocab_size=self.struct_vocab_size,
294
+ hidden_size=self.struct_hidden_size,
295
+ num_hidden_layers=self.struct_num_layers,
296
+ num_attention_heads=self.struct_num_heads,
297
+ intermediate_size=self.struct_hidden_size * 4,
298
+ hidden_dropout_prob=self.hidden_dropout_prob,
299
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
300
+ max_position_embeddings=self.struct_max_length,
301
+ layer_norm_eps=self.layer_norm_eps,
302
+ pad_token_id=self.pad_token_id,
303
+ mask_token_id=self.mask_token_id,
304
+ initializer_range=self.initializer_range,
305
+ )
306
+
307
+
308
+ # =============================================================================
309
+ # Improvement #1: Monosaccharide-Level Pooling
310
+ # =============================================================================
311
+
312
+ class MonosaccharidePooling(nn.Module):
313
+ """
314
+ Pool token representations to monosaccharide level, then aggregate.
315
+
316
+ This bridges the gap between token-level BERT and monosaccharide-level CNNs/GNNs.
317
+ Uses monosaccharide_indices from the data to know where each residue starts.
318
+ """
319
+
320
+ def __init__(self, hidden_size: int, num_attention_heads: int = 8, dropout: float = 0.1):
321
+ super().__init__()
322
+ self.hidden_size = hidden_size
323
+
324
+ # Attention pooling over monosaccharide representations
325
+ self.mono_attention = nn.MultiheadAttention(
326
+ embed_dim=hidden_size,
327
+ num_heads=num_attention_heads,
328
+ dropout=dropout,
329
+ batch_first=True
330
+ )
331
+ self.mono_norm = nn.LayerNorm(hidden_size)
332
+
333
+ # Final aggregation to single glycan representation
334
+ self.glycan_query = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
335
+ self.glycan_attention = nn.MultiheadAttention(
336
+ embed_dim=hidden_size,
337
+ num_heads=num_attention_heads,
338
+ dropout=dropout,
339
+ batch_first=True
340
+ )
341
+ self.glycan_norm = nn.LayerNorm(hidden_size)
342
+
343
+ def forward(
344
+ self,
345
+ hidden_states: torch.Tensor, # (batch, seq_len, hidden)
346
+ residue_ids: torch.Tensor, # (batch, seq_len) - which residue each token belongs to
347
+ attention_mask: torch.Tensor = None, # (batch, seq_len)
348
+ ) -> torch.Tensor:
349
+ """
350
+ Pool tokens to monosaccharide level, then to glycan level.
351
+
352
+ Returns:
353
+ Glycan representation: (batch, hidden_size)
354
+ """
355
+ batch_size = hidden_states.size(0)
356
+ device = hidden_states.device
357
+
358
+ # Get unique residue IDs per sample (excluding -1 padding)
359
+ max_residues = 50 # Reasonable max for glycans
360
+
361
+ # Pool tokens within each residue using mean pooling
362
+ mono_reps = torch.zeros(batch_size, max_residues, self.hidden_size, device=device)
363
+ mono_mask = torch.zeros(batch_size, max_residues, dtype=torch.bool, device=device)
364
+
365
+ for b in range(batch_size):
366
+ unique_residues = torch.unique(residue_ids[b][residue_ids[b] >= 0])
367
+ for i, rid in enumerate(unique_residues):
368
+ if i >= max_residues:
369
+ break
370
+ token_mask = residue_ids[b] == rid
371
+ if attention_mask is not None:
372
+ token_mask = token_mask & (attention_mask[b] > 0)
373
+ if token_mask.sum() > 0:
374
+ mono_reps[b, i] = hidden_states[b][token_mask].mean(dim=0)
375
+ mono_mask[b, i] = True
376
+
377
+ # Apply attention over monosaccharide representations
378
+ # Convert mask for attention: True = valid, need to invert for PyTorch
379
+ key_padding_mask = ~mono_mask # True = ignore
380
+
381
+ mono_out, _ = self.mono_attention(
382
+ mono_reps, mono_reps, mono_reps,
383
+ key_padding_mask=key_padding_mask
384
+ )
385
+ mono_out = self.mono_norm(mono_reps + mono_out)
386
+
387
+ # Aggregate to single glycan representation using learned query
388
+ glycan_query = self.glycan_query.expand(batch_size, -1, -1)
389
+ glycan_out, _ = self.glycan_attention(
390
+ glycan_query, mono_out, mono_out,
391
+ key_padding_mask=key_padding_mask
392
+ )
393
+ glycan_out = self.glycan_norm(glycan_query + glycan_out)
394
+
395
+ return glycan_out.squeeze(1) # (batch, hidden)
396
+
397
+
398
+ # =============================================================================
399
+ # Improvement #2: Residue Type Embeddings
400
+ # =============================================================================
401
+
402
+ # Common monosaccharide types vocabulary
403
+ MONOSACCHARIDE_VOCAB = {
404
+ '[PAD_MONO]': 0, '[UNK_MONO]': 1,
405
+ 'Glc': 2, 'GlcNAc': 3, 'GlcA': 4, 'GlcN': 5,
406
+ 'Gal': 6, 'GalNAc': 7, 'GalA': 8, 'GalN': 9,
407
+ 'Man': 10, 'ManNAc': 11, 'ManA': 12, 'ManN': 13,
408
+ 'Fuc': 14, 'Rha': 15, 'Xyl': 16, 'Ara': 17,
409
+ 'Neu5Ac': 18, 'Neu5Gc': 19, 'Kdn': 20, 'Sia': 21,
410
+ 'GalNAcA': 22, 'GlcNAcA': 23, 'IdoA': 24, 'GulA': 25,
411
+ 'Rib': 26, 'Lyx': 27, 'All': 28, 'Alt': 29,
412
+ 'Tal': 30, 'Ido': 31, 'Qui': 32, 'Oli': 33,
413
+ 'Tyv': 34, 'Abe': 35, 'Par': 36, 'Dig': 37,
414
+ 'Col': 38, 'Dha': 39, 'Kdo': 40, 'Hep': 41,
415
+ 'NeuroGc': 42, 'Muramic': 43, 'LDManHep': 44, 'DDManHep': 45,
416
+ 'Bac': 46, 'Pse': 47, 'Leg': 48, 'Aci': 49,
417
+ '6dTal': 50, 'Fru': 51, 'Tag': 52, 'Sor': 53,
418
+ 'Psi': 54, 'Sed': 55, 'MurNAc': 56, 'MurNGc': 57,
419
+ 'Api': 58, 'Erwiniose': 59, 'Yer': 60, 'Thre': 61,
420
+ # Add more as needed, up to ~70
421
+ }
422
+
423
+
424
+ class ResidueTypeEmbeddings(nn.Module):
425
+ """
426
+ Learnable embeddings for monosaccharide types.
427
+
428
+ Instead of the model having to learn that 'a1221m' = Fucose from character patterns,
429
+ we explicitly add a Fucose embedding to all tokens belonging to that residue.
430
+ """
431
+
432
+ def __init__(self, hidden_size: int, num_mono_types: int = 70):
433
+ super().__init__()
434
+ self.mono_embeddings = nn.Embedding(num_mono_types, hidden_size)
435
+ self.mono_vocab = MONOSACCHARIDE_VOCAB
436
+ self.hidden_size = hidden_size
437
+
438
+ def forward(
439
+ self,
440
+ token_embeddings: torch.Tensor, # (batch, seq_len, hidden)
441
+ residue_ids: torch.Tensor, # (batch, seq_len)
442
+ mono_type_ids: torch.Tensor = None, # (batch, max_residues) - monosaccharide type per residue
443
+ ) -> torch.Tensor:
444
+ """
445
+ Add residue type embeddings to token embeddings.
446
+
447
+ Args:
448
+ token_embeddings: Base token embeddings
449
+ residue_ids: Which residue each token belongs to (-1 for special tokens)
450
+ mono_type_ids: Monosaccharide type ID for each residue position
451
+
452
+ Returns:
453
+ Enhanced embeddings with residue type information
454
+ """
455
+ if mono_type_ids is None:
456
+ return token_embeddings
457
+
458
+ batch_size, seq_len, _ = token_embeddings.shape
459
+ enhanced = token_embeddings.clone()
460
+
461
+ # Add mono type embedding to each token based on its residue
462
+ for b in range(batch_size):
463
+ for pos in range(seq_len):
464
+ rid = residue_ids[b, pos].item()
465
+ if rid >= 0 and rid < mono_type_ids.size(1):
466
+ mono_id = mono_type_ids[b, rid]
467
+ enhanced[b, pos] = enhanced[b, pos] + self.mono_embeddings(mono_id)
468
+
469
+ return enhanced
470
+
471
+ @staticmethod
472
+ def get_mono_type_id(mono_name: str) -> int:
473
+ """Convert monosaccharide name to type ID."""
474
+ return MONOSACCHARIDE_VOCAB.get(mono_name, MONOSACCHARIDE_VOCAB['[UNK_MONO]'])
475
+
476
+
477
+ # =============================================================================
478
+ # Improvement #4: Relative Position Encoding for Glycan Trees
479
+ # =============================================================================
480
+
481
+ class RelativePositionBias(nn.Module):
482
+ """
483
+ Compute relative position bias for attention based on residue IDs.
484
+
485
+ Tokens in the same residue get distance 0.
486
+ Tokens in adjacent residues get distance ±1.
487
+ This helps the model understand glycan tree structure.
488
+ """
489
+
490
+ def __init__(self, num_heads: int, max_distance: int = 10):
491
+ super().__init__()
492
+ self.num_heads = num_heads
493
+ self.max_distance = max_distance
494
+
495
+ # Learnable bias for each relative distance (-max to +max)
496
+ num_distances = 2 * max_distance + 1
497
+ self.relative_bias = nn.Embedding(num_distances, num_heads)
498
+
499
+ def forward(self, residue_ids: torch.Tensor) -> torch.Tensor:
500
+ """
501
+ Compute relative position bias.
502
+
503
+ Args:
504
+ residue_ids: (batch, seq_len)
505
+
506
+ Returns:
507
+ Bias to add to attention scores: (batch, num_heads, seq_len, seq_len)
508
+ """
509
+ # Compute pairwise residue distances
510
+ # (batch, seq_len, 1) - (batch, 1, seq_len) = (batch, seq_len, seq_len)
511
+ distance = residue_ids.unsqueeze(2) - residue_ids.unsqueeze(1)
512
+
513
+ # Clamp to max distance range and shift to 0-indexed
514
+ distance_clamped = distance.clamp(-self.max_distance, self.max_distance)
515
+ distance_idx = distance_clamped + self.max_distance # Now 0 to 2*max_distance
516
+
517
+ # Look up bias: (batch, seq_len, seq_len, num_heads)
518
+ bias = self.relative_bias(distance_idx)
519
+
520
+ # Transpose to (batch, num_heads, seq_len, seq_len)
521
+ bias = bias.permute(0, 3, 1, 2)
522
+
523
+ return bias
524
+
525
+
526
+ class CrossAttentionLayer(nn.Module):
527
+ """
528
+ Cross-attention layer for sequence-structure alignment.
529
+
530
+ Allows sequence tokens to attend to structural atoms using attention masks.
531
+ """
532
+
533
+ def __init__(self, config: MultimodalGlycanBERTConfig):
534
+ super().__init__()
535
+ self.num_heads = config.cross_attn_num_heads
536
+ self.hidden_size = config.seq_hidden_size
537
+ self.head_dim = self.hidden_size // self.num_heads
538
+
539
+ assert self.hidden_size % self.num_heads == 0, "hidden_size must be divisible by num_heads"
540
+
541
+ # Query from sequence, Key/Value from structure (VQ-VAE tokens)
542
+ self.query = nn.Linear(config.seq_hidden_size, self.hidden_size)
543
+ self.key = nn.Linear(config.struct_hidden_size, self.hidden_size)
544
+ self.value = nn.Linear(config.struct_hidden_size, self.hidden_size)
545
+
546
+ self.output = nn.Linear(self.hidden_size, config.seq_hidden_size)
547
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
548
+ self.layer_norm = nn.LayerNorm(config.seq_hidden_size, eps=config.layer_norm_eps)
549
+
550
+ def forward(
551
+ self,
552
+ seq_hidden: torch.Tensor, # (batch, seq_len, seq_hidden)
553
+ struct_hidden: torch.Tensor, # (batch, struct_len, struct_hidden)
554
+ attention_mask: Optional[torch.Tensor] = None, # (batch, seq_len, struct_len)
555
+ ) -> torch.Tensor:
556
+ """
557
+ Apply cross-attention from sequence to structure.
558
+
559
+ Args:
560
+ seq_hidden: Sequence hidden states
561
+ struct_hidden: Structure hidden states
562
+ attention_mask: Boolean mask (True = can attend, False = cannot attend)
563
+
564
+ Returns:
565
+ Updated sequence hidden states
566
+ """
567
+ batch_size, seq_len, _ = seq_hidden.shape
568
+ struct_len = struct_hidden.shape[1]
569
+
570
+ # Project to Q, K, V
571
+ Q = self.query(seq_hidden) # (batch, seq_len, hidden)
572
+ K = self.key(struct_hidden) # (batch, struct_len, hidden)
573
+ V = self.value(struct_hidden) # (batch, struct_len, hidden)
574
+
575
+ # Reshape for multi-head attention
576
+ Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch, heads, seq_len, head_dim)
577
+ K = K.view(batch_size, struct_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch, heads, struct_len, head_dim)
578
+ V = V.view(batch_size, struct_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch, heads, struct_len, head_dim)
579
+
580
+ # Compute attention scores
581
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) # (batch, heads, seq_len, struct_len)
582
+
583
+ # Apply attention mask
584
+ if attention_mask is not None:
585
+ # attention_mask: (batch, seq_len, struct_len) -> (batch, 1, seq_len, struct_len)
586
+ attention_mask = attention_mask.unsqueeze(1)
587
+ # Convert boolean mask to float: True -> 0.0, False -> -10000.0
588
+ attention_mask = (~attention_mask).float() * -10000.0
589
+ scores = scores + attention_mask
590
+
591
+ # Softmax and dropout
592
+ attn_weights = torch.softmax(scores, dim=-1) # (batch, heads, seq_len, struct_len)
593
+ attn_weights = self.dropout(attn_weights)
594
+
595
+ # Apply attention to values
596
+ context = torch.matmul(attn_weights, V) # (batch, heads, seq_len, head_dim)
597
+
598
+ # Reshape back
599
+ context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
600
+
601
+ # Output projection
602
+ output = self.output(context)
603
+ output = self.dropout(output)
604
+
605
+ # Residual connection + layer norm
606
+ output = self.layer_norm(seq_hidden + output)
607
+
608
+ return output
609
+
610
+
611
+ class MultimodalGlycanBERT(nn.Module):
612
+ """
613
+ Multimodal BERT for glycan representation learning (v3).
614
+
615
+ Architecture:
616
+ 1. Separate encoders for each modality (sequence, MS, 3D structure)
617
+ 2. Cross-attention for sequence-structure alignment
618
+ 3. Modality-specific MLM heads
619
+ 4. Fusion layer for combined representation
620
+ """
621
+
622
+ def __init__(self, config: MultimodalGlycanBERTConfig):
623
+ super().__init__()
624
+ self.config = config
625
+
626
+ # ===== Sequence Encoder =====
627
+ seq_config = config.to_seq_config()
628
+ seq_config.cnn_kernel_size = config.cnn_kernel_size
629
+
630
+ if config.use_cnn_frontend:
631
+ print(f"✅ Enabled Convolutional Front-End (Kernel={config.cnn_kernel_size})")
632
+ self.seq_embeddings = ConvGlycanBERTEmbeddings(seq_config)
633
+ else:
634
+ self.seq_embeddings = GlycanBERTEmbeddings(seq_config)
635
+ self.seq_layers = nn.ModuleList([GlycanBERTLayer(seq_config) for _ in range(seq_config.num_hidden_layers)])
636
+ self.seq_mlm_head = nn.Linear(seq_config.hidden_size, seq_config.vocab_size)
637
+
638
+ # ===== MS Encoder =====
639
+ ms_config = config.to_ms_config()
640
+ self.ms_embeddings = GlycanBERTEmbeddings(ms_config)
641
+ self.ms_layers = nn.ModuleList([GlycanBERTLayer(ms_config) for _ in range(ms_config.num_hidden_layers)])
642
+ self.ms_mlm_head = nn.Linear(ms_config.hidden_size, ms_config.vocab_size)
643
+
644
+ # ===== Structure Encoder (VQ-VAE tokens) =====
645
+ if config.use_3d:
646
+ struct_config = config.to_struct_config()
647
+ self.struct_embeddings = GlycanBERTEmbeddings(struct_config)
648
+ self.struct_layers = nn.ModuleList([GlycanBERTLayer(struct_config) for _ in range(struct_config.num_hidden_layers)])
649
+ self.struct_mlm_head = nn.Linear(struct_config.hidden_size, struct_config.vocab_size)
650
+
651
+ # Cross-attention layer (sequence → VQ-VAE structural tokens)
652
+ if config.use_cross_attention:
653
+ self.cross_attention = CrossAttentionLayer(config)
654
+
655
+ # ===== Projection layers (align hidden sizes) =====
656
+ if config.ms_hidden_size != config.seq_hidden_size:
657
+ self.ms_projection = nn.Linear(config.ms_hidden_size, config.seq_hidden_size)
658
+ else:
659
+ self.ms_projection = nn.Identity()
660
+
661
+ if config.use_3d and config.struct_hidden_size != config.seq_hidden_size:
662
+ self.struct_projection = nn.Linear(config.struct_hidden_size, config.seq_hidden_size)
663
+ else:
664
+ self.struct_projection = nn.Identity()
665
+
666
+ # ===== Fusion Layer =====
667
+ # Concatenate seq + ms + struct
668
+ fusion_input_size = config.seq_hidden_size * (3 if config.use_3d else 2)
669
+ self.fusion_layer = nn.Sequential(
670
+ nn.Linear(fusion_input_size, config.fusion_hidden_size),
671
+ nn.LayerNorm(config.fusion_hidden_size, eps=config.layer_norm_eps),
672
+ nn.GELU(),
673
+ nn.Dropout(config.hidden_dropout_prob),
674
+ nn.Linear(config.fusion_hidden_size, config.fusion_hidden_size),
675
+ )
676
+
677
+ # ===== Distance Prediction Head (Topology) =====
678
+ # OPTIMIZED: Project down to 128 dim first to save GPU memory
679
+ # (Batch, 256, 256, 768) -> (Batch, 256, 256, 128) reduces memory by 6x
680
+ self.dist_proj = nn.Linear(config.seq_hidden_size, 128)
681
+ self.distance_head = nn.Sequential(
682
+ nn.Linear(128, 64),
683
+ nn.ReLU(),
684
+ nn.Linear(64, 1)
685
+ )
686
+
687
+ # Initialize weights
688
+ self.apply(self._init_weights)
689
+
690
+ def _init_weights(self, module):
691
+ """Initialize weights."""
692
+ if isinstance(module, nn.Linear):
693
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
694
+ if module.bias is not None:
695
+ module.bias.data.zero_()
696
+ elif isinstance(module, nn.Embedding):
697
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
698
+ if module.padding_idx is not None:
699
+ module.weight.data[module.padding_idx].zero_()
700
+ elif isinstance(module, nn.LayerNorm):
701
+ module.bias.data.zero_()
702
+ module.weight.data.fill_(1.0)
703
+
704
+ def forward(
705
+ self,
706
+ seq_token_ids: torch.Tensor,
707
+ seq_attention_mask: torch.Tensor,
708
+ seq_residue_ids: torch.Tensor,
709
+ seq_branch_depths: Optional[torch.Tensor] = None, # NEW: Branch depths
710
+ seq_linkage_types: Optional[torch.Tensor] = None, # NEW: Linkage types
711
+ ms_token_ids: torch.Tensor = None,
712
+ ms_attention_mask: torch.Tensor = None,
713
+ has_ms: torch.Tensor = None,
714
+ struct_token_ids: Optional[torch.Tensor] = None,
715
+ struct_attention_mask: Optional[torch.Tensor] = None,
716
+ struct_residue_ids: Optional[torch.Tensor] = None,
717
+ has_3d: Optional[torch.Tensor] = None,
718
+ seq_labels: Optional[torch.Tensor] = None,
719
+ ms_labels: Optional[torch.Tensor] = None,
720
+ struct_labels: Optional[torch.Tensor] = None,
721
+ dist_labels: Optional[torch.Tensor] = None, # NEW: Topology distance labels
722
+ return_dict: bool = True,
723
+ ) -> Dict[str, torch.Tensor]:
724
+ """
725
+ Forward pass for multimodal BERT v3.
726
+
727
+ Args:
728
+ seq_token_ids: (batch_size, seq_len) - Sequence token IDs
729
+ seq_attention_mask: (batch_size, seq_len) - Sequence attention mask
730
+ seq_residue_ids: (batch_size, seq_len) - Sequence token residue IDs
731
+ ms_token_ids: (batch_size, ms_len) - MS token IDs
732
+ ms_attention_mask: (batch_size, ms_len) - MS attention mask
733
+ has_ms: (batch_size,) - Boolean mask for samples with MS data
734
+ struct_token_ids: (batch_size, struct_len) - Structure VQ-VAE token IDs (optional)
735
+ struct_attention_mask: (batch_size, struct_len) - Structure attention mask (optional)
736
+ struct_residue_ids: (batch_size, struct_len) - Structure token residue IDs (optional)
737
+ has_3d: (batch_size,) - Boolean mask for samples with 3D data (optional)
738
+ seq_labels: (batch_size, seq_len) - Masked sequence labels (optional)
739
+ ms_labels: (batch_size, ms_len) - Masked MS labels (optional)
740
+ struct_labels: (batch_size, struct_len) - Masked structure labels (optional)
741
+ return_dict: Whether to return dict or tuple
742
+
743
+ Returns:
744
+ Dictionary containing logits, hidden states, losses, etc.
745
+ """
746
+ batch_size = seq_token_ids.shape[0]
747
+ device = seq_token_ids.device
748
+
749
+ # ===== Sequence Encoder =====
750
+ # Pass branch_depths and linkage_types to embeddings for tree-aware encoding
751
+ seq_hidden = self.seq_embeddings(seq_token_ids, seq_branch_depths, seq_linkage_types)
752
+ for layer in self.seq_layers:
753
+ seq_hidden = layer(seq_hidden, seq_attention_mask)
754
+
755
+ seq_pooled = seq_hidden[:, 0, :] # [CLS] token
756
+ seq_logits = self.seq_mlm_head(seq_hidden)
757
+
758
+ # ===== Distance Predictions (Topology) =====
759
+ # Compute pairwise distance predictions
760
+ # MEMORY OPTIMIZATION: Project to 128-dim first
761
+ seq_hidden_small = self.dist_proj(seq_hidden) # (batch, seq_len, 128)
762
+
763
+ # Expand for pairwise: (batch, seq_len, 1, 128) - (batch, 1, seq_len, 128)
764
+ h_i = seq_hidden_small.unsqueeze(2)
765
+ h_j = seq_hidden_small.unsqueeze(1)
766
+ h_diff = torch.abs(h_i - h_j) # (batch, seq_len, seq_len, 128) - Much smaller!
767
+ dist_predictions = self.distance_head(h_diff) # (batch, seq_len, seq_len, 1)
768
+
769
+ # ===== MS Encoder =====
770
+ ms_hidden = None
771
+ ms_pooled = None
772
+ ms_logits = None
773
+
774
+ if ms_token_ids is not None:
775
+ ms_hidden = self.ms_embeddings(ms_token_ids)
776
+ for layer in self.ms_layers:
777
+ ms_hidden = layer(ms_hidden, ms_attention_mask)
778
+
779
+ ms_pooled = ms_hidden[:, 0, :] # [CLS] token
780
+ ms_logits = self.ms_mlm_head(ms_hidden)
781
+
782
+ # Zero out MS representations for samples without MS data
783
+ if has_ms is not None:
784
+ has_ms_expanded = has_ms.unsqueeze(1).float() # (batch, 1)
785
+ ms_pooled = ms_pooled * has_ms_expanded
786
+
787
+ # ===== Structure Encoder =====
788
+ struct_pooled = None
789
+ struct_logits = None
790
+ struct_hidden = None
791
+
792
+ if self.config.use_3d and struct_token_ids is not None:
793
+ struct_hidden = self.struct_embeddings(struct_token_ids)
794
+ for layer in self.struct_layers:
795
+ struct_hidden = layer(struct_hidden, struct_attention_mask)
796
+
797
+ struct_pooled = struct_hidden[:, 0, :] # [CLS] token
798
+ struct_logits = self.struct_mlm_head(struct_hidden)
799
+
800
+ # Zero out structure representations for samples without 3D data
801
+ if has_3d is not None:
802
+ has_3d_expanded = has_3d.unsqueeze(1).float() # (batch, 1)
803
+ struct_pooled = struct_pooled * has_3d_expanded
804
+
805
+ # ===== Cross-Attention (Sequence → VQ-VAE Structural Tokens) =====
806
+ # Use residue-level alignment between WURCS tokens and VQ-VAE tokens
807
+ if self.config.use_cross_attention and struct_residue_ids is not None:
808
+ # Create residue-level mask
809
+ # WURCS token with residue_id=0 → VQ-VAE tokens with residue_id=0
810
+ residue_mask = create_residue_level_mask(
811
+ seq_residue_ids=seq_residue_ids,
812
+ struct_residue_ids=struct_residue_ids,
813
+ ) # (batch, N_seq, N_struct)
814
+
815
+ # Apply cross-attention: sequence tokens attend to VQ-VAE tokens
816
+ seq_hidden = self.cross_attention(
817
+ seq_hidden=seq_hidden,
818
+ struct_hidden=struct_hidden, # VQ-VAE token features
819
+ attention_mask=residue_mask, # Residue-based mask
820
+ )
821
+
822
+ # Update seq_pooled after cross-attention
823
+ seq_pooled = seq_hidden[:, 0, :]
824
+
825
+ # ===== Fusion =====
826
+ # Project to common hidden size
827
+ ms_pooled_projected = self.ms_projection(ms_pooled)
828
+
829
+ if self.config.use_3d and struct_pooled is not None:
830
+ struct_pooled_projected = self.struct_projection(struct_pooled)
831
+ combined = torch.cat([seq_pooled, ms_pooled_projected, struct_pooled_projected], dim=-1)
832
+ else:
833
+ combined = torch.cat([seq_pooled, ms_pooled_projected], dim=-1)
834
+
835
+ fused_repr = self.fusion_layer(combined)
836
+
837
+ # ===== Compute Losses =====
838
+ total_loss = None
839
+ seq_loss = None
840
+ ms_loss = None
841
+ struct_loss = None
842
+ dist_loss = None # NEW: Topology distance loss
843
+
844
+ if seq_labels is not None:
845
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
846
+ seq_loss = loss_fct(
847
+ seq_logits.view(-1, self.config.seq_vocab_size),
848
+ seq_labels.view(-1)
849
+ )
850
+
851
+ if ms_labels is not None:
852
+ ms_labels_masked = ms_labels.clone()
853
+ ms_labels_masked[~has_ms] = -100
854
+ # Only compute loss if there are valid labels (not all -100)
855
+ if (ms_labels_masked != -100).any():
856
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
857
+ ms_loss = loss_fct(
858
+ ms_logits.view(-1, self.config.ms_total_vocab_size),
859
+ ms_labels_masked.view(-1)
860
+ )
861
+ else:
862
+ ms_loss = torch.tensor(0.0, device=seq_token_ids.device)
863
+
864
+ if self.config.use_3d and struct_labels is not None and struct_logits is not None:
865
+ struct_labels_masked = struct_labels.clone()
866
+ if has_3d is not None:
867
+ struct_labels_masked[~has_3d] = -100
868
+ # Only compute loss if there are valid labels (not all -100)
869
+ if (struct_labels_masked != -100).any():
870
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
871
+ struct_loss = loss_fct(
872
+ struct_logits.view(-1, self.config.struct_vocab_size),
873
+ struct_labels_masked.view(-1)
874
+ )
875
+ else:
876
+ struct_loss = torch.tensor(0.0, device=seq_token_ids.device)
877
+
878
+ # ===== Distance Loss (Topology) =====
879
+ if dist_labels is not None:
880
+ # dist_predictions: (Batch, Seq, Seq, 1) -> (Batch, Seq, Seq)
881
+ preds = dist_predictions.squeeze(-1)
882
+
883
+ # Create mask for valid distance pairs (label != -1)
884
+ # Also respect attention mask to avoid padding
885
+ valid_mask = (dist_labels != -1) & (seq_attention_mask.unsqueeze(1) * seq_attention_mask.unsqueeze(2) == 1)
886
+
887
+ # DEBUG: Print once
888
+ if not hasattr(self, '_dist_debug_printed'):
889
+ print(f"[DIST DEBUG] dist_labels shape: {dist_labels.shape}, valid_mask.sum: {valid_mask.sum().item()}")
890
+ self._dist_debug_printed = True
891
+
892
+ if valid_mask.sum() > 0:
893
+ # MSE loss on valid positions only
894
+ loss_fct = nn.MSELoss()
895
+ dist_loss = loss_fct(preds[valid_mask], dist_labels[valid_mask].float())
896
+ else:
897
+ dist_loss = torch.tensor(0.0, device=seq_token_ids.device)
898
+ else:
899
+ # DEBUG: dist_labels is None
900
+ if not hasattr(self, '_dist_none_printed'):
901
+ print("[DIST DEBUG] dist_labels is None!")
902
+ self._dist_none_printed = True
903
+
904
+ # Weighted combination
905
+ losses = []
906
+ if seq_loss is not None:
907
+ losses.append(self.config.seq_loss_weight * seq_loss)
908
+ if ms_loss is not None:
909
+ losses.append(self.config.ms_loss_weight * ms_loss)
910
+ if struct_loss is not None:
911
+ losses.append(self.config.struct_loss_weight * struct_loss)
912
+ if dist_loss is not None:
913
+ losses.append(self.config.dist_loss_weight * dist_loss)
914
+
915
+ if losses:
916
+ total_loss = sum(losses)
917
+
918
+ if return_dict:
919
+ return {
920
+ 'loss': total_loss,
921
+ 'seq_loss': seq_loss,
922
+ 'ms_loss': ms_loss,
923
+ 'struct_loss': struct_loss,
924
+ 'dist_loss': dist_loss, # NEW: Topology loss
925
+ 'seq_logits': seq_logits,
926
+ 'ms_logits': ms_logits,
927
+ 'struct_logits': struct_logits,
928
+ 'dist_predictions': dist_predictions, # NEW: Distance predictions
929
+ 'seq_hidden': seq_hidden,
930
+ 'ms_hidden': ms_hidden,
931
+ 'struct_hidden': struct_hidden,
932
+ 'seq_pooled': seq_pooled,
933
+ 'ms_pooled': ms_pooled,
934
+ 'struct_pooled': struct_pooled,
935
+ 'fused_repr': fused_repr,
936
+ }
937
+ else:
938
+ return (total_loss, seq_logits, ms_logits, struct_logits, fused_repr)
939
+
940
+ def get_multimodal_representation(
941
+ self,
942
+ seq_token_ids: torch.Tensor,
943
+ seq_attention_mask: torch.Tensor,
944
+ seq_residue_ids: torch.Tensor,
945
+ ms_token_ids: torch.Tensor,
946
+ ms_attention_mask: torch.Tensor,
947
+ has_ms: torch.Tensor,
948
+ struct_token_ids: Optional[torch.Tensor] = None,
949
+ struct_attention_mask: Optional[torch.Tensor] = None,
950
+ struct_residue_ids: Optional[torch.Tensor] = None,
951
+ has_3d: Optional[torch.Tensor] = None,
952
+ ) -> torch.Tensor:
953
+ """Get fused multimodal representation (for inference)."""
954
+ outputs = self.forward(
955
+ seq_token_ids=seq_token_ids,
956
+ seq_attention_mask=seq_attention_mask,
957
+ seq_residue_ids=seq_residue_ids,
958
+ ms_token_ids=ms_token_ids,
959
+ ms_attention_mask=ms_attention_mask,
960
+ has_ms=has_ms,
961
+ struct_token_ids=struct_token_ids,
962
+ struct_attention_mask=struct_attention_mask,
963
+ struct_residue_ids=struct_residue_ids,
964
+ has_3d=has_3d,
965
+ return_dict=True,
966
+ )
967
+ return outputs['fused_repr']
968
+
969
+
970
+ if __name__ == "__main__":
971
+ # Test the model
972
+ print("="*80)
973
+ print("Testing Multimodal GlycanBERT v3")
974
+ print("="*80)
975
+
976
+ # Create config
977
+ config = MultimodalGlycanBERTConfig(
978
+ seq_vocab_size=166,
979
+ seq_hidden_size=768,
980
+ seq_num_layers=12,
981
+ seq_num_heads=12,
982
+ ms_vocab_size=242,
983
+ ms_hidden_size=384,
984
+ ms_num_layers=6,
985
+ ms_num_heads=6,
986
+ struct_vocab_size=1024,
987
+ struct_hidden_size=512,
988
+ struct_num_layers=8,
989
+ struct_num_heads=8,
990
+ use_3d=True,
991
+ use_cross_attention=True,
992
+ seq_loss_weight=0.60,
993
+ ms_loss_weight=0.15,
994
+ struct_loss_weight=0.25,
995
+ )
996
+
997
+ print(f"\nConfig:")
998
+ print(f" Sequence vocab: {config.seq_vocab_size}")
999
+ print(f" MS vocab: {config.ms_vocab_size}")
1000
+ print(f" Structure vocab: {config.struct_vocab_size}")
1001
+ print(f" Loss weights: seq={config.seq_loss_weight}, ms={config.ms_loss_weight}, struct={config.struct_loss_weight}")
1002
+
1003
+ # Create model
1004
+ model = MultimodalGlycanBERT(config)
1005
+
1006
+ # Count parameters
1007
+ total_params = sum(p.numel() for p in model.parameters())
1008
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1009
+
1010
+ print(f"\nModel Parameters:")
1011
+ print(f" Total: {total_params:,}")
1012
+ print(f" Trainable: {trainable_params:,}")
1013
+
1014
+ # Test forward pass
1015
+ print(f"\n{'='*80}")
1016
+ print("Testing Forward Pass (with Conv front-end)")
1017
+ print("="*80)
1018
+
1019
+ batch_size = 4
1020
+ seq_len = 128
1021
+ ms_len = 50
1022
+ struct_len = 40
1023
+
1024
+ # Create dummy inputs
1025
+ seq_token_ids = torch.randint(0, config.seq_vocab_size, (batch_size, seq_len))
1026
+ seq_attention_mask = torch.ones(batch_size, seq_len)
1027
+ # Approximate: ~5 tokens per residue
1028
+ seq_residue_ids = torch.div(
1029
+ torch.arange(seq_len), 5, rounding_mode="floor"
1030
+ ).unsqueeze(0).expand(batch_size, -1)
1031
+
1032
+ ms_token_ids = torch.randint(config.ms_vocab_offset, config.ms_total_vocab_size, (batch_size, ms_len))
1033
+ ms_attention_mask = torch.ones(batch_size, ms_len)
1034
+ struct_token_ids = torch.randint(0, config.struct_vocab_size, (batch_size, struct_len))
1035
+ struct_attention_mask = torch.ones(batch_size, struct_len)
1036
+ # Approximate: 4 tokens per residue for VQ-VAE tokens
1037
+ struct_residue_ids = torch.div(
1038
+ torch.arange(struct_len), 4, rounding_mode="floor"
1039
+ ).unsqueeze(0).expand(batch_size, -1)
1040
+
1041
+ has_ms = torch.tensor([True, True, False, True])
1042
+ has_3d = torch.tensor([True, False, True, True])
1043
+
1044
+ # Create labels for MLM
1045
+ seq_labels = seq_token_ids.clone()
1046
+ seq_labels[seq_labels != config.mask_token_id] = -100
1047
+ ms_labels = ms_token_ids.clone()
1048
+ ms_labels[ms_labels != config.mask_token_id] = -100
1049
+ struct_labels = struct_token_ids.clone()
1050
+ struct_labels[struct_labels != config.mask_token_id] = -100
1051
+
1052
+ # Forward pass
1053
+ outputs = model(
1054
+ seq_token_ids=seq_token_ids,
1055
+ seq_attention_mask=seq_attention_mask,
1056
+ seq_residue_ids=seq_residue_ids,
1057
+ ms_token_ids=ms_token_ids,
1058
+ ms_attention_mask=ms_attention_mask,
1059
+ has_ms=has_ms,
1060
+ struct_token_ids=struct_token_ids,
1061
+ struct_attention_mask=struct_attention_mask,
1062
+ struct_residue_ids=struct_residue_ids,
1063
+ has_3d=has_3d,
1064
+ seq_labels=seq_labels,
1065
+ ms_labels=ms_labels,
1066
+ struct_labels=struct_labels,
1067
+ )
1068
+
1069
+ print(f"\nOutput shapes:")
1070
+ print(f" seq_logits: {outputs['seq_logits'].shape}")
1071
+ print(f" ms_logits: {outputs['ms_logits'].shape}")
1072
+ print(f" struct_logits: {outputs['struct_logits'].shape}")
1073
+ print(f" fused_repr: {outputs['fused_repr'].shape}")
1074
+
1075
+ print(f"\nLosses:")
1076
+ print(f" Total loss: {outputs['loss'].item():.4f}")
1077
+ print(f" Sequence loss: {outputs['seq_loss'].item():.4f}")
1078
+ print(f" MS loss: {outputs['ms_loss'].item():.4f}")
1079
+ print(f" Structure loss: {outputs['struct_loss'].item():.4f}")
1080
+
1081
+ print(f"\n{'='*80}")
1082
+ print("Model Test Complete!")
1083
+ print("="*80)
1084
+
src/wurcs_bpe_tokenizer.py ADDED
@@ -0,0 +1,740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ WURCS-BPE Tokenizer
4
+
5
+ A hybrid tokenizer that learns semantic subwords from WURCS while preserving
6
+ the ability to handle rare/novel glycan structures character-by-character.
7
+
8
+ Key features:
9
+ 1. Pre-tokenization: Split WURCS into semantic units (residues, linkages, mods)
10
+ 2. BPE: Learn subword merges from corpus
11
+ 3. Character fallback: Handle novel structures
12
+ 4. Tree embeddings: Preserve branch_depth and linkage_type per token
13
+
14
+ Usage:
15
+ # Train BPE on corpus
16
+ tokenizer = WURCSBPETokenizer.train_from_corpus(
17
+ wurcs_strings,
18
+ num_merges=500,
19
+ output_path="bpe_vocabulary.json"
20
+ )
21
+
22
+ # Tokenize
23
+ result = tokenizer.tokenize(wurcs_string)
24
+ """
25
+
26
+ import json
27
+ import re
28
+ from collections import Counter, defaultdict
29
+ from pathlib import Path
30
+ from typing import Dict, List, Optional, Tuple, Set
31
+ import pickle
32
+
33
+
34
+ class WURCSPreTokenizer:
35
+ """
36
+ Pre-tokenize WURCS into semantic units before BPE.
37
+
38
+ WURCS format: WURCS=2.0/count/[residues]/indices/linkages
39
+
40
+ We split into:
41
+ - Residues: [a2122h-1b_1-5_2*NCC/3=O] -> one unit per []
42
+ - Linkages: a4-b1 -> one unit per linkage
43
+ - Special markers: [BRANCH_OPEN], [BRANCH_CLOSE], etc.
44
+ """
45
+
46
+ # Residue patterns for common monosaccharides
47
+ RESIDUE_PATTERN = re.compile(r'\[([^\]]+)\]')
48
+ LINKAGE_PATTERN = re.compile(r'([a-z])(\d+|\?)-([a-z])(\d+|\?)')
49
+
50
+ def __init__(self):
51
+ self.special_tokens = {
52
+ '[PAD]': 0,
53
+ '[UNK]': 1,
54
+ '[START]': 2,
55
+ '[END]': 3,
56
+ '[MASK]': 4,
57
+ '[BRANCH_OPEN]': 5,
58
+ '[BRANCH_CLOSE]': 6,
59
+ '[LINK]': 7,
60
+ '[MOD]': 8,
61
+ '[RESIDUE_ERROR]': 9,
62
+ }
63
+
64
+ def pre_tokenize(self, wurcs: str) -> List[Dict]:
65
+ """
66
+ Pre-tokenize WURCS into semantic units.
67
+
68
+ Returns list of dicts with:
69
+ - text: The unit text
70
+ - type: 'special', 'residue', 'linkage', 'mod', 'index'
71
+ - residue_id: Which residue this belongs to (-1 for special, -2 for linkage)
72
+ - branch_depth: Tree depth (computed later)
73
+ """
74
+ units = []
75
+
76
+ # Add start token
77
+ units.append({
78
+ 'text': '[START]',
79
+ 'type': 'special',
80
+ 'residue_id': -1,
81
+ 'branch_depth': 0,
82
+ 'linkage_type': 0,
83
+ })
84
+
85
+ # Parse WURCS sections
86
+ if not wurcs.startswith('WURCS='):
87
+ units.append({'text': '[RESIDUE_ERROR]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0})
88
+ units.append({'text': '[END]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0})
89
+ return units
90
+
91
+ try:
92
+ parts = self._split_wurcs_sections(wurcs)
93
+ if len(parts) < 4:
94
+ return [{'text': '[ERROR]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0}]
95
+
96
+ # parts: WURCS=2.0/3,3,2/[a2122h-1b_1-5][a2122h-1a_1-5][a1122h-1b_1-5]/1-2-3-1/a4-b1_b3-c1_c4-d1
97
+ # section 2: residue definitions
98
+ # section 3: indices
99
+ # section 4: linkages (optional)
100
+
101
+ version = parts[0] # WURCS=2.0
102
+ counts = parts[1] # residue_count,node_count,link_count
103
+ residue_defs = parts[2] # [res1][res2]...
104
+ indices = parts[3] # 1-2-3-1
105
+ linkages = parts[4] if len(parts) > 4 else "" # a4-b1_b3-c1
106
+
107
+ # Parse residue definitions
108
+ residue_list = self.RESIDUE_PATTERN.findall(residue_defs)
109
+
110
+ # Parse linkages to compute branch structure
111
+ linkage_list = linkages.split('_') if linkages else []
112
+ branch_points, residue_depths, linkage_types_map, adj = self._analyze_tree_structure(linkage_list, num_residues=len(residue_list))
113
+
114
+ # Compute distance matrix and cache it based on the linkage string (structure)
115
+ # This is the most expensive part, so we cache it
116
+ if not hasattr(self, '_dist_cache'): self._dist_cache = {}
117
+ if linkages not in self._dist_cache:
118
+ self._dist_cache[linkages] = self._compute_distance_matrix(adj, len(residue_list))
119
+ dist_matrix_raw = self._dist_cache[linkages]
120
+
121
+ # Parse indices to map positions to residue definitions
122
+ index_list = indices.split('-') if indices else []
123
+
124
+ # Process each residue instance
125
+ residue_letter = ord('a')
126
+ for idx, res_idx in enumerate(index_list):
127
+ current_residue_id = idx
128
+ res_letter = chr(residue_letter + idx)
129
+
130
+ # Check if this is a branch point - add branch marker before
131
+ if res_letter in branch_points and branch_points[res_letter] > 0:
132
+ for _ in range(branch_points[res_letter]):
133
+ units.append({
134
+ 'text': '[BRANCH_OPEN]',
135
+ 'type': 'special',
136
+ 'residue_id': -1,
137
+ 'branch_depth': residue_depths.get(res_letter, 0),
138
+ 'linkage_type': 0,
139
+ })
140
+
141
+ # Get residue definition
142
+ try:
143
+ res_def_idx = int(res_idx) - 1 # 1-indexed to 0-indexed
144
+ res_def = residue_list[res_def_idx] if res_def_idx < len(residue_list) else ""
145
+ except (ValueError, IndexError):
146
+ res_def = ""
147
+
148
+ # Split residue into base and modifications
149
+ res_parts = res_def.split('_')
150
+ base = res_parts[0] if res_parts else res_def
151
+ mods = res_parts[1:] if len(res_parts) > 1 else []
152
+
153
+ # Add residue base as a single unit
154
+ depth = residue_depths.get(res_letter, 0)
155
+ units.append({
156
+ 'text': base,
157
+ 'type': 'residue',
158
+ 'residue_id': current_residue_id,
159
+ 'branch_depth': depth,
160
+ 'linkage_type': 0,
161
+ })
162
+
163
+ # Add modifications
164
+ for mod in mods:
165
+ units.append({
166
+ 'text': mod,
167
+ 'type': 'mod',
168
+ 'residue_id': current_residue_id,
169
+ 'branch_depth': depth,
170
+ 'linkage_type': 0,
171
+ })
172
+
173
+ # Store distance matrix in units for easy access in tokenizer
174
+ if units:
175
+ # Find first residue unit or just use START
176
+ units[0]['distance_matrix'] = dist_matrix_raw
177
+
178
+ # Add linkages
179
+ for link in linkage_list:
180
+ if not link:
181
+ continue
182
+ # Parse linkage type
183
+ lt = self._parse_linkage_type(link)
184
+ units.append({
185
+ 'text': link,
186
+ 'type': 'linkage',
187
+ 'residue_id': -2,
188
+ 'branch_depth': 0,
189
+ 'linkage_type': lt,
190
+ })
191
+ except Exception:
192
+ # Fallback for truly broken WURCS
193
+ pass
194
+
195
+ # Add end token
196
+ units.append({
197
+ 'text': '[END]',
198
+ 'type': 'special',
199
+ 'residue_id': -1,
200
+ 'branch_depth': 0,
201
+ 'linkage_type': 0,
202
+ })
203
+
204
+ return units
205
+
206
+ def _split_wurcs_sections(self, wurcs: str) -> List[str]:
207
+ """Split WURCS string into sections, handling nested brackets."""
208
+ # Remove WURCS= prefix
209
+ if wurcs.startswith('WURCS='):
210
+ wurcs = wurcs[6:]
211
+
212
+ sections = []
213
+ current = ""
214
+ bracket_depth = 0
215
+
216
+ for char in wurcs:
217
+ if char == '[':
218
+ bracket_depth += 1
219
+ current += char
220
+ elif char == ']':
221
+ bracket_depth -= 1
222
+ current += char
223
+ elif char == '/' and bracket_depth == 0:
224
+ sections.append(current)
225
+ current = ""
226
+ else:
227
+ current += char
228
+
229
+ if current:
230
+ sections.append(current)
231
+
232
+ return sections
233
+
234
+ def _analyze_tree_structure(self, linkages: List[str], num_residues: int) -> Tuple[Dict, Dict, Dict, Dict]:
235
+ """Analyze linkages to compute branch points and residue depths."""
236
+ branch_points = defaultdict(int) # residue -> number of children
237
+ children = defaultdict(list)
238
+ all_residues = set()
239
+ linkage_types = {}
240
+
241
+ for link in linkages:
242
+ match = self.LINKAGE_PATTERN.match(link)
243
+ if match:
244
+ from_res, from_pos, to_res, to_pos = match.groups()
245
+ children[from_res].append(to_res)
246
+ all_residues.add(from_res)
247
+ all_residues.add(to_res)
248
+
249
+ # Store linkage type
250
+ linkage_types[link] = self._parse_linkage_type(link)
251
+
252
+ # Build adjacency list for BFS
253
+ adj = defaultdict(list)
254
+ for link in linkages:
255
+ match = self.LINKAGE_PATTERN.match(link)
256
+ if match:
257
+ u = ord(match.group(1)) - ord('a')
258
+ v = ord(match.group(3)) - ord('a')
259
+ if 0 <= u < num_residues and 0 <= v < num_residues:
260
+ adj[u].append(v)
261
+ adj[v].append(u)
262
+
263
+ # Find branch points (residues with >1 child)
264
+ for res, kids in children.items():
265
+ if len(kids) > 1:
266
+ branch_points[res] = len(kids) - 1 # Number of extra branches
267
+
268
+ # Compute depths using BFS
269
+ # Find root (residue with no parent)
270
+ child_set = set()
271
+ for kids in children.values():
272
+ child_set.update(kids)
273
+ roots = all_residues - child_set
274
+ root = min(roots) if roots else 'a'
275
+
276
+ depths = {root: 0}
277
+ queue = [root]
278
+ while queue:
279
+ current = queue.pop(0)
280
+ for child in children.get(current, []):
281
+ if child not in depths:
282
+ depths[child] = depths[current] + 1
283
+ queue.append(child)
284
+
285
+ return branch_points, depths, linkage_types, adj
286
+
287
+ def _compute_distance_matrix(self, adj: Dict[int, List[int]], num_residues: int) -> List[List[int]]:
288
+ """
289
+ Compute shortest path distance (number of bonds) between all residue pairs using BFS.
290
+ """
291
+ if num_residues == 0:
292
+ return []
293
+
294
+ dist_matrix = [[-1] * num_residues for _ in range(num_residues)]
295
+
296
+ for i in range(num_residues):
297
+ dist_matrix[i][i] = 0
298
+ queue = [(i, 0)]
299
+ visited = {i}
300
+
301
+ while queue:
302
+ curr, d = queue.pop(0)
303
+ dist_matrix[i][curr] = d
304
+
305
+ for neighbor in adj[curr]:
306
+ if neighbor not in visited:
307
+ visited.add(neighbor)
308
+ queue.append((neighbor, d + 1))
309
+
310
+ return dist_matrix
311
+
312
+ def _compute_distance_matrix_OLD(self, linkages: List[str], num_residues: int) -> List[List[int]]:
313
+ """
314
+ Compute shortest path distance (number of bonds) between all residue pairs.
315
+ Returns a symmetric N x N matrix where N is num_residues.
316
+ Values are integers (number of steps). 0 on diagonal. -1 if unreachable (shouldn't happen in single tree).
317
+ """
318
+ if num_residues == 0:
319
+ return []
320
+
321
+ # Initialize adjacency list
322
+ adj = defaultdict(list)
323
+ for link in linkages:
324
+ match = self.LINKAGE_PATTERN.match(link)
325
+ if match:
326
+ # WURCS indices are 1-based letters (a=1, b=2...)
327
+ from_res_char, _, to_res_char, _ = match.groups()
328
+ # Convert char to 0-based index
329
+ u = ord(from_res_char) - ord('a')
330
+ v = ord(to_res_char) - ord('a')
331
+
332
+ # Undirected graph for structural distance
333
+ if 0 <= u < num_residues and 0 <= v < num_residues:
334
+ adj[u].append(v)
335
+ adj[v].append(u)
336
+
337
+ # Compute All-Pairs Shortest Path (BFS from each node is fine for small N)
338
+ # Glycans are small (N ~ 5-20 usually), so O(N^2) BFS is cheap.
339
+ dist_matrix = [[-1] * num_residues for _ in range(num_residues)]
340
+
341
+ for i in range(num_residues):
342
+ dist_matrix[i][i] = 0
343
+ queue = [(i, 0)]
344
+ visited = {i}
345
+
346
+ while queue:
347
+ curr, d = queue.pop(0)
348
+ dist_matrix[i][curr] = d
349
+
350
+ for neighbor in adj[curr]:
351
+ if neighbor not in visited:
352
+ visited.add(neighbor)
353
+ queue.append((neighbor, d + 1))
354
+
355
+ return dist_matrix
356
+
357
+ def _parse_linkage_type(self, link: str) -> int:
358
+ """Parse linkage string to get type ID."""
359
+ LINKAGE_TYPES = {
360
+ (1, 2): 0, (2, 1): 0,
361
+ (1, 3): 1, (3, 1): 1,
362
+ (1, 4): 2, (4, 1): 2,
363
+ (1, 6): 3, (6, 1): 3,
364
+ (2, 3): 4, (3, 2): 4,
365
+ (2, 6): 5, (6, 2): 5,
366
+ (3, 6): 6, (6, 3): 6,
367
+ }
368
+
369
+ match = self.LINKAGE_PATTERN.match(link)
370
+ if match:
371
+ _, from_pos, _, to_pos = match.groups()
372
+ try:
373
+ pos_tuple = (int(from_pos), int(to_pos))
374
+ return LINKAGE_TYPES.get(pos_tuple, 7)
375
+ except ValueError:
376
+ return 8 # Unknown
377
+ return 8
378
+
379
+
380
+ class WURCSBPETokenizer:
381
+ """
382
+ BPE tokenizer for WURCS with tree-aware embeddings.
383
+ """
384
+
385
+ def __init__(self, vocab_path: Optional[str] = None):
386
+ self.pre_tokenizer = WURCSPreTokenizer()
387
+
388
+ # Special tokens (fixed)
389
+ self.special_tokens = self.pre_tokenizer.special_tokens
390
+
391
+ # BPE vocabulary
392
+ self.token_to_id: Dict[str, int] = {}
393
+ self.id_to_token: Dict[int, str] = {}
394
+ self.merges: List[Tuple[str, str]] = []
395
+
396
+ if vocab_path:
397
+ self.load_vocab(vocab_path)
398
+ else:
399
+ # Initialize with special tokens only
400
+ self.token_to_id = dict(self.special_tokens)
401
+ self.id_to_token = {v: k for k, v in self.token_to_id.items()}
402
+
403
+ @classmethod
404
+ def train_from_corpus(
405
+ cls,
406
+ wurcs_strings: List[str],
407
+ num_merges: int = 500,
408
+ output_path: Optional[str] = None,
409
+ min_frequency: int = 2,
410
+ max_token_length: Optional[int] = None,
411
+ ) -> 'WURCSBPETokenizer':
412
+ """
413
+ Train BPE on a corpus of WURCS strings.
414
+
415
+ Args:
416
+ wurcs_strings: List of WURCS strings
417
+ num_merges: Number of BPE merge operations
418
+ output_path: Optional path to save vocabulary
419
+ min_frequency: Minimum frequency for a token to be kept
420
+ max_token_length: Maximum length of a merged token (None = no limit)
421
+
422
+ Returns:
423
+ Trained tokenizer
424
+ """
425
+ tokenizer = cls()
426
+ pre_tok = WURCSPreTokenizer()
427
+
428
+ print(f"Training BPE on {len(wurcs_strings)} WURCS strings...")
429
+
430
+ # Step 1: Pre-tokenize all strings to get semantic units
431
+ all_units = []
432
+ for wurcs in wurcs_strings:
433
+ units = pre_tok.pre_tokenize(wurcs)
434
+ for unit in units:
435
+ if unit['type'] != 'special':
436
+ all_units.append(unit['text'])
437
+
438
+ # Step 2: Count unit frequencies
439
+ unit_counts = Counter(all_units)
440
+ print(f"Found {len(unit_counts)} unique units")
441
+
442
+ # Step 3: Initialize vocabulary with characters from all units
443
+ char_vocab = set()
444
+ for unit in unit_counts:
445
+ for char in unit:
446
+ char_vocab.add(char)
447
+
448
+ # Build initial vocab: special tokens + characters
449
+ vocab_id = len(tokenizer.special_tokens)
450
+ for char in sorted(char_vocab):
451
+ tokenizer.token_to_id[char] = vocab_id
452
+ tokenizer.id_to_token[vocab_id] = char
453
+ vocab_id += 1
454
+
455
+ print(f"Initial vocab size: {vocab_id} (special + characters)")
456
+
457
+ # Step 4: Convert units to character sequences
458
+ word_freqs = {}
459
+ for unit, count in unit_counts.items():
460
+ if count >= min_frequency:
461
+ # Split into characters with space separator
462
+ chars = tuple(unit)
463
+ word_freqs[chars] = count
464
+
465
+ # Step 5: BPE merging
466
+ merges = []
467
+
468
+ for merge_idx in range(num_merges):
469
+ # Count pairs
470
+ pair_counts = Counter()
471
+ for word, freq in word_freqs.items():
472
+ for i in range(len(word) - 1):
473
+ pair = (word[i], word[i + 1])
474
+ pair_counts[pair] += freq
475
+
476
+ if not pair_counts:
477
+ break
478
+
479
+ # Find most frequent pair
480
+ best_pair = pair_counts.most_common(1)[0][0]
481
+ best_count = pair_counts[best_pair]
482
+
483
+ if best_count < min_frequency:
484
+ break
485
+
486
+ # Merge pair
487
+ new_token = best_pair[0] + best_pair[1]
488
+
489
+ # Check length constraint
490
+ if max_token_length and len(new_token) > max_token_length:
491
+ # remove this pair from consideration for this iteration and future?
492
+ # Actually, skipping it here is tricky because we need to ignore it in pair_counts next time
493
+ # Simpler: Just skip adding it to merges and modify word_freqs?
494
+ # No, if we don't merge, we just continue to the next best pair in THIS iteration.
495
+ # But pair_counts is already computed.
496
+ # We need to loop until we find a valid pair or run out
497
+
498
+ # In this simple implementation, let's just skip this merge efficiently
499
+ # We need to find the NEXT most common pair.
500
+
501
+ # Re-do finding best pair loop
502
+ found_valid_pair = False
503
+ for pair, count in pair_counts.most_common():
504
+ token_candidate = pair[0] + pair[1]
505
+ if max_token_length and len(token_candidate) > max_token_length:
506
+ continue # Skip too long
507
+
508
+ if count < min_frequency:
509
+ break # Stop if frequency too low
510
+
511
+ # Found valid pair
512
+ best_pair = pair
513
+ best_count = count
514
+ new_token = token_candidate
515
+ found_valid_pair = True
516
+ break
517
+
518
+ if not found_valid_pair:
519
+ print(f" Stopping early: No more pairs satisfy max_token_length={max_token_length}")
520
+ break
521
+
522
+ # Final check before merging (in case we didn't enter the if block but updated vars)
523
+ # Actually the logic above handles it. If we entered the block, we either found a new best_pair or broke.
524
+
525
+ merges.append(best_pair)
526
+
527
+ # Add to vocab
528
+ tokenizer.token_to_id[new_token] = vocab_id
529
+ tokenizer.id_to_token[vocab_id] = new_token
530
+ vocab_id += 1
531
+
532
+ # Update word_freqs
533
+ new_word_freqs = {}
534
+ for word, freq in word_freqs.items():
535
+ new_word = []
536
+ i = 0
537
+ while i < len(word):
538
+ if i < len(word) - 1 and word[i] == best_pair[0] and word[i + 1] == best_pair[1]:
539
+ new_word.append(new_token)
540
+ i += 2
541
+ else:
542
+ new_word.append(word[i])
543
+ i += 1
544
+ new_word_freqs[tuple(new_word)] = freq
545
+ word_freqs = new_word_freqs
546
+
547
+ if (merge_idx + 1) % 100 == 0:
548
+ print(f" Merge {merge_idx + 1}/{num_merges}: '{best_pair[0]}' + '{best_pair[1]}' -> '{new_token}' (count={best_count})")
549
+
550
+ tokenizer.merges = merges
551
+ print(f"Final vocab size: {len(tokenizer.token_to_id)}")
552
+
553
+ # Save if requested
554
+ if output_path:
555
+ tokenizer.save_vocab(output_path)
556
+
557
+ return tokenizer
558
+
559
+ def apply_bpe(self, text: str) -> List[str]:
560
+ """Apply BPE merges to a text string."""
561
+ if text in self.token_to_id:
562
+ return [text]
563
+
564
+ # Split into characters
565
+ tokens = list(text)
566
+
567
+ # Apply merges
568
+ for pair in self.merges:
569
+ new_tokens = []
570
+ i = 0
571
+ while i < len(tokens):
572
+ if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
573
+ new_tokens.append(pair[0] + pair[1])
574
+ i += 2
575
+ else:
576
+ new_tokens.append(tokens[i])
577
+ i += 1
578
+ tokens = new_tokens
579
+
580
+ return tokens
581
+
582
+ def tokenize(self, wurcs: str, max_length: int = 256) -> Dict:
583
+ """
584
+ Tokenize a WURCS string.
585
+
586
+ Returns:
587
+ Dict with:
588
+ - tokens: List of token strings
589
+ - token_ids: List of token IDs
590
+ - residue_ids: List of residue IDs
591
+ - branch_depths: List of branch depths
592
+ - linkage_types: List of linkage types
593
+ - attention_mask: Attention mask
594
+ """
595
+ # Pre-tokenize
596
+ units = self.pre_tokenizer.pre_tokenize(wurcs)
597
+
598
+ tokens = []
599
+ token_ids = []
600
+ residue_ids = []
601
+ branch_depths = []
602
+ linkage_types = []
603
+
604
+ for unit in units:
605
+ if unit['type'] == 'special':
606
+ # Special tokens stay as-is
607
+ tok = unit['text']
608
+ tokens.append(tok)
609
+ token_ids.append(self.token_to_id.get(tok, self.token_to_id['[UNK]']))
610
+ residue_ids.append(unit['residue_id'])
611
+ branch_depths.append(unit['branch_depth'])
612
+ linkage_types.append(unit['linkage_type'])
613
+ else:
614
+ # Apply BPE to this unit
615
+ bpe_tokens = self.apply_bpe(unit['text'])
616
+ for tok in bpe_tokens:
617
+ tokens.append(tok)
618
+ token_ids.append(self.token_to_id.get(tok, self.token_to_id['[UNK]']))
619
+ residue_ids.append(unit['residue_id'])
620
+ branch_depths.append(unit['branch_depth'])
621
+ linkage_types.append(unit['linkage_type'])
622
+
623
+ # Truncate if needed
624
+ if len(tokens) > max_length:
625
+ tokens = tokens[:max_length - 1] + ['[END]']
626
+ token_ids = token_ids[:max_length - 1] + [self.token_to_id['[END]']]
627
+ residue_ids = residue_ids[:max_length - 1] + [-1]
628
+ branch_depths = branch_depths[:max_length - 1] + [0]
629
+ linkage_types = linkage_types[:max_length - 1] + [0]
630
+
631
+ # Create attention mask and pad
632
+ length = len(tokens)
633
+ attention_mask = [1] * length
634
+
635
+ while len(tokens) < max_length:
636
+ tokens.append('[PAD]')
637
+ token_ids.append(self.token_to_id['[PAD]'])
638
+ residue_ids.append(-1)
639
+ branch_depths.append(0)
640
+ linkage_types.append(0)
641
+ attention_mask.append(0)
642
+ # Pre-tokenize
643
+ units = self.pre_tokenizer.pre_tokenize(wurcs)
644
+
645
+ # Extract distance matrix from pre-tokenizer result
646
+ dist_matrix_raw = units[0].get('distance_matrix', [])
647
+ num_residues = len(dist_matrix_raw)
648
+
649
+ # Map token-to-token distances using residue_ids
650
+ # token_i is associated with residue_ids[i].
651
+ # residue_ids[i] is index into dist_matrix_raw.
652
+ # If residue_ids[i] == -1 (special), distance is undefined (use -1 or 999)
653
+
654
+ # Use UNPADDED length for distance matrix to save massive memory
655
+ # distance_matrix will be e.g. 20x20, while tokens are padded to 256
656
+ token_len = length
657
+ distance_matrix = [[-1] * token_len for _ in range(token_len)]
658
+
659
+ for i in range(token_len):
660
+ for j in range(token_len):
661
+ r_i = residue_ids[i]
662
+ r_j = residue_ids[j]
663
+
664
+ if r_i >= 0 and r_j >= 0 and r_i < num_residues and r_j < num_residues:
665
+ distance_matrix[i][j] = dist_matrix_raw[r_i][r_j]
666
+ else:
667
+ distance_matrix[i][j] = -1 # Special/Padding
668
+
669
+ # MEMORY OPTIMIZATION: Do NOT pad matrix here.
670
+ # Pad on-the-fly in Dataset class instead.
671
+ # This saves massive memory (0.2GB vs 66GB).
672
+
673
+ return {
674
+ 'tokens': tokens,
675
+ 'token_ids': token_ids,
676
+ 'residue_ids': residue_ids,
677
+ 'branch_depths': branch_depths,
678
+ 'linkage_types': linkage_types,
679
+ 'attention_mask': attention_mask,
680
+ 'distance_matrix': distance_matrix, # New Output
681
+ 'length': length,
682
+ }
683
+
684
+ def save_vocab(self, path: str):
685
+ """Save vocabulary to JSON file."""
686
+ data = {
687
+ 'special_tokens': self.special_tokens,
688
+ 'token_to_id': self.token_to_id,
689
+ 'merges': [list(m) for m in self.merges],
690
+ 'metadata': {
691
+ 'vocab_size': len(self.token_to_id),
692
+ 'num_merges': len(self.merges),
693
+ }
694
+ }
695
+ with open(path, 'w') as f:
696
+ json.dump(data, f, indent=2)
697
+ print(f"Saved vocabulary to {path}")
698
+
699
+ def load_vocab(self, path: str):
700
+ """Load vocabulary from JSON file."""
701
+ with open(path, 'r') as f:
702
+ data = json.load(f)
703
+
704
+ self.special_tokens = data['special_tokens']
705
+ self.token_to_id = data['token_to_id']
706
+ self.id_to_token = {int(v): k for k, v in self.token_to_id.items()}
707
+ self.merges = [tuple(m) for m in data['merges']]
708
+
709
+ print(f"Loaded vocabulary with {len(self.token_to_id)} tokens")
710
+
711
+ @property
712
+ def vocab_size(self) -> int:
713
+ return len(self.token_to_id)
714
+
715
+
716
+ # ============================================================================
717
+ # Testing
718
+ # ============================================================================
719
+
720
+ if __name__ == '__main__':
721
+ # Test pre-tokenizer
722
+ print("="*80)
723
+ print("Testing WURCSPreTokenizer")
724
+ print("="*80)
725
+
726
+ pre_tok = WURCSPreTokenizer()
727
+
728
+ test_wurcs = [
729
+ "WURCS=2.0/2,2,1/[a2122h-1b_1-5][a2211m-1a_1-5]/1-2/a4-b1",
730
+ "WURCS=2.0/3,3,2/[a2122h-1b_1-5_2*NCC/3=O][a2112h-1a_1-5][a2211m-1a_1-5]/1-2-3/a4-b1_b3-c1",
731
+ ]
732
+
733
+ for wurcs in test_wurcs:
734
+ print(f"\nWURCS: {wurcs[:60]}...")
735
+ units = pre_tok.pre_tokenize(wurcs)
736
+ print(f"Units ({len(units)}):")
737
+ for u in units[:10]:
738
+ print(f" {u['type']:10} | res={u['residue_id']:2} | depth={u['branch_depth']} | {u['text']}")
739
+ if len(units) > 10:
740
+ print(f" ... and {len(units) - 10} more")
vocab/bpe_ambiguity_tokens.json ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ambiguous_tokens": {
3
+ "?": 32,
4
+ "?|": 90,
5
+ "a?|": 108,
6
+ "a?|b": 109,
7
+ "?|c": 110,
8
+ "a?|b?|c": 111,
9
+ "?|d": 112,
10
+ "a?|b?|c?|d": 113,
11
+ "?|e": 114,
12
+ "a?|b?|c?|d?|e": 115,
13
+ "?|f": 116,
14
+ "a?|b?|c?|d?|e?|f": 117,
15
+ "?-": 118,
16
+ "?|g": 119,
17
+ "a?|b?|c?|d?|e?|f?|g": 120,
18
+ "?|h": 122,
19
+ "?|i": 123,
20
+ "?|h?|i": 124,
21
+ "?|j": 125,
22
+ "?|h?|i?|j": 126,
23
+ "?|k": 128,
24
+ "?|h?|i?|j?|k": 129,
25
+ "?|l": 130,
26
+ "?|h?|i?|j?|k?|l": 131,
27
+ "?|m": 132,
28
+ "?|h?|i?|j?|k?|l?|m": 133,
29
+ "?|h?|i?|j?|k?|l?|m?|": 138,
30
+ "n?|": 141,
31
+ "n?|o": 142,
32
+ "?}": 143,
33
+ "n?|o?|": 146,
34
+ "n?|o?|p": 147,
35
+ "?}-": 149,
36
+ "?}-{": 150,
37
+ "n?|o?|p?|": 153,
38
+ "n?|o?|p?|q": 154,
39
+ "n?|o?|p?|q?|": 157,
40
+ "n?|o?|p?|q?|r": 158,
41
+ "n?|o?|p?|q?|r?|": 165,
42
+ "n?|o?|p?|q?|r?|s": 166,
43
+ "n?|o?|p?|q?|r?|s?|": 170,
44
+ "n?|o?|p?|q?|r?|s?|t": 171,
45
+ "?|u": 189,
46
+ "a?-": 197,
47
+ "c?-": 201,
48
+ "?|u?|": 209,
49
+ "?|u?|v": 210,
50
+ "b?-": 211,
51
+ "a?-b1": 213,
52
+ "d?-": 217,
53
+ "b?-c1": 221,
54
+ "c?-d1": 230,
55
+ "?|u?|v?|": 231,
56
+ "?|u?|v?|w": 232,
57
+ "1-?": 242,
58
+ "d?-e1": 244,
59
+ "e?-": 245,
60
+ "e?-f1": 262,
61
+ "?|u?|v?|w?|": 266,
62
+ "?|u?|v?|w?|x": 267,
63
+ "f?-": 273,
64
+ "?|u?|v?|w?|x?|": 288,
65
+ "?|u?|v?|w?|x?|y": 289,
66
+ "g?-": 298,
67
+ "n?|o?|p?|q?|r?|s?|t?": 304,
68
+ "i?-": 306,
69
+ "h?-": 308,
70
+ "?|u?}-{": 312,
71
+ "?|u?": 313,
72
+ "n?|o?|p?|q?|r?}-{": 314,
73
+ "n?|o?|p?|q?|r?": 315,
74
+ "f?-g1": 318,
75
+ "n?|o?}-{": 322,
76
+ "n?|o?": 323,
77
+ "?|u?|v?|w?|x?|y?|": 325,
78
+ "?|u?|v?|w?|x?|y?|z": 326,
79
+ "n?|o?|p?|q?|r?|s?}-{": 328,
80
+ "n?|o?|p?|q?|r?|s?": 329,
81
+ "n?|o?|p?}-{": 331,
82
+ "n?|o?|p?": 332,
83
+ "n?|o?|p?|q?}-{": 336,
84
+ "n?|o?|p?|q?": 337,
85
+ "g?-h1": 339,
86
+ "?|h?|i?|j?|k?|l?}-{": 342,
87
+ "?|h?|i?|j?|k?|l?": 343,
88
+ "n?}-{": 344,
89
+ "n?": 345,
90
+ "j?-": 346,
91
+ "h?-i1": 347,
92
+ "?|u?|v?}-{": 351,
93
+ "?|u?|v?": 352,
94
+ "k?-": 353,
95
+ "i?-j1": 355,
96
+ "?|h?|i?|j?|k?|l?|m?": 363,
97
+ "?|h?|i?|j?|k?}-{": 364,
98
+ "?|h?|i?|j?|k?": 365,
99
+ "?|u?|v?|w?|x?|y?|z?|": 369,
100
+ "?|h?|i?|j?}-{": 375,
101
+ "?|h?|i?|j?": 376,
102
+ "l?-": 377,
103
+ "A?|": 392,
104
+ "A?|B": 393,
105
+ "j?-k1": 401,
106
+ "m?-": 404,
107
+ "?|h?|i?}-{": 408,
108
+ "?|h?|i?": 409,
109
+ "k?-l1": 418,
110
+ "?|h?}-{": 420,
111
+ "?|h?": 421,
112
+ "A?|B?|": 424,
113
+ "A?|B?|C": 425,
114
+ "a?|b?|c?|d?|e?|f?|g?": 427,
115
+ "f?-g2": 431,
116
+ "a?|b?|c?|d?|e?|f?}-{": 437,
117
+ "a?|b?|c?|d?|e?|f?": 438,
118
+ "l?-m1": 442,
119
+ "?|u?|v?|w?}-{": 450,
120
+ "?|u?|v?|w?": 451,
121
+ "A?|B?|C?|": 464,
122
+ "A?|B?|C?|D": 465,
123
+ "a?|b?|c?|d?|e?}-{": 475,
124
+ "a?|b?|c?|d?|e?": 476,
125
+ "n?-": 499,
126
+ "a?|b?|c?|d?}-{": 502,
127
+ "a?|b?|c?|d?": 503,
128
+ "m?-n1": 518,
129
+ "A?|B?|C?|D?|": 521,
130
+ "A?|B?|C?|D?|E": 522,
131
+ "o?-": 534,
132
+ "d?-h1": 536,
133
+ "A?|B?|C?|D?|E?|": 542,
134
+ "A?|B?|C?|D?|E?|F": 543,
135
+ "c?-i1": 544,
136
+ "c?-h1": 549,
137
+ "A?|B?|C?|D?|E?|F?|": 550,
138
+ "A?|B?|C?|D?|E?|F?|G": 551,
139
+ "?|u?|v?|w?|x?}-{": 563,
140
+ "?|u?|v?|w?|x?": 564,
141
+ "a?|b?|c?}-{": 571,
142
+ "a?|b?|c?}-{a?|b?|c": 572,
143
+ "a?|b?|c?}-{a?|b?|c?": 573,
144
+ "?|H": 581,
145
+ "2-?": 592,
146
+ "?|H?|": 598,
147
+ "?|H?|I": 599,
148
+ "?}*OC": 600,
149
+ "c?-k1": 607,
150
+ "c?-g1": 609,
151
+ "?|H?|I?|": 615,
152
+ "?|H?|I?|J": 616,
153
+ "n?-o1": 617,
154
+ "d?-g1": 629,
155
+ "o?-p1": 634,
156
+ "p?-": 646,
157
+ "?|u?|v?|w?|x?|y?}-{": 653,
158
+ "?|u?|v?|w?|x?|y?": 654,
159
+ "b?-c2": 656,
160
+ "d?-i1": 658,
161
+ "c?-j1": 691,
162
+ "?}*OSO": 696,
163
+ "e?-h1": 701,
164
+ "q?-": 713,
165
+ "c?-f1": 720,
166
+ "i?-j2": 728,
167
+ "?|h?|i?}": 742,
168
+ "h?-i2": 747,
169
+ "g?-h2": 753,
170
+ "c?-l1": 756,
171
+ "j?-k2": 758,
172
+ "?|h?}": 759,
173
+ "c?-e1": 760,
174
+ "?|H?|I?|J?|": 761,
175
+ "?|H?|I?|J?|K": 762,
176
+ "a?|b?|c?|d?|e?|f?}": 772,
177
+ "b?-e1": 774,
178
+ "b?-f1": 791,
179
+ "d?-f1": 794,
180
+ "p?-q1": 796,
181
+ "a?|b?|c?|d?|e?}": 798,
182
+ "a?-d1": 800,
183
+ "m?-n2": 803,
184
+ "e?-g1": 809,
185
+ "?|h?|i?|j?}": 812,
186
+ "r?-": 817,
187
+ "a?-c1": 818,
188
+ "?|u?|v?|w?|x?|y?|z?": 822,
189
+ "a?-e1": 826,
190
+ "d?-j1": 833,
191
+ "b?-g1": 834,
192
+ "q?-r1": 847,
193
+ "d?-e2": 854,
194
+ "c?-m1": 860,
195
+ "a?-f1": 875,
196
+ "b?-d1": 887,
197
+ "?|H?|I?|J?|K?|": 892,
198
+ "?|H?|I?|J?|K?|L": 893,
199
+ "?|H?|I?|J?|K?|L?|": 894,
200
+ "?|H?|I?|J?|K?|L?|M": 895,
201
+ "?|H?|I?|J?|K?|L?|M?|": 896,
202
+ "a?-l1": 920,
203
+ "?*OSO/3=O/3=O": 923,
204
+ "k?-l2": 940,
205
+ "k?-o1": 942,
206
+ "N?|": 965,
207
+ "N?|O": 966,
208
+ "N?|O?|": 967,
209
+ "N?|O?|P": 968,
210
+ "N?|O?|P?|": 969,
211
+ "N?|O?|P?|Q": 970,
212
+ "N?|O?|P?|Q?|": 971,
213
+ "N?|O?|P?|Q?|R": 972,
214
+ "N?|O?|P?|Q?|R?|": 973,
215
+ "N?|O?|P?|Q?|R?|S": 974,
216
+ "N?|O?|P?|Q?|R?|S?|": 975,
217
+ "N?|O?|P?|Q?|R?|S?|T": 976,
218
+ "?|U": 977,
219
+ "?|U?|": 978,
220
+ "?|U?|V": 979,
221
+ "c?-d2": 983,
222
+ "r?-s1": 988,
223
+ "a?|b?}-{": 995,
224
+ "a?|b?}-{a?|b": 996,
225
+ "a?|b?}-{a?|b?": 997,
226
+ "e?-f2": 1001,
227
+ "g?-i1": 1006,
228
+ "i?-l1": 1010,
229
+ "s?-": 1011,
230
+ "?|h?|i?|j?|k?}": 1017,
231
+ "b?-h1": 1034,
232
+ "a?-j1": 1038,
233
+ "n?-o2": 1046,
234
+ "a?-b2": 1069,
235
+ "e?-i1": 1095,
236
+ "h?-j1": 1102,
237
+ "a?-k1": 1108,
238
+ "i?-k1": 1115,
239
+ "a?-g1": 1116,
240
+ "?}*OPO": 1122,
241
+ "d?-k1": 1129,
242
+ "a?-m1": 1151,
243
+ "a?-i1": 1159,
244
+ "A?}-{": 1174,
245
+ "A?": 1175,
246
+ "?}*OCC": 1177,
247
+ "l?-m2": 1179,
248
+ "A?|B?}-{": 1180,
249
+ "A?|B?": 1181,
250
+ "f?-h1": 1183,
251
+ "a?-n1": 1189,
252
+ "p?-q2": 1192,
253
+ "c?-n1": 1197,
254
+ "?|U?|V?|": 1202,
255
+ "?|U?|V?|W": 1203,
256
+ "?|U?|V?|W?|": 1204,
257
+ "?|U?|V?|W?|X": 1205,
258
+ "?|U?|V?|W?|X?|": 1206,
259
+ "?|U?|V?|W?|X?|Y": 1207,
260
+ "?|a": 1208,
261
+ "s?-t1": 1223,
262
+ "?|h?|i?|j?|k?|l?|m?}": 1228,
263
+ "g?-j1": 1234,
264
+ "A?|B?|C?|D?}-{": 1242,
265
+ "A?|B?|C?|D?": 1243,
266
+ "a?-h1": 1253,
267
+ "?|H?|I?|J?}-{": 1257,
268
+ "?|H?|I?|J?": 1258,
269
+ "o?-p2": 1261,
270
+ "b?-i1": 1273,
271
+ "?|h?|i?|j?|k?|l?}": 1309,
272
+ "j?-m1": 1317,
273
+ "c?-o1": 1318,
274
+ "a?-o1": 1330,
275
+ "a?|b?|c?}*OC": 1331,
276
+ "b?-j1": 1357,
277
+ "a?-r1": 1361,
278
+ "n?}": 1363,
279
+ "A?|B?|C?}-{": 1371,
280
+ "A?|B?|C?": 1372,
281
+ "m?-p1": 1375,
282
+ "l?-p1": 1383,
283
+ "a?-p1": 1444,
284
+ "k?-n1": 1446,
285
+ "j?-l1": 1470,
286
+ "?|U?|V?|W?|X?|Y?|": 1471,
287
+ "?|U?|V?|W?|X?|Y?|Z": 1472,
288
+ "?|aa?|": 1473,
289
+ "?|aa?|a": 1474,
290
+ "?|aa?|ab": 1475,
291
+ "?*OPO/3O/3=O": 1476,
292
+ "l?-q1": 1489,
293
+ "l?-n1": 1499,
294
+ "a?-s1": 1517,
295
+ "k?-m1": 1524,
296
+ "a?-q1": 1546,
297
+ "c?-q1": 1547,
298
+ "t?-": 1551,
299
+ "a?|b?|c?|d?}*OC": 1565,
300
+ "f?-i1": 1590,
301
+ "c?-p1": 1591,
302
+ "n?-q1": 1593,
303
+ "?|i?}": 1611,
304
+ "a?|b?|c?|d?|e?}*OC": 1612,
305
+ "m?-q1": 1617,
306
+ "q?-r2": 1623,
307
+ "l?-o1": 1624,
308
+ "m?-r1": 1628,
309
+ "a?-t1": 1630,
310
+ "a?|b?|c?|d?}*OSO": 1649,
311
+ "c?-r1": 1675,
312
+ "1-d?|i?}": 1683,
313
+ "j?-n1": 1691,
314
+ "u?-": 1694,
315
+ "a?|b?|c?|d?|e?}*OSO": 1718,
316
+ "?*OCC/3=O": 1723,
317
+ "?%": 1752,
318
+ "?*OP^XOCCN/3O/3=O": 1770,
319
+ "t?-u1": 1772,
320
+ "?*": 1774,
321
+ "c?-s1": 1775,
322
+ "a?-u1": 1793,
323
+ "f?-h2": 1808,
324
+ "e?-j1": 1811,
325
+ "c?-t1": 1818,
326
+ "f1-a?|b?|c?|d?|e?}": 1822,
327
+ "u?-v1": 1835,
328
+ "h?-k1": 1841,
329
+ "?|H?|I?|J?|K?}-{": 1846,
330
+ "?|H?|I?|J?|K?": 1847,
331
+ "n?|o?}": 1851,
332
+ "1-d?|h?}": 1852,
333
+ "q?-s1": 1872,
334
+ "%?%": 1880,
335
+ "b?-g2": 1881,
336
+ "r?-s2": 1882,
337
+ "d?-l1": 1898,
338
+ "v?-": 1917,
339
+ "b?-k1": 1927,
340
+ "?|aa?|ab?|a": 1942,
341
+ "?|aa?|ab?|ac": 1943,
342
+ "?|aa?|ab?|ac?|": 1944,
343
+ "?|aa?|ab?|ac?|ad": 1945,
344
+ "a?|b?}*OC": 1949,
345
+ "?*OC": 1952,
346
+ "e?-k1": 1955,
347
+ "a?-d2": 1999,
348
+ "s?-t2": 2013,
349
+ "a?-f2": 2027,
350
+ "o?-q1": 2030,
351
+ "?}*OP^XOCCN": 2040,
352
+ "a?|b?|c?}*OCC": 2047,
353
+ "m?-o1": 2048,
354
+ "c?-f2": 2058,
355
+ "A?|B?|C?|D?|E?|F?|G?": 2060,
356
+ "a?|b?|c?}*OSO": 2071,
357
+ "?|U?|V?}-{": 2079,
358
+ "?|U?|V?": 2080,
359
+ "c?-u1": 2087
360
+ },
361
+ "ambiguous_ids": [
362
+ 32,
363
+ 90,
364
+ 108,
365
+ 109,
366
+ 110,
367
+ 111,
368
+ 112,
369
+ 113,
370
+ 114,
371
+ 115,
372
+ 116,
373
+ 117,
374
+ 118,
375
+ 119,
376
+ 120,
377
+ 122,
378
+ 123,
379
+ 124,
380
+ 125,
381
+ 126,
382
+ 128,
383
+ 129,
384
+ 130,
385
+ 131,
386
+ 132,
387
+ 133,
388
+ 138,
389
+ 141,
390
+ 142,
391
+ 143,
392
+ 146,
393
+ 147,
394
+ 149,
395
+ 150,
396
+ 153,
397
+ 154,
398
+ 157,
399
+ 158,
400
+ 165,
401
+ 166,
402
+ 170,
403
+ 171,
404
+ 189,
405
+ 197,
406
+ 201,
407
+ 209,
408
+ 210,
409
+ 211,
410
+ 213,
411
+ 217,
412
+ 221,
413
+ 230,
414
+ 231,
415
+ 232,
416
+ 242,
417
+ 244,
418
+ 245,
419
+ 262,
420
+ 266,
421
+ 267,
422
+ 273,
423
+ 288,
424
+ 289,
425
+ 298,
426
+ 304,
427
+ 306,
428
+ 308,
429
+ 312,
430
+ 313,
431
+ 314,
432
+ 315,
433
+ 318,
434
+ 322,
435
+ 323,
436
+ 325,
437
+ 326,
438
+ 328,
439
+ 329,
440
+ 331,
441
+ 332,
442
+ 336,
443
+ 337,
444
+ 339,
445
+ 342,
446
+ 343,
447
+ 344,
448
+ 345,
449
+ 346,
450
+ 347,
451
+ 351,
452
+ 352,
453
+ 353,
454
+ 355,
455
+ 363,
456
+ 364,
457
+ 365,
458
+ 369,
459
+ 375,
460
+ 376,
461
+ 377,
462
+ 392,
463
+ 393,
464
+ 401,
465
+ 404,
466
+ 408,
467
+ 409,
468
+ 418,
469
+ 420,
470
+ 421,
471
+ 424,
472
+ 425,
473
+ 427,
474
+ 431,
475
+ 437,
476
+ 438,
477
+ 442,
478
+ 450,
479
+ 451,
480
+ 464,
481
+ 465,
482
+ 475,
483
+ 476,
484
+ 499,
485
+ 502,
486
+ 503,
487
+ 518,
488
+ 521,
489
+ 522,
490
+ 534,
491
+ 536,
492
+ 542,
493
+ 543,
494
+ 544,
495
+ 549,
496
+ 550,
497
+ 551,
498
+ 563,
499
+ 564,
500
+ 571,
501
+ 572,
502
+ 573,
503
+ 581,
504
+ 592,
505
+ 598,
506
+ 599,
507
+ 600,
508
+ 607,
509
+ 609,
510
+ 615,
511
+ 616,
512
+ 617,
513
+ 629,
514
+ 634,
515
+ 646,
516
+ 653,
517
+ 654,
518
+ 656,
519
+ 658,
520
+ 691,
521
+ 696,
522
+ 701,
523
+ 713,
524
+ 720,
525
+ 728,
526
+ 742,
527
+ 747,
528
+ 753,
529
+ 756,
530
+ 758,
531
+ 759,
532
+ 760,
533
+ 761,
534
+ 762,
535
+ 772,
536
+ 774,
537
+ 791,
538
+ 794,
539
+ 796,
540
+ 798,
541
+ 800,
542
+ 803,
543
+ 809,
544
+ 812,
545
+ 817,
546
+ 818,
547
+ 822,
548
+ 826,
549
+ 833,
550
+ 834,
551
+ 847,
552
+ 854,
553
+ 860,
554
+ 875,
555
+ 887,
556
+ 892,
557
+ 893,
558
+ 894,
559
+ 895,
560
+ 896,
561
+ 920,
562
+ 923,
563
+ 940,
564
+ 942,
565
+ 965,
566
+ 966,
567
+ 967,
568
+ 968,
569
+ 969,
570
+ 970,
571
+ 971,
572
+ 972,
573
+ 973,
574
+ 974,
575
+ 975,
576
+ 976,
577
+ 977,
578
+ 978,
579
+ 979,
580
+ 983,
581
+ 988,
582
+ 995,
583
+ 996,
584
+ 997,
585
+ 1001,
586
+ 1006,
587
+ 1010,
588
+ 1011,
589
+ 1017,
590
+ 1034,
591
+ 1038,
592
+ 1046,
593
+ 1069,
594
+ 1095,
595
+ 1102,
596
+ 1108,
597
+ 1115,
598
+ 1116,
599
+ 1122,
600
+ 1129,
601
+ 1151,
602
+ 1159,
603
+ 1174,
604
+ 1175,
605
+ 1177,
606
+ 1179,
607
+ 1180,
608
+ 1181,
609
+ 1183,
610
+ 1189,
611
+ 1192,
612
+ 1197,
613
+ 1202,
614
+ 1203,
615
+ 1204,
616
+ 1205,
617
+ 1206,
618
+ 1207,
619
+ 1208,
620
+ 1223,
621
+ 1228,
622
+ 1234,
623
+ 1242,
624
+ 1243,
625
+ 1253,
626
+ 1257,
627
+ 1258,
628
+ 1261,
629
+ 1273,
630
+ 1309,
631
+ 1317,
632
+ 1318,
633
+ 1330,
634
+ 1331,
635
+ 1357,
636
+ 1361,
637
+ 1363,
638
+ 1371,
639
+ 1372,
640
+ 1375,
641
+ 1383,
642
+ 1444,
643
+ 1446,
644
+ 1470,
645
+ 1471,
646
+ 1472,
647
+ 1473,
648
+ 1474,
649
+ 1475,
650
+ 1476,
651
+ 1489,
652
+ 1499,
653
+ 1517,
654
+ 1524,
655
+ 1546,
656
+ 1547,
657
+ 1551,
658
+ 1565,
659
+ 1590,
660
+ 1591,
661
+ 1593,
662
+ 1611,
663
+ 1612,
664
+ 1617,
665
+ 1623,
666
+ 1624,
667
+ 1628,
668
+ 1630,
669
+ 1649,
670
+ 1675,
671
+ 1683,
672
+ 1691,
673
+ 1694,
674
+ 1718,
675
+ 1723,
676
+ 1752,
677
+ 1770,
678
+ 1772,
679
+ 1774,
680
+ 1775,
681
+ 1793,
682
+ 1808,
683
+ 1811,
684
+ 1818,
685
+ 1822,
686
+ 1835,
687
+ 1841,
688
+ 1846,
689
+ 1847,
690
+ 1851,
691
+ 1852,
692
+ 1872,
693
+ 1880,
694
+ 1881,
695
+ 1882,
696
+ 1898,
697
+ 1917,
698
+ 1927,
699
+ 1942,
700
+ 1943,
701
+ 1944,
702
+ 1945,
703
+ 1949,
704
+ 1952,
705
+ 1955,
706
+ 1999,
707
+ 2013,
708
+ 2027,
709
+ 2030,
710
+ 2040,
711
+ 2047,
712
+ 2048,
713
+ 2058,
714
+ 2060,
715
+ 2071,
716
+ 2079,
717
+ 2080,
718
+ 2087
719
+ ],
720
+ "source_vocab": "data/bpe_vocabulary_clean.json"
721
+ }
vocab/bpe_vocabulary.json ADDED
The diff for this file is too large to render. See raw diff