Add files using upload-large-folder tool
Browse files- .gitattributes +1 -33
- README.md +37 -0
- SHA256SUMS +9 -0
- checkpoints/best_v51_contrastive_model.pt +3 -0
- config.json +11 -0
- requirements.txt +5 -0
- src/multimodal_glycan_bert_v3.py +1084 -0
- src/wurcs_bpe_tokenizer.py +740 -0
- vocab/bpe_ambiguity_tokens.json +721 -0
- vocab/bpe_vocabulary.json +0 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,3 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 3 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: pytorch
|
| 3 |
+
license: other
|
| 4 |
+
tags:
|
| 5 |
+
- glycans
|
| 6 |
+
- wurcs
|
| 7 |
+
- bertose
|
| 8 |
+
- ambiguity-resolution
|
| 9 |
+
- contrastive-learning
|
| 10 |
+
- pytorch
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# Bertose IAR Ambiguity Resolver
|
| 14 |
+
|
| 15 |
+
Draft private release for Bertose ambiguity-resolution inference.
|
| 16 |
+
|
| 17 |
+
This repository contains the contrastive Bertose checkpoint used to score ambiguous WURCS BPE tokens and support iterative ambiguity resolution.
|
| 18 |
+
|
| 19 |
+
## Files
|
| 20 |
+
|
| 21 |
+
- `checkpoints/best_v51_contrastive_model.pt` - contrastive ambiguity-resolution checkpoint.
|
| 22 |
+
- `vocab/bpe_vocabulary.json` - WURCS BPE vocabulary.
|
| 23 |
+
- `vocab/bpe_ambiguity_tokens.json` - ambiguous BPE token map used by the resolver.
|
| 24 |
+
- `src/multimodal_glycan_bert_v3.py` - model definition.
|
| 25 |
+
- `src/wurcs_bpe_tokenizer.py` - WURCS BPE tokenizer.
|
| 26 |
+
|
| 27 |
+
## Expected Input
|
| 28 |
+
|
| 29 |
+
Single glycan or batch CSV with WURCS strings.
|
| 30 |
+
|
| 31 |
+
## Output
|
| 32 |
+
|
| 33 |
+
Token-level ambiguity-resolution predictions with confidence scores. The companion notebook writes both summary and detail CSVs for batch runs.
|
| 34 |
+
|
| 35 |
+
## Draft Notes
|
| 36 |
+
|
| 37 |
+
This release does not claim to reconstruct final canonical WURCS strings by itself. It provides model-backed token-level updates and confidence values for ambiguous positions.
|
SHA256SUMS
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
622368f62c23e97e9137c277eaadcc93ee3901cbb420b591422bb1c2e19689a5 ./.gitattributes
|
| 2 |
+
266caeb2fb9b68076343b40da91116dca0f2302f03cf28c2332b80b1a69c1758 ./README.md
|
| 3 |
+
ae468f4e8c06dc0c3848138a474dc43249aa6d14dfd0df8f58d68fcaad371152 ./checkpoints/best_v51_contrastive_model.pt
|
| 4 |
+
daf55c190fece0678064e41697a9545592beb1285f8aa74e595b933b9d37b4c2 ./config.json
|
| 5 |
+
6a56e6f73b8f874470ecde6e538f3f5029ae23aa6c10559817d1c2a8b59b7c0f ./requirements.txt
|
| 6 |
+
0d9ce16bf90242f38621d64cd974ea5679bff4c2013bea8d7bffe1b8dd120794 ./src/multimodal_glycan_bert_v3.py
|
| 7 |
+
0bc54399362945601bcfd403441fc80968d173200dd0561f57568b2053a94839 ./src/wurcs_bpe_tokenizer.py
|
| 8 |
+
c68cd003370b2dcdb162f848f766e4e62f2653c6c38d205f8cbe53a9aabe2d74 ./vocab/bpe_ambiguity_tokens.json
|
| 9 |
+
6a572afdf53f1494ab96c896876b824ca7ea749777352606aa9f96bf270ceecc ./vocab/bpe_vocabulary.json
|
checkpoints/best_v51_contrastive_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ae468f4e8c06dc0c3848138a474dc43249aa6d14dfd0df8f58d68fcaad371152
|
| 3 |
+
size 557458637
|
config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_family": "Bertose",
|
| 3 |
+
"release_name": "bertose-iar-ambiguity-resolver",
|
| 4 |
+
"checkpoint": "checkpoints/best_v51_contrastive_model.pt",
|
| 5 |
+
"vocabulary": "vocab/bpe_vocabulary.json",
|
| 6 |
+
"ambiguity_tokens": "vocab/bpe_ambiguity_tokens.json",
|
| 7 |
+
"embedding_dim": 768,
|
| 8 |
+
"max_glycan_length": 256,
|
| 9 |
+
"input_format": "WURCS",
|
| 10 |
+
"output_format": "token_level_predictions"
|
| 11 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
numpy
|
| 3 |
+
pandas
|
| 4 |
+
tqdm
|
| 5 |
+
huggingface_hub
|
src/multimodal_glycan_bert_v3.py
ADDED
|
@@ -0,0 +1,1084 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multimodal Glycan BERT Model v3
|
| 3 |
+
|
| 4 |
+
Extends GlycanBERT to handle three modalities:
|
| 5 |
+
- Sequence (WURCS atomic tokenization)
|
| 6 |
+
- MS (mass spectrometry peaks, RT, intensity)
|
| 7 |
+
- 3D structure (VQ-VAE discrete tokens, 4 per residue)
|
| 8 |
+
|
| 9 |
+
Each modality has its own encoder, with cross-attention for sequence-structure alignment.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from typing import Dict, Optional, Tuple
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from .glycan_bert import GlycanBERTConfig, GlycanBERTEmbeddings, GlycanBERTLayer
|
| 19 |
+
except ImportError:
|
| 20 |
+
from glycan_bert import GlycanBERTConfig, GlycanBERTEmbeddings, GlycanBERTLayer
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ConvGlycanBERTEmbeddings(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
Improved Convolutional front-end that mixes local WURCS context before the Transformer.
|
| 26 |
+
|
| 27 |
+
Key improvements over original:
|
| 28 |
+
1. Position embeddings added BEFORE convolution (provides spatial context to conv)
|
| 29 |
+
2. Residual connection (conv enriches embeddings rather than replacing them)
|
| 30 |
+
3. Multi-scale convolutions (kernel sizes 3, 5, 7) for better receptive field
|
| 31 |
+
4. Proper layer normalization on the residual path
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, config):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.token_embeddings = nn.Embedding(
|
| 37 |
+
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
|
| 38 |
+
)
|
| 39 |
+
self.position_embeddings = nn.Embedding(
|
| 40 |
+
config.max_position_embeddings, config.hidden_size
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# NEW: Branch depth embeddings - encodes depth in glycan tree (0=root, 1=child, etc.)
|
| 44 |
+
max_branch_depth = getattr(config, "max_branch_depth", 8)
|
| 45 |
+
self.branch_embeddings = nn.Embedding(max_branch_depth, config.hidden_size)
|
| 46 |
+
|
| 47 |
+
# NEW: Linkage type embeddings - encodes chemistry of glycosidic bond
|
| 48 |
+
# 0=none, 1=1-3, 2=1-4, 3=1-6, etc.
|
| 49 |
+
num_linkage_types = getattr(config, "num_linkage_types", 9)
|
| 50 |
+
self.linkage_embeddings = nn.Embedding(num_linkage_types, config.hidden_size)
|
| 51 |
+
|
| 52 |
+
# Multi-scale convolutions for different receptive fields
|
| 53 |
+
kernel_size = getattr(config, "cnn_kernel_size", 3)
|
| 54 |
+
# Split channels evenly: 256 + 256 + 256 = 768 for hidden_size=768
|
| 55 |
+
channels_per_scale = config.hidden_size // 3
|
| 56 |
+
self.conv_layers = nn.ModuleList([
|
| 57 |
+
nn.Conv1d(
|
| 58 |
+
in_channels=config.hidden_size,
|
| 59 |
+
out_channels=channels_per_scale,
|
| 60 |
+
kernel_size=kernel_size + 2 * i, # Kernels: 3, 5, 7
|
| 61 |
+
padding=(kernel_size + 2 * i) // 2, # Same padding
|
| 62 |
+
)
|
| 63 |
+
for i in range(3)
|
| 64 |
+
])
|
| 65 |
+
self.conv_activation = nn.GELU()
|
| 66 |
+
self.conv_proj = nn.Linear(channels_per_scale * 3, config.hidden_size) # Project concatenated back
|
| 67 |
+
|
| 68 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 69 |
+
self.conv_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 70 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 71 |
+
self.register_buffer(
|
| 72 |
+
"position_ids",
|
| 73 |
+
torch.arange(config.max_position_embeddings).expand((1, -1)),
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
self.hidden_size = config.hidden_size
|
| 77 |
+
|
| 78 |
+
def forward(self, input_ids, branch_depths=None, linkage_types=None):
|
| 79 |
+
seq_len = input_ids.shape[1]
|
| 80 |
+
|
| 81 |
+
# Step 1: Token + Position embeddings FIRST (provides spatial context to conv)
|
| 82 |
+
x = self.token_embeddings(input_ids) # (batch, seq, hidden)
|
| 83 |
+
position_ids = self.position_ids[:, :seq_len]
|
| 84 |
+
x = x + self.position_embeddings(position_ids)
|
| 85 |
+
|
| 86 |
+
# NEW: Add branch depth embeddings (encodes tree structure)
|
| 87 |
+
if branch_depths is not None:
|
| 88 |
+
# Clamp to valid range
|
| 89 |
+
branch_depths = branch_depths.clamp(0, self.branch_embeddings.num_embeddings - 1)
|
| 90 |
+
x = x + self.branch_embeddings(branch_depths)
|
| 91 |
+
|
| 92 |
+
# NEW: Add linkage type embeddings (encodes bond chemistry)
|
| 93 |
+
if linkage_types is not None:
|
| 94 |
+
linkage_types = linkage_types.clamp(0, self.linkage_embeddings.num_embeddings - 1)
|
| 95 |
+
x = x + self.linkage_embeddings(linkage_types)
|
| 96 |
+
|
| 97 |
+
x = self.LayerNorm(x)
|
| 98 |
+
|
| 99 |
+
# Step 2: Multi-scale convolution with RESIDUAL connection
|
| 100 |
+
# Convolution expects (batch, hidden, seq)
|
| 101 |
+
conv_in = x.permute(0, 2, 1)
|
| 102 |
+
|
| 103 |
+
# Apply multi-scale convolutions and concatenate
|
| 104 |
+
conv_outputs = []
|
| 105 |
+
for conv in self.conv_layers:
|
| 106 |
+
conv_out = self.conv_activation(conv(conv_in))
|
| 107 |
+
conv_outputs.append(conv_out)
|
| 108 |
+
|
| 109 |
+
# Concatenate multi-scale features and project back
|
| 110 |
+
conv_out = torch.cat(conv_outputs, dim=1) # (batch, hidden, seq)
|
| 111 |
+
conv_out = conv_out.permute(0, 2, 1) # (batch, seq, hidden)
|
| 112 |
+
conv_out = self.conv_proj(conv_out) # Project to correct size
|
| 113 |
+
|
| 114 |
+
# Step 3: Residual connection - conv ENRICHES rather than replaces
|
| 115 |
+
x = self.conv_norm(x + self.dropout(conv_out))
|
| 116 |
+
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def create_residue_level_mask(
|
| 121 |
+
seq_residue_ids: torch.Tensor, # (batch, N_seq)
|
| 122 |
+
struct_residue_ids: torch.Tensor # (batch, N_struct)
|
| 123 |
+
) -> torch.Tensor:
|
| 124 |
+
"""
|
| 125 |
+
Create residue-level attention mask for cross-attention.
|
| 126 |
+
|
| 127 |
+
Maps WURCS tokens to VQ-VAE structural tokens based on residue IDs.
|
| 128 |
+
A WURCS token with residue_id=0 can only attend to VQ-VAE tokens with residue_id=0.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
seq_residue_ids: Residue IDs for sequence tokens (batch, N_seq)
|
| 132 |
+
struct_residue_ids: Residue IDs for structural tokens (batch, N_struct)
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
Boolean mask (batch, N_seq, N_struct) where True = can attend
|
| 136 |
+
"""
|
| 137 |
+
# Expand dimensions for broadcasting
|
| 138 |
+
# seq: (batch, N_seq, 1)
|
| 139 |
+
# struct: (batch, 1, N_struct)
|
| 140 |
+
mask = seq_residue_ids.unsqueeze(2) == struct_residue_ids.unsqueeze(1)
|
| 141 |
+
# Shape: (batch, N_seq, N_struct)
|
| 142 |
+
|
| 143 |
+
# Mask out structural tokens (residue_id = -1) and MS tokens (residue_id = -2)
|
| 144 |
+
# Only tokens with residue_id >= 0 can attend
|
| 145 |
+
mask &= (seq_residue_ids.unsqueeze(2) >= 0)
|
| 146 |
+
|
| 147 |
+
return mask # True = can attend, False = cannot attend
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class MultimodalGlycanBERTConfig:
|
| 151 |
+
"""Configuration for Multimodal GlycanBERT v3."""
|
| 152 |
+
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
# Sequence modality
|
| 156 |
+
seq_vocab_size: int = 166,
|
| 157 |
+
seq_hidden_size: int = 768,
|
| 158 |
+
seq_num_layers: int = 12,
|
| 159 |
+
seq_num_heads: int = 12,
|
| 160 |
+
seq_max_length: int = 512,
|
| 161 |
+
|
| 162 |
+
# MS modality
|
| 163 |
+
ms_vocab_size: int = 242,
|
| 164 |
+
ms_hidden_size: int = 384,
|
| 165 |
+
ms_num_layers: int = 6,
|
| 166 |
+
ms_num_heads: int = 6,
|
| 167 |
+
ms_max_length: int = 150,
|
| 168 |
+
|
| 169 |
+
# 3D structure modality
|
| 170 |
+
struct_vocab_size: int = 1024, # VQ-VAE codebook size
|
| 171 |
+
struct_hidden_size: int = 512,
|
| 172 |
+
struct_num_layers: int = 8,
|
| 173 |
+
struct_num_heads: int = 8,
|
| 174 |
+
struct_max_length: int = 200,
|
| 175 |
+
use_3d: bool = True,
|
| 176 |
+
|
| 177 |
+
# Cross-attention
|
| 178 |
+
use_cross_attention: bool = True,
|
| 179 |
+
cross_attn_num_heads: int = 8,
|
| 180 |
+
|
| 181 |
+
# Fusion
|
| 182 |
+
fusion_hidden_size: int = 768,
|
| 183 |
+
fusion_num_layers: int = 2,
|
| 184 |
+
|
| 185 |
+
# Training
|
| 186 |
+
hidden_dropout_prob: float = 0.1,
|
| 187 |
+
attention_probs_dropout_prob: float = 0.1,
|
| 188 |
+
layer_norm_eps: float = 1e-12,
|
| 189 |
+
initializer_range: float = 0.02,
|
| 190 |
+
|
| 191 |
+
# Conv front-end
|
| 192 |
+
use_cnn_frontend: bool = True,
|
| 193 |
+
cnn_kernel_size: int = 3,
|
| 194 |
+
|
| 195 |
+
# Loss weights
|
| 196 |
+
seq_loss_weight: float = 0.60,
|
| 197 |
+
ms_loss_weight: float = 0.15,
|
| 198 |
+
struct_loss_weight: float = 0.25,
|
| 199 |
+
|
| 200 |
+
# Token IDs
|
| 201 |
+
pad_token_id: int = 0,
|
| 202 |
+
mask_token_id: int = 1,
|
| 203 |
+
):
|
| 204 |
+
# Sequence config
|
| 205 |
+
self.seq_vocab_size = seq_vocab_size
|
| 206 |
+
self.seq_hidden_size = seq_hidden_size
|
| 207 |
+
self.seq_num_layers = seq_num_layers
|
| 208 |
+
self.seq_num_heads = seq_num_heads
|
| 209 |
+
self.seq_max_length = seq_max_length
|
| 210 |
+
|
| 211 |
+
# MS config
|
| 212 |
+
self.ms_vocab_size = ms_vocab_size
|
| 213 |
+
self.ms_vocab_offset = seq_vocab_size # MS tokens start at 166
|
| 214 |
+
self.ms_total_vocab_size = seq_vocab_size + ms_vocab_size # 408 total
|
| 215 |
+
self.ms_hidden_size = ms_hidden_size
|
| 216 |
+
self.ms_num_layers = ms_num_layers
|
| 217 |
+
self.ms_num_heads = ms_num_heads
|
| 218 |
+
self.ms_max_length = ms_max_length
|
| 219 |
+
|
| 220 |
+
# Structure config
|
| 221 |
+
self.struct_vocab_size = struct_vocab_size
|
| 222 |
+
self.struct_hidden_size = struct_hidden_size
|
| 223 |
+
self.struct_num_layers = struct_num_layers
|
| 224 |
+
self.struct_num_heads = struct_num_heads
|
| 225 |
+
self.struct_max_length = struct_max_length
|
| 226 |
+
self.use_3d = use_3d
|
| 227 |
+
|
| 228 |
+
# Cross-attention config
|
| 229 |
+
self.use_cross_attention = use_cross_attention
|
| 230 |
+
self.cross_attn_num_heads = cross_attn_num_heads
|
| 231 |
+
|
| 232 |
+
# Fusion config
|
| 233 |
+
self.fusion_hidden_size = fusion_hidden_size
|
| 234 |
+
self.fusion_num_layers = fusion_num_layers
|
| 235 |
+
|
| 236 |
+
# Training config
|
| 237 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 238 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 239 |
+
self.layer_norm_eps = layer_norm_eps
|
| 240 |
+
self.initializer_range = initializer_range
|
| 241 |
+
|
| 242 |
+
# Conv front-end
|
| 243 |
+
self.use_cnn_frontend = use_cnn_frontend
|
| 244 |
+
self.cnn_kernel_size = cnn_kernel_size
|
| 245 |
+
|
| 246 |
+
# Loss weights
|
| 247 |
+
self.seq_loss_weight = seq_loss_weight
|
| 248 |
+
self.ms_loss_weight = ms_loss_weight
|
| 249 |
+
self.struct_loss_weight = struct_loss_weight
|
| 250 |
+
self.dist_loss_weight = 0.25 # NEW: Topology loss weight (default, can override from config)
|
| 251 |
+
|
| 252 |
+
# Token IDs
|
| 253 |
+
self.pad_token_id = pad_token_id
|
| 254 |
+
self.mask_token_id = mask_token_id
|
| 255 |
+
|
| 256 |
+
def to_seq_config(self) -> GlycanBERTConfig:
|
| 257 |
+
"""Convert to sequence-only config."""
|
| 258 |
+
return GlycanBERTConfig(
|
| 259 |
+
vocab_size=self.seq_vocab_size,
|
| 260 |
+
hidden_size=self.seq_hidden_size,
|
| 261 |
+
num_hidden_layers=self.seq_num_layers,
|
| 262 |
+
num_attention_heads=self.seq_num_heads,
|
| 263 |
+
intermediate_size=self.seq_hidden_size * 4,
|
| 264 |
+
hidden_dropout_prob=self.hidden_dropout_prob,
|
| 265 |
+
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
| 266 |
+
max_position_embeddings=self.seq_max_length,
|
| 267 |
+
layer_norm_eps=self.layer_norm_eps,
|
| 268 |
+
pad_token_id=self.pad_token_id,
|
| 269 |
+
mask_token_id=self.mask_token_id,
|
| 270 |
+
initializer_range=self.initializer_range,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
def to_ms_config(self) -> GlycanBERTConfig:
|
| 274 |
+
"""Convert to MS-only config."""
|
| 275 |
+
return GlycanBERTConfig(
|
| 276 |
+
vocab_size=self.ms_total_vocab_size,
|
| 277 |
+
hidden_size=self.ms_hidden_size,
|
| 278 |
+
num_hidden_layers=self.ms_num_layers,
|
| 279 |
+
num_attention_heads=self.ms_num_heads,
|
| 280 |
+
intermediate_size=self.ms_hidden_size * 4,
|
| 281 |
+
hidden_dropout_prob=self.hidden_dropout_prob,
|
| 282 |
+
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
| 283 |
+
max_position_embeddings=self.ms_max_length,
|
| 284 |
+
layer_norm_eps=self.layer_norm_eps,
|
| 285 |
+
pad_token_id=self.pad_token_id,
|
| 286 |
+
mask_token_id=self.mask_token_id,
|
| 287 |
+
initializer_range=self.initializer_range,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
def to_struct_config(self) -> GlycanBERTConfig:
|
| 291 |
+
"""Convert to structure-only config."""
|
| 292 |
+
return GlycanBERTConfig(
|
| 293 |
+
vocab_size=self.struct_vocab_size,
|
| 294 |
+
hidden_size=self.struct_hidden_size,
|
| 295 |
+
num_hidden_layers=self.struct_num_layers,
|
| 296 |
+
num_attention_heads=self.struct_num_heads,
|
| 297 |
+
intermediate_size=self.struct_hidden_size * 4,
|
| 298 |
+
hidden_dropout_prob=self.hidden_dropout_prob,
|
| 299 |
+
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
| 300 |
+
max_position_embeddings=self.struct_max_length,
|
| 301 |
+
layer_norm_eps=self.layer_norm_eps,
|
| 302 |
+
pad_token_id=self.pad_token_id,
|
| 303 |
+
mask_token_id=self.mask_token_id,
|
| 304 |
+
initializer_range=self.initializer_range,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# =============================================================================
|
| 309 |
+
# Improvement #1: Monosaccharide-Level Pooling
|
| 310 |
+
# =============================================================================
|
| 311 |
+
|
| 312 |
+
class MonosaccharidePooling(nn.Module):
|
| 313 |
+
"""
|
| 314 |
+
Pool token representations to monosaccharide level, then aggregate.
|
| 315 |
+
|
| 316 |
+
This bridges the gap between token-level BERT and monosaccharide-level CNNs/GNNs.
|
| 317 |
+
Uses monosaccharide_indices from the data to know where each residue starts.
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
def __init__(self, hidden_size: int, num_attention_heads: int = 8, dropout: float = 0.1):
|
| 321 |
+
super().__init__()
|
| 322 |
+
self.hidden_size = hidden_size
|
| 323 |
+
|
| 324 |
+
# Attention pooling over monosaccharide representations
|
| 325 |
+
self.mono_attention = nn.MultiheadAttention(
|
| 326 |
+
embed_dim=hidden_size,
|
| 327 |
+
num_heads=num_attention_heads,
|
| 328 |
+
dropout=dropout,
|
| 329 |
+
batch_first=True
|
| 330 |
+
)
|
| 331 |
+
self.mono_norm = nn.LayerNorm(hidden_size)
|
| 332 |
+
|
| 333 |
+
# Final aggregation to single glycan representation
|
| 334 |
+
self.glycan_query = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
|
| 335 |
+
self.glycan_attention = nn.MultiheadAttention(
|
| 336 |
+
embed_dim=hidden_size,
|
| 337 |
+
num_heads=num_attention_heads,
|
| 338 |
+
dropout=dropout,
|
| 339 |
+
batch_first=True
|
| 340 |
+
)
|
| 341 |
+
self.glycan_norm = nn.LayerNorm(hidden_size)
|
| 342 |
+
|
| 343 |
+
def forward(
|
| 344 |
+
self,
|
| 345 |
+
hidden_states: torch.Tensor, # (batch, seq_len, hidden)
|
| 346 |
+
residue_ids: torch.Tensor, # (batch, seq_len) - which residue each token belongs to
|
| 347 |
+
attention_mask: torch.Tensor = None, # (batch, seq_len)
|
| 348 |
+
) -> torch.Tensor:
|
| 349 |
+
"""
|
| 350 |
+
Pool tokens to monosaccharide level, then to glycan level.
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
Glycan representation: (batch, hidden_size)
|
| 354 |
+
"""
|
| 355 |
+
batch_size = hidden_states.size(0)
|
| 356 |
+
device = hidden_states.device
|
| 357 |
+
|
| 358 |
+
# Get unique residue IDs per sample (excluding -1 padding)
|
| 359 |
+
max_residues = 50 # Reasonable max for glycans
|
| 360 |
+
|
| 361 |
+
# Pool tokens within each residue using mean pooling
|
| 362 |
+
mono_reps = torch.zeros(batch_size, max_residues, self.hidden_size, device=device)
|
| 363 |
+
mono_mask = torch.zeros(batch_size, max_residues, dtype=torch.bool, device=device)
|
| 364 |
+
|
| 365 |
+
for b in range(batch_size):
|
| 366 |
+
unique_residues = torch.unique(residue_ids[b][residue_ids[b] >= 0])
|
| 367 |
+
for i, rid in enumerate(unique_residues):
|
| 368 |
+
if i >= max_residues:
|
| 369 |
+
break
|
| 370 |
+
token_mask = residue_ids[b] == rid
|
| 371 |
+
if attention_mask is not None:
|
| 372 |
+
token_mask = token_mask & (attention_mask[b] > 0)
|
| 373 |
+
if token_mask.sum() > 0:
|
| 374 |
+
mono_reps[b, i] = hidden_states[b][token_mask].mean(dim=0)
|
| 375 |
+
mono_mask[b, i] = True
|
| 376 |
+
|
| 377 |
+
# Apply attention over monosaccharide representations
|
| 378 |
+
# Convert mask for attention: True = valid, need to invert for PyTorch
|
| 379 |
+
key_padding_mask = ~mono_mask # True = ignore
|
| 380 |
+
|
| 381 |
+
mono_out, _ = self.mono_attention(
|
| 382 |
+
mono_reps, mono_reps, mono_reps,
|
| 383 |
+
key_padding_mask=key_padding_mask
|
| 384 |
+
)
|
| 385 |
+
mono_out = self.mono_norm(mono_reps + mono_out)
|
| 386 |
+
|
| 387 |
+
# Aggregate to single glycan representation using learned query
|
| 388 |
+
glycan_query = self.glycan_query.expand(batch_size, -1, -1)
|
| 389 |
+
glycan_out, _ = self.glycan_attention(
|
| 390 |
+
glycan_query, mono_out, mono_out,
|
| 391 |
+
key_padding_mask=key_padding_mask
|
| 392 |
+
)
|
| 393 |
+
glycan_out = self.glycan_norm(glycan_query + glycan_out)
|
| 394 |
+
|
| 395 |
+
return glycan_out.squeeze(1) # (batch, hidden)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
# =============================================================================
|
| 399 |
+
# Improvement #2: Residue Type Embeddings
|
| 400 |
+
# =============================================================================
|
| 401 |
+
|
| 402 |
+
# Common monosaccharide types vocabulary
|
| 403 |
+
MONOSACCHARIDE_VOCAB = {
|
| 404 |
+
'[PAD_MONO]': 0, '[UNK_MONO]': 1,
|
| 405 |
+
'Glc': 2, 'GlcNAc': 3, 'GlcA': 4, 'GlcN': 5,
|
| 406 |
+
'Gal': 6, 'GalNAc': 7, 'GalA': 8, 'GalN': 9,
|
| 407 |
+
'Man': 10, 'ManNAc': 11, 'ManA': 12, 'ManN': 13,
|
| 408 |
+
'Fuc': 14, 'Rha': 15, 'Xyl': 16, 'Ara': 17,
|
| 409 |
+
'Neu5Ac': 18, 'Neu5Gc': 19, 'Kdn': 20, 'Sia': 21,
|
| 410 |
+
'GalNAcA': 22, 'GlcNAcA': 23, 'IdoA': 24, 'GulA': 25,
|
| 411 |
+
'Rib': 26, 'Lyx': 27, 'All': 28, 'Alt': 29,
|
| 412 |
+
'Tal': 30, 'Ido': 31, 'Qui': 32, 'Oli': 33,
|
| 413 |
+
'Tyv': 34, 'Abe': 35, 'Par': 36, 'Dig': 37,
|
| 414 |
+
'Col': 38, 'Dha': 39, 'Kdo': 40, 'Hep': 41,
|
| 415 |
+
'NeuroGc': 42, 'Muramic': 43, 'LDManHep': 44, 'DDManHep': 45,
|
| 416 |
+
'Bac': 46, 'Pse': 47, 'Leg': 48, 'Aci': 49,
|
| 417 |
+
'6dTal': 50, 'Fru': 51, 'Tag': 52, 'Sor': 53,
|
| 418 |
+
'Psi': 54, 'Sed': 55, 'MurNAc': 56, 'MurNGc': 57,
|
| 419 |
+
'Api': 58, 'Erwiniose': 59, 'Yer': 60, 'Thre': 61,
|
| 420 |
+
# Add more as needed, up to ~70
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class ResidueTypeEmbeddings(nn.Module):
|
| 425 |
+
"""
|
| 426 |
+
Learnable embeddings for monosaccharide types.
|
| 427 |
+
|
| 428 |
+
Instead of the model having to learn that 'a1221m' = Fucose from character patterns,
|
| 429 |
+
we explicitly add a Fucose embedding to all tokens belonging to that residue.
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
def __init__(self, hidden_size: int, num_mono_types: int = 70):
|
| 433 |
+
super().__init__()
|
| 434 |
+
self.mono_embeddings = nn.Embedding(num_mono_types, hidden_size)
|
| 435 |
+
self.mono_vocab = MONOSACCHARIDE_VOCAB
|
| 436 |
+
self.hidden_size = hidden_size
|
| 437 |
+
|
| 438 |
+
def forward(
|
| 439 |
+
self,
|
| 440 |
+
token_embeddings: torch.Tensor, # (batch, seq_len, hidden)
|
| 441 |
+
residue_ids: torch.Tensor, # (batch, seq_len)
|
| 442 |
+
mono_type_ids: torch.Tensor = None, # (batch, max_residues) - monosaccharide type per residue
|
| 443 |
+
) -> torch.Tensor:
|
| 444 |
+
"""
|
| 445 |
+
Add residue type embeddings to token embeddings.
|
| 446 |
+
|
| 447 |
+
Args:
|
| 448 |
+
token_embeddings: Base token embeddings
|
| 449 |
+
residue_ids: Which residue each token belongs to (-1 for special tokens)
|
| 450 |
+
mono_type_ids: Monosaccharide type ID for each residue position
|
| 451 |
+
|
| 452 |
+
Returns:
|
| 453 |
+
Enhanced embeddings with residue type information
|
| 454 |
+
"""
|
| 455 |
+
if mono_type_ids is None:
|
| 456 |
+
return token_embeddings
|
| 457 |
+
|
| 458 |
+
batch_size, seq_len, _ = token_embeddings.shape
|
| 459 |
+
enhanced = token_embeddings.clone()
|
| 460 |
+
|
| 461 |
+
# Add mono type embedding to each token based on its residue
|
| 462 |
+
for b in range(batch_size):
|
| 463 |
+
for pos in range(seq_len):
|
| 464 |
+
rid = residue_ids[b, pos].item()
|
| 465 |
+
if rid >= 0 and rid < mono_type_ids.size(1):
|
| 466 |
+
mono_id = mono_type_ids[b, rid]
|
| 467 |
+
enhanced[b, pos] = enhanced[b, pos] + self.mono_embeddings(mono_id)
|
| 468 |
+
|
| 469 |
+
return enhanced
|
| 470 |
+
|
| 471 |
+
@staticmethod
|
| 472 |
+
def get_mono_type_id(mono_name: str) -> int:
|
| 473 |
+
"""Convert monosaccharide name to type ID."""
|
| 474 |
+
return MONOSACCHARIDE_VOCAB.get(mono_name, MONOSACCHARIDE_VOCAB['[UNK_MONO]'])
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
# =============================================================================
|
| 478 |
+
# Improvement #4: Relative Position Encoding for Glycan Trees
|
| 479 |
+
# =============================================================================
|
| 480 |
+
|
| 481 |
+
class RelativePositionBias(nn.Module):
|
| 482 |
+
"""
|
| 483 |
+
Compute relative position bias for attention based on residue IDs.
|
| 484 |
+
|
| 485 |
+
Tokens in the same residue get distance 0.
|
| 486 |
+
Tokens in adjacent residues get distance ±1.
|
| 487 |
+
This helps the model understand glycan tree structure.
|
| 488 |
+
"""
|
| 489 |
+
|
| 490 |
+
def __init__(self, num_heads: int, max_distance: int = 10):
|
| 491 |
+
super().__init__()
|
| 492 |
+
self.num_heads = num_heads
|
| 493 |
+
self.max_distance = max_distance
|
| 494 |
+
|
| 495 |
+
# Learnable bias for each relative distance (-max to +max)
|
| 496 |
+
num_distances = 2 * max_distance + 1
|
| 497 |
+
self.relative_bias = nn.Embedding(num_distances, num_heads)
|
| 498 |
+
|
| 499 |
+
def forward(self, residue_ids: torch.Tensor) -> torch.Tensor:
|
| 500 |
+
"""
|
| 501 |
+
Compute relative position bias.
|
| 502 |
+
|
| 503 |
+
Args:
|
| 504 |
+
residue_ids: (batch, seq_len)
|
| 505 |
+
|
| 506 |
+
Returns:
|
| 507 |
+
Bias to add to attention scores: (batch, num_heads, seq_len, seq_len)
|
| 508 |
+
"""
|
| 509 |
+
# Compute pairwise residue distances
|
| 510 |
+
# (batch, seq_len, 1) - (batch, 1, seq_len) = (batch, seq_len, seq_len)
|
| 511 |
+
distance = residue_ids.unsqueeze(2) - residue_ids.unsqueeze(1)
|
| 512 |
+
|
| 513 |
+
# Clamp to max distance range and shift to 0-indexed
|
| 514 |
+
distance_clamped = distance.clamp(-self.max_distance, self.max_distance)
|
| 515 |
+
distance_idx = distance_clamped + self.max_distance # Now 0 to 2*max_distance
|
| 516 |
+
|
| 517 |
+
# Look up bias: (batch, seq_len, seq_len, num_heads)
|
| 518 |
+
bias = self.relative_bias(distance_idx)
|
| 519 |
+
|
| 520 |
+
# Transpose to (batch, num_heads, seq_len, seq_len)
|
| 521 |
+
bias = bias.permute(0, 3, 1, 2)
|
| 522 |
+
|
| 523 |
+
return bias
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
class CrossAttentionLayer(nn.Module):
|
| 527 |
+
"""
|
| 528 |
+
Cross-attention layer for sequence-structure alignment.
|
| 529 |
+
|
| 530 |
+
Allows sequence tokens to attend to structural atoms using attention masks.
|
| 531 |
+
"""
|
| 532 |
+
|
| 533 |
+
def __init__(self, config: MultimodalGlycanBERTConfig):
|
| 534 |
+
super().__init__()
|
| 535 |
+
self.num_heads = config.cross_attn_num_heads
|
| 536 |
+
self.hidden_size = config.seq_hidden_size
|
| 537 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 538 |
+
|
| 539 |
+
assert self.hidden_size % self.num_heads == 0, "hidden_size must be divisible by num_heads"
|
| 540 |
+
|
| 541 |
+
# Query from sequence, Key/Value from structure (VQ-VAE tokens)
|
| 542 |
+
self.query = nn.Linear(config.seq_hidden_size, self.hidden_size)
|
| 543 |
+
self.key = nn.Linear(config.struct_hidden_size, self.hidden_size)
|
| 544 |
+
self.value = nn.Linear(config.struct_hidden_size, self.hidden_size)
|
| 545 |
+
|
| 546 |
+
self.output = nn.Linear(self.hidden_size, config.seq_hidden_size)
|
| 547 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 548 |
+
self.layer_norm = nn.LayerNorm(config.seq_hidden_size, eps=config.layer_norm_eps)
|
| 549 |
+
|
| 550 |
+
def forward(
|
| 551 |
+
self,
|
| 552 |
+
seq_hidden: torch.Tensor, # (batch, seq_len, seq_hidden)
|
| 553 |
+
struct_hidden: torch.Tensor, # (batch, struct_len, struct_hidden)
|
| 554 |
+
attention_mask: Optional[torch.Tensor] = None, # (batch, seq_len, struct_len)
|
| 555 |
+
) -> torch.Tensor:
|
| 556 |
+
"""
|
| 557 |
+
Apply cross-attention from sequence to structure.
|
| 558 |
+
|
| 559 |
+
Args:
|
| 560 |
+
seq_hidden: Sequence hidden states
|
| 561 |
+
struct_hidden: Structure hidden states
|
| 562 |
+
attention_mask: Boolean mask (True = can attend, False = cannot attend)
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
Updated sequence hidden states
|
| 566 |
+
"""
|
| 567 |
+
batch_size, seq_len, _ = seq_hidden.shape
|
| 568 |
+
struct_len = struct_hidden.shape[1]
|
| 569 |
+
|
| 570 |
+
# Project to Q, K, V
|
| 571 |
+
Q = self.query(seq_hidden) # (batch, seq_len, hidden)
|
| 572 |
+
K = self.key(struct_hidden) # (batch, struct_len, hidden)
|
| 573 |
+
V = self.value(struct_hidden) # (batch, struct_len, hidden)
|
| 574 |
+
|
| 575 |
+
# Reshape for multi-head attention
|
| 576 |
+
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch, heads, seq_len, head_dim)
|
| 577 |
+
K = K.view(batch_size, struct_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch, heads, struct_len, head_dim)
|
| 578 |
+
V = V.view(batch_size, struct_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch, heads, struct_len, head_dim)
|
| 579 |
+
|
| 580 |
+
# Compute attention scores
|
| 581 |
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) # (batch, heads, seq_len, struct_len)
|
| 582 |
+
|
| 583 |
+
# Apply attention mask
|
| 584 |
+
if attention_mask is not None:
|
| 585 |
+
# attention_mask: (batch, seq_len, struct_len) -> (batch, 1, seq_len, struct_len)
|
| 586 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 587 |
+
# Convert boolean mask to float: True -> 0.0, False -> -10000.0
|
| 588 |
+
attention_mask = (~attention_mask).float() * -10000.0
|
| 589 |
+
scores = scores + attention_mask
|
| 590 |
+
|
| 591 |
+
# Softmax and dropout
|
| 592 |
+
attn_weights = torch.softmax(scores, dim=-1) # (batch, heads, seq_len, struct_len)
|
| 593 |
+
attn_weights = self.dropout(attn_weights)
|
| 594 |
+
|
| 595 |
+
# Apply attention to values
|
| 596 |
+
context = torch.matmul(attn_weights, V) # (batch, heads, seq_len, head_dim)
|
| 597 |
+
|
| 598 |
+
# Reshape back
|
| 599 |
+
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
|
| 600 |
+
|
| 601 |
+
# Output projection
|
| 602 |
+
output = self.output(context)
|
| 603 |
+
output = self.dropout(output)
|
| 604 |
+
|
| 605 |
+
# Residual connection + layer norm
|
| 606 |
+
output = self.layer_norm(seq_hidden + output)
|
| 607 |
+
|
| 608 |
+
return output
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
class MultimodalGlycanBERT(nn.Module):
|
| 612 |
+
"""
|
| 613 |
+
Multimodal BERT for glycan representation learning (v3).
|
| 614 |
+
|
| 615 |
+
Architecture:
|
| 616 |
+
1. Separate encoders for each modality (sequence, MS, 3D structure)
|
| 617 |
+
2. Cross-attention for sequence-structure alignment
|
| 618 |
+
3. Modality-specific MLM heads
|
| 619 |
+
4. Fusion layer for combined representation
|
| 620 |
+
"""
|
| 621 |
+
|
| 622 |
+
def __init__(self, config: MultimodalGlycanBERTConfig):
|
| 623 |
+
super().__init__()
|
| 624 |
+
self.config = config
|
| 625 |
+
|
| 626 |
+
# ===== Sequence Encoder =====
|
| 627 |
+
seq_config = config.to_seq_config()
|
| 628 |
+
seq_config.cnn_kernel_size = config.cnn_kernel_size
|
| 629 |
+
|
| 630 |
+
if config.use_cnn_frontend:
|
| 631 |
+
print(f"✅ Enabled Convolutional Front-End (Kernel={config.cnn_kernel_size})")
|
| 632 |
+
self.seq_embeddings = ConvGlycanBERTEmbeddings(seq_config)
|
| 633 |
+
else:
|
| 634 |
+
self.seq_embeddings = GlycanBERTEmbeddings(seq_config)
|
| 635 |
+
self.seq_layers = nn.ModuleList([GlycanBERTLayer(seq_config) for _ in range(seq_config.num_hidden_layers)])
|
| 636 |
+
self.seq_mlm_head = nn.Linear(seq_config.hidden_size, seq_config.vocab_size)
|
| 637 |
+
|
| 638 |
+
# ===== MS Encoder =====
|
| 639 |
+
ms_config = config.to_ms_config()
|
| 640 |
+
self.ms_embeddings = GlycanBERTEmbeddings(ms_config)
|
| 641 |
+
self.ms_layers = nn.ModuleList([GlycanBERTLayer(ms_config) for _ in range(ms_config.num_hidden_layers)])
|
| 642 |
+
self.ms_mlm_head = nn.Linear(ms_config.hidden_size, ms_config.vocab_size)
|
| 643 |
+
|
| 644 |
+
# ===== Structure Encoder (VQ-VAE tokens) =====
|
| 645 |
+
if config.use_3d:
|
| 646 |
+
struct_config = config.to_struct_config()
|
| 647 |
+
self.struct_embeddings = GlycanBERTEmbeddings(struct_config)
|
| 648 |
+
self.struct_layers = nn.ModuleList([GlycanBERTLayer(struct_config) for _ in range(struct_config.num_hidden_layers)])
|
| 649 |
+
self.struct_mlm_head = nn.Linear(struct_config.hidden_size, struct_config.vocab_size)
|
| 650 |
+
|
| 651 |
+
# Cross-attention layer (sequence → VQ-VAE structural tokens)
|
| 652 |
+
if config.use_cross_attention:
|
| 653 |
+
self.cross_attention = CrossAttentionLayer(config)
|
| 654 |
+
|
| 655 |
+
# ===== Projection layers (align hidden sizes) =====
|
| 656 |
+
if config.ms_hidden_size != config.seq_hidden_size:
|
| 657 |
+
self.ms_projection = nn.Linear(config.ms_hidden_size, config.seq_hidden_size)
|
| 658 |
+
else:
|
| 659 |
+
self.ms_projection = nn.Identity()
|
| 660 |
+
|
| 661 |
+
if config.use_3d and config.struct_hidden_size != config.seq_hidden_size:
|
| 662 |
+
self.struct_projection = nn.Linear(config.struct_hidden_size, config.seq_hidden_size)
|
| 663 |
+
else:
|
| 664 |
+
self.struct_projection = nn.Identity()
|
| 665 |
+
|
| 666 |
+
# ===== Fusion Layer =====
|
| 667 |
+
# Concatenate seq + ms + struct
|
| 668 |
+
fusion_input_size = config.seq_hidden_size * (3 if config.use_3d else 2)
|
| 669 |
+
self.fusion_layer = nn.Sequential(
|
| 670 |
+
nn.Linear(fusion_input_size, config.fusion_hidden_size),
|
| 671 |
+
nn.LayerNorm(config.fusion_hidden_size, eps=config.layer_norm_eps),
|
| 672 |
+
nn.GELU(),
|
| 673 |
+
nn.Dropout(config.hidden_dropout_prob),
|
| 674 |
+
nn.Linear(config.fusion_hidden_size, config.fusion_hidden_size),
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
# ===== Distance Prediction Head (Topology) =====
|
| 678 |
+
# OPTIMIZED: Project down to 128 dim first to save GPU memory
|
| 679 |
+
# (Batch, 256, 256, 768) -> (Batch, 256, 256, 128) reduces memory by 6x
|
| 680 |
+
self.dist_proj = nn.Linear(config.seq_hidden_size, 128)
|
| 681 |
+
self.distance_head = nn.Sequential(
|
| 682 |
+
nn.Linear(128, 64),
|
| 683 |
+
nn.ReLU(),
|
| 684 |
+
nn.Linear(64, 1)
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
# Initialize weights
|
| 688 |
+
self.apply(self._init_weights)
|
| 689 |
+
|
| 690 |
+
def _init_weights(self, module):
|
| 691 |
+
"""Initialize weights."""
|
| 692 |
+
if isinstance(module, nn.Linear):
|
| 693 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 694 |
+
if module.bias is not None:
|
| 695 |
+
module.bias.data.zero_()
|
| 696 |
+
elif isinstance(module, nn.Embedding):
|
| 697 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 698 |
+
if module.padding_idx is not None:
|
| 699 |
+
module.weight.data[module.padding_idx].zero_()
|
| 700 |
+
elif isinstance(module, nn.LayerNorm):
|
| 701 |
+
module.bias.data.zero_()
|
| 702 |
+
module.weight.data.fill_(1.0)
|
| 703 |
+
|
| 704 |
+
def forward(
|
| 705 |
+
self,
|
| 706 |
+
seq_token_ids: torch.Tensor,
|
| 707 |
+
seq_attention_mask: torch.Tensor,
|
| 708 |
+
seq_residue_ids: torch.Tensor,
|
| 709 |
+
seq_branch_depths: Optional[torch.Tensor] = None, # NEW: Branch depths
|
| 710 |
+
seq_linkage_types: Optional[torch.Tensor] = None, # NEW: Linkage types
|
| 711 |
+
ms_token_ids: torch.Tensor = None,
|
| 712 |
+
ms_attention_mask: torch.Tensor = None,
|
| 713 |
+
has_ms: torch.Tensor = None,
|
| 714 |
+
struct_token_ids: Optional[torch.Tensor] = None,
|
| 715 |
+
struct_attention_mask: Optional[torch.Tensor] = None,
|
| 716 |
+
struct_residue_ids: Optional[torch.Tensor] = None,
|
| 717 |
+
has_3d: Optional[torch.Tensor] = None,
|
| 718 |
+
seq_labels: Optional[torch.Tensor] = None,
|
| 719 |
+
ms_labels: Optional[torch.Tensor] = None,
|
| 720 |
+
struct_labels: Optional[torch.Tensor] = None,
|
| 721 |
+
dist_labels: Optional[torch.Tensor] = None, # NEW: Topology distance labels
|
| 722 |
+
return_dict: bool = True,
|
| 723 |
+
) -> Dict[str, torch.Tensor]:
|
| 724 |
+
"""
|
| 725 |
+
Forward pass for multimodal BERT v3.
|
| 726 |
+
|
| 727 |
+
Args:
|
| 728 |
+
seq_token_ids: (batch_size, seq_len) - Sequence token IDs
|
| 729 |
+
seq_attention_mask: (batch_size, seq_len) - Sequence attention mask
|
| 730 |
+
seq_residue_ids: (batch_size, seq_len) - Sequence token residue IDs
|
| 731 |
+
ms_token_ids: (batch_size, ms_len) - MS token IDs
|
| 732 |
+
ms_attention_mask: (batch_size, ms_len) - MS attention mask
|
| 733 |
+
has_ms: (batch_size,) - Boolean mask for samples with MS data
|
| 734 |
+
struct_token_ids: (batch_size, struct_len) - Structure VQ-VAE token IDs (optional)
|
| 735 |
+
struct_attention_mask: (batch_size, struct_len) - Structure attention mask (optional)
|
| 736 |
+
struct_residue_ids: (batch_size, struct_len) - Structure token residue IDs (optional)
|
| 737 |
+
has_3d: (batch_size,) - Boolean mask for samples with 3D data (optional)
|
| 738 |
+
seq_labels: (batch_size, seq_len) - Masked sequence labels (optional)
|
| 739 |
+
ms_labels: (batch_size, ms_len) - Masked MS labels (optional)
|
| 740 |
+
struct_labels: (batch_size, struct_len) - Masked structure labels (optional)
|
| 741 |
+
return_dict: Whether to return dict or tuple
|
| 742 |
+
|
| 743 |
+
Returns:
|
| 744 |
+
Dictionary containing logits, hidden states, losses, etc.
|
| 745 |
+
"""
|
| 746 |
+
batch_size = seq_token_ids.shape[0]
|
| 747 |
+
device = seq_token_ids.device
|
| 748 |
+
|
| 749 |
+
# ===== Sequence Encoder =====
|
| 750 |
+
# Pass branch_depths and linkage_types to embeddings for tree-aware encoding
|
| 751 |
+
seq_hidden = self.seq_embeddings(seq_token_ids, seq_branch_depths, seq_linkage_types)
|
| 752 |
+
for layer in self.seq_layers:
|
| 753 |
+
seq_hidden = layer(seq_hidden, seq_attention_mask)
|
| 754 |
+
|
| 755 |
+
seq_pooled = seq_hidden[:, 0, :] # [CLS] token
|
| 756 |
+
seq_logits = self.seq_mlm_head(seq_hidden)
|
| 757 |
+
|
| 758 |
+
# ===== Distance Predictions (Topology) =====
|
| 759 |
+
# Compute pairwise distance predictions
|
| 760 |
+
# MEMORY OPTIMIZATION: Project to 128-dim first
|
| 761 |
+
seq_hidden_small = self.dist_proj(seq_hidden) # (batch, seq_len, 128)
|
| 762 |
+
|
| 763 |
+
# Expand for pairwise: (batch, seq_len, 1, 128) - (batch, 1, seq_len, 128)
|
| 764 |
+
h_i = seq_hidden_small.unsqueeze(2)
|
| 765 |
+
h_j = seq_hidden_small.unsqueeze(1)
|
| 766 |
+
h_diff = torch.abs(h_i - h_j) # (batch, seq_len, seq_len, 128) - Much smaller!
|
| 767 |
+
dist_predictions = self.distance_head(h_diff) # (batch, seq_len, seq_len, 1)
|
| 768 |
+
|
| 769 |
+
# ===== MS Encoder =====
|
| 770 |
+
ms_hidden = None
|
| 771 |
+
ms_pooled = None
|
| 772 |
+
ms_logits = None
|
| 773 |
+
|
| 774 |
+
if ms_token_ids is not None:
|
| 775 |
+
ms_hidden = self.ms_embeddings(ms_token_ids)
|
| 776 |
+
for layer in self.ms_layers:
|
| 777 |
+
ms_hidden = layer(ms_hidden, ms_attention_mask)
|
| 778 |
+
|
| 779 |
+
ms_pooled = ms_hidden[:, 0, :] # [CLS] token
|
| 780 |
+
ms_logits = self.ms_mlm_head(ms_hidden)
|
| 781 |
+
|
| 782 |
+
# Zero out MS representations for samples without MS data
|
| 783 |
+
if has_ms is not None:
|
| 784 |
+
has_ms_expanded = has_ms.unsqueeze(1).float() # (batch, 1)
|
| 785 |
+
ms_pooled = ms_pooled * has_ms_expanded
|
| 786 |
+
|
| 787 |
+
# ===== Structure Encoder =====
|
| 788 |
+
struct_pooled = None
|
| 789 |
+
struct_logits = None
|
| 790 |
+
struct_hidden = None
|
| 791 |
+
|
| 792 |
+
if self.config.use_3d and struct_token_ids is not None:
|
| 793 |
+
struct_hidden = self.struct_embeddings(struct_token_ids)
|
| 794 |
+
for layer in self.struct_layers:
|
| 795 |
+
struct_hidden = layer(struct_hidden, struct_attention_mask)
|
| 796 |
+
|
| 797 |
+
struct_pooled = struct_hidden[:, 0, :] # [CLS] token
|
| 798 |
+
struct_logits = self.struct_mlm_head(struct_hidden)
|
| 799 |
+
|
| 800 |
+
# Zero out structure representations for samples without 3D data
|
| 801 |
+
if has_3d is not None:
|
| 802 |
+
has_3d_expanded = has_3d.unsqueeze(1).float() # (batch, 1)
|
| 803 |
+
struct_pooled = struct_pooled * has_3d_expanded
|
| 804 |
+
|
| 805 |
+
# ===== Cross-Attention (Sequence → VQ-VAE Structural Tokens) =====
|
| 806 |
+
# Use residue-level alignment between WURCS tokens and VQ-VAE tokens
|
| 807 |
+
if self.config.use_cross_attention and struct_residue_ids is not None:
|
| 808 |
+
# Create residue-level mask
|
| 809 |
+
# WURCS token with residue_id=0 → VQ-VAE tokens with residue_id=0
|
| 810 |
+
residue_mask = create_residue_level_mask(
|
| 811 |
+
seq_residue_ids=seq_residue_ids,
|
| 812 |
+
struct_residue_ids=struct_residue_ids,
|
| 813 |
+
) # (batch, N_seq, N_struct)
|
| 814 |
+
|
| 815 |
+
# Apply cross-attention: sequence tokens attend to VQ-VAE tokens
|
| 816 |
+
seq_hidden = self.cross_attention(
|
| 817 |
+
seq_hidden=seq_hidden,
|
| 818 |
+
struct_hidden=struct_hidden, # VQ-VAE token features
|
| 819 |
+
attention_mask=residue_mask, # Residue-based mask
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
# Update seq_pooled after cross-attention
|
| 823 |
+
seq_pooled = seq_hidden[:, 0, :]
|
| 824 |
+
|
| 825 |
+
# ===== Fusion =====
|
| 826 |
+
# Project to common hidden size
|
| 827 |
+
ms_pooled_projected = self.ms_projection(ms_pooled)
|
| 828 |
+
|
| 829 |
+
if self.config.use_3d and struct_pooled is not None:
|
| 830 |
+
struct_pooled_projected = self.struct_projection(struct_pooled)
|
| 831 |
+
combined = torch.cat([seq_pooled, ms_pooled_projected, struct_pooled_projected], dim=-1)
|
| 832 |
+
else:
|
| 833 |
+
combined = torch.cat([seq_pooled, ms_pooled_projected], dim=-1)
|
| 834 |
+
|
| 835 |
+
fused_repr = self.fusion_layer(combined)
|
| 836 |
+
|
| 837 |
+
# ===== Compute Losses =====
|
| 838 |
+
total_loss = None
|
| 839 |
+
seq_loss = None
|
| 840 |
+
ms_loss = None
|
| 841 |
+
struct_loss = None
|
| 842 |
+
dist_loss = None # NEW: Topology distance loss
|
| 843 |
+
|
| 844 |
+
if seq_labels is not None:
|
| 845 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| 846 |
+
seq_loss = loss_fct(
|
| 847 |
+
seq_logits.view(-1, self.config.seq_vocab_size),
|
| 848 |
+
seq_labels.view(-1)
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
if ms_labels is not None:
|
| 852 |
+
ms_labels_masked = ms_labels.clone()
|
| 853 |
+
ms_labels_masked[~has_ms] = -100
|
| 854 |
+
# Only compute loss if there are valid labels (not all -100)
|
| 855 |
+
if (ms_labels_masked != -100).any():
|
| 856 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| 857 |
+
ms_loss = loss_fct(
|
| 858 |
+
ms_logits.view(-1, self.config.ms_total_vocab_size),
|
| 859 |
+
ms_labels_masked.view(-1)
|
| 860 |
+
)
|
| 861 |
+
else:
|
| 862 |
+
ms_loss = torch.tensor(0.0, device=seq_token_ids.device)
|
| 863 |
+
|
| 864 |
+
if self.config.use_3d and struct_labels is not None and struct_logits is not None:
|
| 865 |
+
struct_labels_masked = struct_labels.clone()
|
| 866 |
+
if has_3d is not None:
|
| 867 |
+
struct_labels_masked[~has_3d] = -100
|
| 868 |
+
# Only compute loss if there are valid labels (not all -100)
|
| 869 |
+
if (struct_labels_masked != -100).any():
|
| 870 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| 871 |
+
struct_loss = loss_fct(
|
| 872 |
+
struct_logits.view(-1, self.config.struct_vocab_size),
|
| 873 |
+
struct_labels_masked.view(-1)
|
| 874 |
+
)
|
| 875 |
+
else:
|
| 876 |
+
struct_loss = torch.tensor(0.0, device=seq_token_ids.device)
|
| 877 |
+
|
| 878 |
+
# ===== Distance Loss (Topology) =====
|
| 879 |
+
if dist_labels is not None:
|
| 880 |
+
# dist_predictions: (Batch, Seq, Seq, 1) -> (Batch, Seq, Seq)
|
| 881 |
+
preds = dist_predictions.squeeze(-1)
|
| 882 |
+
|
| 883 |
+
# Create mask for valid distance pairs (label != -1)
|
| 884 |
+
# Also respect attention mask to avoid padding
|
| 885 |
+
valid_mask = (dist_labels != -1) & (seq_attention_mask.unsqueeze(1) * seq_attention_mask.unsqueeze(2) == 1)
|
| 886 |
+
|
| 887 |
+
# DEBUG: Print once
|
| 888 |
+
if not hasattr(self, '_dist_debug_printed'):
|
| 889 |
+
print(f"[DIST DEBUG] dist_labels shape: {dist_labels.shape}, valid_mask.sum: {valid_mask.sum().item()}")
|
| 890 |
+
self._dist_debug_printed = True
|
| 891 |
+
|
| 892 |
+
if valid_mask.sum() > 0:
|
| 893 |
+
# MSE loss on valid positions only
|
| 894 |
+
loss_fct = nn.MSELoss()
|
| 895 |
+
dist_loss = loss_fct(preds[valid_mask], dist_labels[valid_mask].float())
|
| 896 |
+
else:
|
| 897 |
+
dist_loss = torch.tensor(0.0, device=seq_token_ids.device)
|
| 898 |
+
else:
|
| 899 |
+
# DEBUG: dist_labels is None
|
| 900 |
+
if not hasattr(self, '_dist_none_printed'):
|
| 901 |
+
print("[DIST DEBUG] dist_labels is None!")
|
| 902 |
+
self._dist_none_printed = True
|
| 903 |
+
|
| 904 |
+
# Weighted combination
|
| 905 |
+
losses = []
|
| 906 |
+
if seq_loss is not None:
|
| 907 |
+
losses.append(self.config.seq_loss_weight * seq_loss)
|
| 908 |
+
if ms_loss is not None:
|
| 909 |
+
losses.append(self.config.ms_loss_weight * ms_loss)
|
| 910 |
+
if struct_loss is not None:
|
| 911 |
+
losses.append(self.config.struct_loss_weight * struct_loss)
|
| 912 |
+
if dist_loss is not None:
|
| 913 |
+
losses.append(self.config.dist_loss_weight * dist_loss)
|
| 914 |
+
|
| 915 |
+
if losses:
|
| 916 |
+
total_loss = sum(losses)
|
| 917 |
+
|
| 918 |
+
if return_dict:
|
| 919 |
+
return {
|
| 920 |
+
'loss': total_loss,
|
| 921 |
+
'seq_loss': seq_loss,
|
| 922 |
+
'ms_loss': ms_loss,
|
| 923 |
+
'struct_loss': struct_loss,
|
| 924 |
+
'dist_loss': dist_loss, # NEW: Topology loss
|
| 925 |
+
'seq_logits': seq_logits,
|
| 926 |
+
'ms_logits': ms_logits,
|
| 927 |
+
'struct_logits': struct_logits,
|
| 928 |
+
'dist_predictions': dist_predictions, # NEW: Distance predictions
|
| 929 |
+
'seq_hidden': seq_hidden,
|
| 930 |
+
'ms_hidden': ms_hidden,
|
| 931 |
+
'struct_hidden': struct_hidden,
|
| 932 |
+
'seq_pooled': seq_pooled,
|
| 933 |
+
'ms_pooled': ms_pooled,
|
| 934 |
+
'struct_pooled': struct_pooled,
|
| 935 |
+
'fused_repr': fused_repr,
|
| 936 |
+
}
|
| 937 |
+
else:
|
| 938 |
+
return (total_loss, seq_logits, ms_logits, struct_logits, fused_repr)
|
| 939 |
+
|
| 940 |
+
def get_multimodal_representation(
|
| 941 |
+
self,
|
| 942 |
+
seq_token_ids: torch.Tensor,
|
| 943 |
+
seq_attention_mask: torch.Tensor,
|
| 944 |
+
seq_residue_ids: torch.Tensor,
|
| 945 |
+
ms_token_ids: torch.Tensor,
|
| 946 |
+
ms_attention_mask: torch.Tensor,
|
| 947 |
+
has_ms: torch.Tensor,
|
| 948 |
+
struct_token_ids: Optional[torch.Tensor] = None,
|
| 949 |
+
struct_attention_mask: Optional[torch.Tensor] = None,
|
| 950 |
+
struct_residue_ids: Optional[torch.Tensor] = None,
|
| 951 |
+
has_3d: Optional[torch.Tensor] = None,
|
| 952 |
+
) -> torch.Tensor:
|
| 953 |
+
"""Get fused multimodal representation (for inference)."""
|
| 954 |
+
outputs = self.forward(
|
| 955 |
+
seq_token_ids=seq_token_ids,
|
| 956 |
+
seq_attention_mask=seq_attention_mask,
|
| 957 |
+
seq_residue_ids=seq_residue_ids,
|
| 958 |
+
ms_token_ids=ms_token_ids,
|
| 959 |
+
ms_attention_mask=ms_attention_mask,
|
| 960 |
+
has_ms=has_ms,
|
| 961 |
+
struct_token_ids=struct_token_ids,
|
| 962 |
+
struct_attention_mask=struct_attention_mask,
|
| 963 |
+
struct_residue_ids=struct_residue_ids,
|
| 964 |
+
has_3d=has_3d,
|
| 965 |
+
return_dict=True,
|
| 966 |
+
)
|
| 967 |
+
return outputs['fused_repr']
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
if __name__ == "__main__":
|
| 971 |
+
# Test the model
|
| 972 |
+
print("="*80)
|
| 973 |
+
print("Testing Multimodal GlycanBERT v3")
|
| 974 |
+
print("="*80)
|
| 975 |
+
|
| 976 |
+
# Create config
|
| 977 |
+
config = MultimodalGlycanBERTConfig(
|
| 978 |
+
seq_vocab_size=166,
|
| 979 |
+
seq_hidden_size=768,
|
| 980 |
+
seq_num_layers=12,
|
| 981 |
+
seq_num_heads=12,
|
| 982 |
+
ms_vocab_size=242,
|
| 983 |
+
ms_hidden_size=384,
|
| 984 |
+
ms_num_layers=6,
|
| 985 |
+
ms_num_heads=6,
|
| 986 |
+
struct_vocab_size=1024,
|
| 987 |
+
struct_hidden_size=512,
|
| 988 |
+
struct_num_layers=8,
|
| 989 |
+
struct_num_heads=8,
|
| 990 |
+
use_3d=True,
|
| 991 |
+
use_cross_attention=True,
|
| 992 |
+
seq_loss_weight=0.60,
|
| 993 |
+
ms_loss_weight=0.15,
|
| 994 |
+
struct_loss_weight=0.25,
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
print(f"\nConfig:")
|
| 998 |
+
print(f" Sequence vocab: {config.seq_vocab_size}")
|
| 999 |
+
print(f" MS vocab: {config.ms_vocab_size}")
|
| 1000 |
+
print(f" Structure vocab: {config.struct_vocab_size}")
|
| 1001 |
+
print(f" Loss weights: seq={config.seq_loss_weight}, ms={config.ms_loss_weight}, struct={config.struct_loss_weight}")
|
| 1002 |
+
|
| 1003 |
+
# Create model
|
| 1004 |
+
model = MultimodalGlycanBERT(config)
|
| 1005 |
+
|
| 1006 |
+
# Count parameters
|
| 1007 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 1008 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 1009 |
+
|
| 1010 |
+
print(f"\nModel Parameters:")
|
| 1011 |
+
print(f" Total: {total_params:,}")
|
| 1012 |
+
print(f" Trainable: {trainable_params:,}")
|
| 1013 |
+
|
| 1014 |
+
# Test forward pass
|
| 1015 |
+
print(f"\n{'='*80}")
|
| 1016 |
+
print("Testing Forward Pass (with Conv front-end)")
|
| 1017 |
+
print("="*80)
|
| 1018 |
+
|
| 1019 |
+
batch_size = 4
|
| 1020 |
+
seq_len = 128
|
| 1021 |
+
ms_len = 50
|
| 1022 |
+
struct_len = 40
|
| 1023 |
+
|
| 1024 |
+
# Create dummy inputs
|
| 1025 |
+
seq_token_ids = torch.randint(0, config.seq_vocab_size, (batch_size, seq_len))
|
| 1026 |
+
seq_attention_mask = torch.ones(batch_size, seq_len)
|
| 1027 |
+
# Approximate: ~5 tokens per residue
|
| 1028 |
+
seq_residue_ids = torch.div(
|
| 1029 |
+
torch.arange(seq_len), 5, rounding_mode="floor"
|
| 1030 |
+
).unsqueeze(0).expand(batch_size, -1)
|
| 1031 |
+
|
| 1032 |
+
ms_token_ids = torch.randint(config.ms_vocab_offset, config.ms_total_vocab_size, (batch_size, ms_len))
|
| 1033 |
+
ms_attention_mask = torch.ones(batch_size, ms_len)
|
| 1034 |
+
struct_token_ids = torch.randint(0, config.struct_vocab_size, (batch_size, struct_len))
|
| 1035 |
+
struct_attention_mask = torch.ones(batch_size, struct_len)
|
| 1036 |
+
# Approximate: 4 tokens per residue for VQ-VAE tokens
|
| 1037 |
+
struct_residue_ids = torch.div(
|
| 1038 |
+
torch.arange(struct_len), 4, rounding_mode="floor"
|
| 1039 |
+
).unsqueeze(0).expand(batch_size, -1)
|
| 1040 |
+
|
| 1041 |
+
has_ms = torch.tensor([True, True, False, True])
|
| 1042 |
+
has_3d = torch.tensor([True, False, True, True])
|
| 1043 |
+
|
| 1044 |
+
# Create labels for MLM
|
| 1045 |
+
seq_labels = seq_token_ids.clone()
|
| 1046 |
+
seq_labels[seq_labels != config.mask_token_id] = -100
|
| 1047 |
+
ms_labels = ms_token_ids.clone()
|
| 1048 |
+
ms_labels[ms_labels != config.mask_token_id] = -100
|
| 1049 |
+
struct_labels = struct_token_ids.clone()
|
| 1050 |
+
struct_labels[struct_labels != config.mask_token_id] = -100
|
| 1051 |
+
|
| 1052 |
+
# Forward pass
|
| 1053 |
+
outputs = model(
|
| 1054 |
+
seq_token_ids=seq_token_ids,
|
| 1055 |
+
seq_attention_mask=seq_attention_mask,
|
| 1056 |
+
seq_residue_ids=seq_residue_ids,
|
| 1057 |
+
ms_token_ids=ms_token_ids,
|
| 1058 |
+
ms_attention_mask=ms_attention_mask,
|
| 1059 |
+
has_ms=has_ms,
|
| 1060 |
+
struct_token_ids=struct_token_ids,
|
| 1061 |
+
struct_attention_mask=struct_attention_mask,
|
| 1062 |
+
struct_residue_ids=struct_residue_ids,
|
| 1063 |
+
has_3d=has_3d,
|
| 1064 |
+
seq_labels=seq_labels,
|
| 1065 |
+
ms_labels=ms_labels,
|
| 1066 |
+
struct_labels=struct_labels,
|
| 1067 |
+
)
|
| 1068 |
+
|
| 1069 |
+
print(f"\nOutput shapes:")
|
| 1070 |
+
print(f" seq_logits: {outputs['seq_logits'].shape}")
|
| 1071 |
+
print(f" ms_logits: {outputs['ms_logits'].shape}")
|
| 1072 |
+
print(f" struct_logits: {outputs['struct_logits'].shape}")
|
| 1073 |
+
print(f" fused_repr: {outputs['fused_repr'].shape}")
|
| 1074 |
+
|
| 1075 |
+
print(f"\nLosses:")
|
| 1076 |
+
print(f" Total loss: {outputs['loss'].item():.4f}")
|
| 1077 |
+
print(f" Sequence loss: {outputs['seq_loss'].item():.4f}")
|
| 1078 |
+
print(f" MS loss: {outputs['ms_loss'].item():.4f}")
|
| 1079 |
+
print(f" Structure loss: {outputs['struct_loss'].item():.4f}")
|
| 1080 |
+
|
| 1081 |
+
print(f"\n{'='*80}")
|
| 1082 |
+
print("Model Test Complete!")
|
| 1083 |
+
print("="*80)
|
| 1084 |
+
|
src/wurcs_bpe_tokenizer.py
ADDED
|
@@ -0,0 +1,740 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
WURCS-BPE Tokenizer
|
| 4 |
+
|
| 5 |
+
A hybrid tokenizer that learns semantic subwords from WURCS while preserving
|
| 6 |
+
the ability to handle rare/novel glycan structures character-by-character.
|
| 7 |
+
|
| 8 |
+
Key features:
|
| 9 |
+
1. Pre-tokenization: Split WURCS into semantic units (residues, linkages, mods)
|
| 10 |
+
2. BPE: Learn subword merges from corpus
|
| 11 |
+
3. Character fallback: Handle novel structures
|
| 12 |
+
4. Tree embeddings: Preserve branch_depth and linkage_type per token
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
# Train BPE on corpus
|
| 16 |
+
tokenizer = WURCSBPETokenizer.train_from_corpus(
|
| 17 |
+
wurcs_strings,
|
| 18 |
+
num_merges=500,
|
| 19 |
+
output_path="bpe_vocabulary.json"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Tokenize
|
| 23 |
+
result = tokenizer.tokenize(wurcs_string)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import json
|
| 27 |
+
import re
|
| 28 |
+
from collections import Counter, defaultdict
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from typing import Dict, List, Optional, Tuple, Set
|
| 31 |
+
import pickle
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class WURCSPreTokenizer:
|
| 35 |
+
"""
|
| 36 |
+
Pre-tokenize WURCS into semantic units before BPE.
|
| 37 |
+
|
| 38 |
+
WURCS format: WURCS=2.0/count/[residues]/indices/linkages
|
| 39 |
+
|
| 40 |
+
We split into:
|
| 41 |
+
- Residues: [a2122h-1b_1-5_2*NCC/3=O] -> one unit per []
|
| 42 |
+
- Linkages: a4-b1 -> one unit per linkage
|
| 43 |
+
- Special markers: [BRANCH_OPEN], [BRANCH_CLOSE], etc.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
# Residue patterns for common monosaccharides
|
| 47 |
+
RESIDUE_PATTERN = re.compile(r'\[([^\]]+)\]')
|
| 48 |
+
LINKAGE_PATTERN = re.compile(r'([a-z])(\d+|\?)-([a-z])(\d+|\?)')
|
| 49 |
+
|
| 50 |
+
def __init__(self):
|
| 51 |
+
self.special_tokens = {
|
| 52 |
+
'[PAD]': 0,
|
| 53 |
+
'[UNK]': 1,
|
| 54 |
+
'[START]': 2,
|
| 55 |
+
'[END]': 3,
|
| 56 |
+
'[MASK]': 4,
|
| 57 |
+
'[BRANCH_OPEN]': 5,
|
| 58 |
+
'[BRANCH_CLOSE]': 6,
|
| 59 |
+
'[LINK]': 7,
|
| 60 |
+
'[MOD]': 8,
|
| 61 |
+
'[RESIDUE_ERROR]': 9,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
def pre_tokenize(self, wurcs: str) -> List[Dict]:
|
| 65 |
+
"""
|
| 66 |
+
Pre-tokenize WURCS into semantic units.
|
| 67 |
+
|
| 68 |
+
Returns list of dicts with:
|
| 69 |
+
- text: The unit text
|
| 70 |
+
- type: 'special', 'residue', 'linkage', 'mod', 'index'
|
| 71 |
+
- residue_id: Which residue this belongs to (-1 for special, -2 for linkage)
|
| 72 |
+
- branch_depth: Tree depth (computed later)
|
| 73 |
+
"""
|
| 74 |
+
units = []
|
| 75 |
+
|
| 76 |
+
# Add start token
|
| 77 |
+
units.append({
|
| 78 |
+
'text': '[START]',
|
| 79 |
+
'type': 'special',
|
| 80 |
+
'residue_id': -1,
|
| 81 |
+
'branch_depth': 0,
|
| 82 |
+
'linkage_type': 0,
|
| 83 |
+
})
|
| 84 |
+
|
| 85 |
+
# Parse WURCS sections
|
| 86 |
+
if not wurcs.startswith('WURCS='):
|
| 87 |
+
units.append({'text': '[RESIDUE_ERROR]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0})
|
| 88 |
+
units.append({'text': '[END]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0})
|
| 89 |
+
return units
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
parts = self._split_wurcs_sections(wurcs)
|
| 93 |
+
if len(parts) < 4:
|
| 94 |
+
return [{'text': '[ERROR]', 'type': 'special', 'residue_id': -1, 'branch_depth': 0, 'linkage_type': 0}]
|
| 95 |
+
|
| 96 |
+
# parts: WURCS=2.0/3,3,2/[a2122h-1b_1-5][a2122h-1a_1-5][a1122h-1b_1-5]/1-2-3-1/a4-b1_b3-c1_c4-d1
|
| 97 |
+
# section 2: residue definitions
|
| 98 |
+
# section 3: indices
|
| 99 |
+
# section 4: linkages (optional)
|
| 100 |
+
|
| 101 |
+
version = parts[0] # WURCS=2.0
|
| 102 |
+
counts = parts[1] # residue_count,node_count,link_count
|
| 103 |
+
residue_defs = parts[2] # [res1][res2]...
|
| 104 |
+
indices = parts[3] # 1-2-3-1
|
| 105 |
+
linkages = parts[4] if len(parts) > 4 else "" # a4-b1_b3-c1
|
| 106 |
+
|
| 107 |
+
# Parse residue definitions
|
| 108 |
+
residue_list = self.RESIDUE_PATTERN.findall(residue_defs)
|
| 109 |
+
|
| 110 |
+
# Parse linkages to compute branch structure
|
| 111 |
+
linkage_list = linkages.split('_') if linkages else []
|
| 112 |
+
branch_points, residue_depths, linkage_types_map, adj = self._analyze_tree_structure(linkage_list, num_residues=len(residue_list))
|
| 113 |
+
|
| 114 |
+
# Compute distance matrix and cache it based on the linkage string (structure)
|
| 115 |
+
# This is the most expensive part, so we cache it
|
| 116 |
+
if not hasattr(self, '_dist_cache'): self._dist_cache = {}
|
| 117 |
+
if linkages not in self._dist_cache:
|
| 118 |
+
self._dist_cache[linkages] = self._compute_distance_matrix(adj, len(residue_list))
|
| 119 |
+
dist_matrix_raw = self._dist_cache[linkages]
|
| 120 |
+
|
| 121 |
+
# Parse indices to map positions to residue definitions
|
| 122 |
+
index_list = indices.split('-') if indices else []
|
| 123 |
+
|
| 124 |
+
# Process each residue instance
|
| 125 |
+
residue_letter = ord('a')
|
| 126 |
+
for idx, res_idx in enumerate(index_list):
|
| 127 |
+
current_residue_id = idx
|
| 128 |
+
res_letter = chr(residue_letter + idx)
|
| 129 |
+
|
| 130 |
+
# Check if this is a branch point - add branch marker before
|
| 131 |
+
if res_letter in branch_points and branch_points[res_letter] > 0:
|
| 132 |
+
for _ in range(branch_points[res_letter]):
|
| 133 |
+
units.append({
|
| 134 |
+
'text': '[BRANCH_OPEN]',
|
| 135 |
+
'type': 'special',
|
| 136 |
+
'residue_id': -1,
|
| 137 |
+
'branch_depth': residue_depths.get(res_letter, 0),
|
| 138 |
+
'linkage_type': 0,
|
| 139 |
+
})
|
| 140 |
+
|
| 141 |
+
# Get residue definition
|
| 142 |
+
try:
|
| 143 |
+
res_def_idx = int(res_idx) - 1 # 1-indexed to 0-indexed
|
| 144 |
+
res_def = residue_list[res_def_idx] if res_def_idx < len(residue_list) else ""
|
| 145 |
+
except (ValueError, IndexError):
|
| 146 |
+
res_def = ""
|
| 147 |
+
|
| 148 |
+
# Split residue into base and modifications
|
| 149 |
+
res_parts = res_def.split('_')
|
| 150 |
+
base = res_parts[0] if res_parts else res_def
|
| 151 |
+
mods = res_parts[1:] if len(res_parts) > 1 else []
|
| 152 |
+
|
| 153 |
+
# Add residue base as a single unit
|
| 154 |
+
depth = residue_depths.get(res_letter, 0)
|
| 155 |
+
units.append({
|
| 156 |
+
'text': base,
|
| 157 |
+
'type': 'residue',
|
| 158 |
+
'residue_id': current_residue_id,
|
| 159 |
+
'branch_depth': depth,
|
| 160 |
+
'linkage_type': 0,
|
| 161 |
+
})
|
| 162 |
+
|
| 163 |
+
# Add modifications
|
| 164 |
+
for mod in mods:
|
| 165 |
+
units.append({
|
| 166 |
+
'text': mod,
|
| 167 |
+
'type': 'mod',
|
| 168 |
+
'residue_id': current_residue_id,
|
| 169 |
+
'branch_depth': depth,
|
| 170 |
+
'linkage_type': 0,
|
| 171 |
+
})
|
| 172 |
+
|
| 173 |
+
# Store distance matrix in units for easy access in tokenizer
|
| 174 |
+
if units:
|
| 175 |
+
# Find first residue unit or just use START
|
| 176 |
+
units[0]['distance_matrix'] = dist_matrix_raw
|
| 177 |
+
|
| 178 |
+
# Add linkages
|
| 179 |
+
for link in linkage_list:
|
| 180 |
+
if not link:
|
| 181 |
+
continue
|
| 182 |
+
# Parse linkage type
|
| 183 |
+
lt = self._parse_linkage_type(link)
|
| 184 |
+
units.append({
|
| 185 |
+
'text': link,
|
| 186 |
+
'type': 'linkage',
|
| 187 |
+
'residue_id': -2,
|
| 188 |
+
'branch_depth': 0,
|
| 189 |
+
'linkage_type': lt,
|
| 190 |
+
})
|
| 191 |
+
except Exception:
|
| 192 |
+
# Fallback for truly broken WURCS
|
| 193 |
+
pass
|
| 194 |
+
|
| 195 |
+
# Add end token
|
| 196 |
+
units.append({
|
| 197 |
+
'text': '[END]',
|
| 198 |
+
'type': 'special',
|
| 199 |
+
'residue_id': -1,
|
| 200 |
+
'branch_depth': 0,
|
| 201 |
+
'linkage_type': 0,
|
| 202 |
+
})
|
| 203 |
+
|
| 204 |
+
return units
|
| 205 |
+
|
| 206 |
+
def _split_wurcs_sections(self, wurcs: str) -> List[str]:
|
| 207 |
+
"""Split WURCS string into sections, handling nested brackets."""
|
| 208 |
+
# Remove WURCS= prefix
|
| 209 |
+
if wurcs.startswith('WURCS='):
|
| 210 |
+
wurcs = wurcs[6:]
|
| 211 |
+
|
| 212 |
+
sections = []
|
| 213 |
+
current = ""
|
| 214 |
+
bracket_depth = 0
|
| 215 |
+
|
| 216 |
+
for char in wurcs:
|
| 217 |
+
if char == '[':
|
| 218 |
+
bracket_depth += 1
|
| 219 |
+
current += char
|
| 220 |
+
elif char == ']':
|
| 221 |
+
bracket_depth -= 1
|
| 222 |
+
current += char
|
| 223 |
+
elif char == '/' and bracket_depth == 0:
|
| 224 |
+
sections.append(current)
|
| 225 |
+
current = ""
|
| 226 |
+
else:
|
| 227 |
+
current += char
|
| 228 |
+
|
| 229 |
+
if current:
|
| 230 |
+
sections.append(current)
|
| 231 |
+
|
| 232 |
+
return sections
|
| 233 |
+
|
| 234 |
+
def _analyze_tree_structure(self, linkages: List[str], num_residues: int) -> Tuple[Dict, Dict, Dict, Dict]:
|
| 235 |
+
"""Analyze linkages to compute branch points and residue depths."""
|
| 236 |
+
branch_points = defaultdict(int) # residue -> number of children
|
| 237 |
+
children = defaultdict(list)
|
| 238 |
+
all_residues = set()
|
| 239 |
+
linkage_types = {}
|
| 240 |
+
|
| 241 |
+
for link in linkages:
|
| 242 |
+
match = self.LINKAGE_PATTERN.match(link)
|
| 243 |
+
if match:
|
| 244 |
+
from_res, from_pos, to_res, to_pos = match.groups()
|
| 245 |
+
children[from_res].append(to_res)
|
| 246 |
+
all_residues.add(from_res)
|
| 247 |
+
all_residues.add(to_res)
|
| 248 |
+
|
| 249 |
+
# Store linkage type
|
| 250 |
+
linkage_types[link] = self._parse_linkage_type(link)
|
| 251 |
+
|
| 252 |
+
# Build adjacency list for BFS
|
| 253 |
+
adj = defaultdict(list)
|
| 254 |
+
for link in linkages:
|
| 255 |
+
match = self.LINKAGE_PATTERN.match(link)
|
| 256 |
+
if match:
|
| 257 |
+
u = ord(match.group(1)) - ord('a')
|
| 258 |
+
v = ord(match.group(3)) - ord('a')
|
| 259 |
+
if 0 <= u < num_residues and 0 <= v < num_residues:
|
| 260 |
+
adj[u].append(v)
|
| 261 |
+
adj[v].append(u)
|
| 262 |
+
|
| 263 |
+
# Find branch points (residues with >1 child)
|
| 264 |
+
for res, kids in children.items():
|
| 265 |
+
if len(kids) > 1:
|
| 266 |
+
branch_points[res] = len(kids) - 1 # Number of extra branches
|
| 267 |
+
|
| 268 |
+
# Compute depths using BFS
|
| 269 |
+
# Find root (residue with no parent)
|
| 270 |
+
child_set = set()
|
| 271 |
+
for kids in children.values():
|
| 272 |
+
child_set.update(kids)
|
| 273 |
+
roots = all_residues - child_set
|
| 274 |
+
root = min(roots) if roots else 'a'
|
| 275 |
+
|
| 276 |
+
depths = {root: 0}
|
| 277 |
+
queue = [root]
|
| 278 |
+
while queue:
|
| 279 |
+
current = queue.pop(0)
|
| 280 |
+
for child in children.get(current, []):
|
| 281 |
+
if child not in depths:
|
| 282 |
+
depths[child] = depths[current] + 1
|
| 283 |
+
queue.append(child)
|
| 284 |
+
|
| 285 |
+
return branch_points, depths, linkage_types, adj
|
| 286 |
+
|
| 287 |
+
def _compute_distance_matrix(self, adj: Dict[int, List[int]], num_residues: int) -> List[List[int]]:
|
| 288 |
+
"""
|
| 289 |
+
Compute shortest path distance (number of bonds) between all residue pairs using BFS.
|
| 290 |
+
"""
|
| 291 |
+
if num_residues == 0:
|
| 292 |
+
return []
|
| 293 |
+
|
| 294 |
+
dist_matrix = [[-1] * num_residues for _ in range(num_residues)]
|
| 295 |
+
|
| 296 |
+
for i in range(num_residues):
|
| 297 |
+
dist_matrix[i][i] = 0
|
| 298 |
+
queue = [(i, 0)]
|
| 299 |
+
visited = {i}
|
| 300 |
+
|
| 301 |
+
while queue:
|
| 302 |
+
curr, d = queue.pop(0)
|
| 303 |
+
dist_matrix[i][curr] = d
|
| 304 |
+
|
| 305 |
+
for neighbor in adj[curr]:
|
| 306 |
+
if neighbor not in visited:
|
| 307 |
+
visited.add(neighbor)
|
| 308 |
+
queue.append((neighbor, d + 1))
|
| 309 |
+
|
| 310 |
+
return dist_matrix
|
| 311 |
+
|
| 312 |
+
def _compute_distance_matrix_OLD(self, linkages: List[str], num_residues: int) -> List[List[int]]:
|
| 313 |
+
"""
|
| 314 |
+
Compute shortest path distance (number of bonds) between all residue pairs.
|
| 315 |
+
Returns a symmetric N x N matrix where N is num_residues.
|
| 316 |
+
Values are integers (number of steps). 0 on diagonal. -1 if unreachable (shouldn't happen in single tree).
|
| 317 |
+
"""
|
| 318 |
+
if num_residues == 0:
|
| 319 |
+
return []
|
| 320 |
+
|
| 321 |
+
# Initialize adjacency list
|
| 322 |
+
adj = defaultdict(list)
|
| 323 |
+
for link in linkages:
|
| 324 |
+
match = self.LINKAGE_PATTERN.match(link)
|
| 325 |
+
if match:
|
| 326 |
+
# WURCS indices are 1-based letters (a=1, b=2...)
|
| 327 |
+
from_res_char, _, to_res_char, _ = match.groups()
|
| 328 |
+
# Convert char to 0-based index
|
| 329 |
+
u = ord(from_res_char) - ord('a')
|
| 330 |
+
v = ord(to_res_char) - ord('a')
|
| 331 |
+
|
| 332 |
+
# Undirected graph for structural distance
|
| 333 |
+
if 0 <= u < num_residues and 0 <= v < num_residues:
|
| 334 |
+
adj[u].append(v)
|
| 335 |
+
adj[v].append(u)
|
| 336 |
+
|
| 337 |
+
# Compute All-Pairs Shortest Path (BFS from each node is fine for small N)
|
| 338 |
+
# Glycans are small (N ~ 5-20 usually), so O(N^2) BFS is cheap.
|
| 339 |
+
dist_matrix = [[-1] * num_residues for _ in range(num_residues)]
|
| 340 |
+
|
| 341 |
+
for i in range(num_residues):
|
| 342 |
+
dist_matrix[i][i] = 0
|
| 343 |
+
queue = [(i, 0)]
|
| 344 |
+
visited = {i}
|
| 345 |
+
|
| 346 |
+
while queue:
|
| 347 |
+
curr, d = queue.pop(0)
|
| 348 |
+
dist_matrix[i][curr] = d
|
| 349 |
+
|
| 350 |
+
for neighbor in adj[curr]:
|
| 351 |
+
if neighbor not in visited:
|
| 352 |
+
visited.add(neighbor)
|
| 353 |
+
queue.append((neighbor, d + 1))
|
| 354 |
+
|
| 355 |
+
return dist_matrix
|
| 356 |
+
|
| 357 |
+
def _parse_linkage_type(self, link: str) -> int:
|
| 358 |
+
"""Parse linkage string to get type ID."""
|
| 359 |
+
LINKAGE_TYPES = {
|
| 360 |
+
(1, 2): 0, (2, 1): 0,
|
| 361 |
+
(1, 3): 1, (3, 1): 1,
|
| 362 |
+
(1, 4): 2, (4, 1): 2,
|
| 363 |
+
(1, 6): 3, (6, 1): 3,
|
| 364 |
+
(2, 3): 4, (3, 2): 4,
|
| 365 |
+
(2, 6): 5, (6, 2): 5,
|
| 366 |
+
(3, 6): 6, (6, 3): 6,
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
match = self.LINKAGE_PATTERN.match(link)
|
| 370 |
+
if match:
|
| 371 |
+
_, from_pos, _, to_pos = match.groups()
|
| 372 |
+
try:
|
| 373 |
+
pos_tuple = (int(from_pos), int(to_pos))
|
| 374 |
+
return LINKAGE_TYPES.get(pos_tuple, 7)
|
| 375 |
+
except ValueError:
|
| 376 |
+
return 8 # Unknown
|
| 377 |
+
return 8
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class WURCSBPETokenizer:
|
| 381 |
+
"""
|
| 382 |
+
BPE tokenizer for WURCS with tree-aware embeddings.
|
| 383 |
+
"""
|
| 384 |
+
|
| 385 |
+
def __init__(self, vocab_path: Optional[str] = None):
|
| 386 |
+
self.pre_tokenizer = WURCSPreTokenizer()
|
| 387 |
+
|
| 388 |
+
# Special tokens (fixed)
|
| 389 |
+
self.special_tokens = self.pre_tokenizer.special_tokens
|
| 390 |
+
|
| 391 |
+
# BPE vocabulary
|
| 392 |
+
self.token_to_id: Dict[str, int] = {}
|
| 393 |
+
self.id_to_token: Dict[int, str] = {}
|
| 394 |
+
self.merges: List[Tuple[str, str]] = []
|
| 395 |
+
|
| 396 |
+
if vocab_path:
|
| 397 |
+
self.load_vocab(vocab_path)
|
| 398 |
+
else:
|
| 399 |
+
# Initialize with special tokens only
|
| 400 |
+
self.token_to_id = dict(self.special_tokens)
|
| 401 |
+
self.id_to_token = {v: k for k, v in self.token_to_id.items()}
|
| 402 |
+
|
| 403 |
+
@classmethod
|
| 404 |
+
def train_from_corpus(
|
| 405 |
+
cls,
|
| 406 |
+
wurcs_strings: List[str],
|
| 407 |
+
num_merges: int = 500,
|
| 408 |
+
output_path: Optional[str] = None,
|
| 409 |
+
min_frequency: int = 2,
|
| 410 |
+
max_token_length: Optional[int] = None,
|
| 411 |
+
) -> 'WURCSBPETokenizer':
|
| 412 |
+
"""
|
| 413 |
+
Train BPE on a corpus of WURCS strings.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
wurcs_strings: List of WURCS strings
|
| 417 |
+
num_merges: Number of BPE merge operations
|
| 418 |
+
output_path: Optional path to save vocabulary
|
| 419 |
+
min_frequency: Minimum frequency for a token to be kept
|
| 420 |
+
max_token_length: Maximum length of a merged token (None = no limit)
|
| 421 |
+
|
| 422 |
+
Returns:
|
| 423 |
+
Trained tokenizer
|
| 424 |
+
"""
|
| 425 |
+
tokenizer = cls()
|
| 426 |
+
pre_tok = WURCSPreTokenizer()
|
| 427 |
+
|
| 428 |
+
print(f"Training BPE on {len(wurcs_strings)} WURCS strings...")
|
| 429 |
+
|
| 430 |
+
# Step 1: Pre-tokenize all strings to get semantic units
|
| 431 |
+
all_units = []
|
| 432 |
+
for wurcs in wurcs_strings:
|
| 433 |
+
units = pre_tok.pre_tokenize(wurcs)
|
| 434 |
+
for unit in units:
|
| 435 |
+
if unit['type'] != 'special':
|
| 436 |
+
all_units.append(unit['text'])
|
| 437 |
+
|
| 438 |
+
# Step 2: Count unit frequencies
|
| 439 |
+
unit_counts = Counter(all_units)
|
| 440 |
+
print(f"Found {len(unit_counts)} unique units")
|
| 441 |
+
|
| 442 |
+
# Step 3: Initialize vocabulary with characters from all units
|
| 443 |
+
char_vocab = set()
|
| 444 |
+
for unit in unit_counts:
|
| 445 |
+
for char in unit:
|
| 446 |
+
char_vocab.add(char)
|
| 447 |
+
|
| 448 |
+
# Build initial vocab: special tokens + characters
|
| 449 |
+
vocab_id = len(tokenizer.special_tokens)
|
| 450 |
+
for char in sorted(char_vocab):
|
| 451 |
+
tokenizer.token_to_id[char] = vocab_id
|
| 452 |
+
tokenizer.id_to_token[vocab_id] = char
|
| 453 |
+
vocab_id += 1
|
| 454 |
+
|
| 455 |
+
print(f"Initial vocab size: {vocab_id} (special + characters)")
|
| 456 |
+
|
| 457 |
+
# Step 4: Convert units to character sequences
|
| 458 |
+
word_freqs = {}
|
| 459 |
+
for unit, count in unit_counts.items():
|
| 460 |
+
if count >= min_frequency:
|
| 461 |
+
# Split into characters with space separator
|
| 462 |
+
chars = tuple(unit)
|
| 463 |
+
word_freqs[chars] = count
|
| 464 |
+
|
| 465 |
+
# Step 5: BPE merging
|
| 466 |
+
merges = []
|
| 467 |
+
|
| 468 |
+
for merge_idx in range(num_merges):
|
| 469 |
+
# Count pairs
|
| 470 |
+
pair_counts = Counter()
|
| 471 |
+
for word, freq in word_freqs.items():
|
| 472 |
+
for i in range(len(word) - 1):
|
| 473 |
+
pair = (word[i], word[i + 1])
|
| 474 |
+
pair_counts[pair] += freq
|
| 475 |
+
|
| 476 |
+
if not pair_counts:
|
| 477 |
+
break
|
| 478 |
+
|
| 479 |
+
# Find most frequent pair
|
| 480 |
+
best_pair = pair_counts.most_common(1)[0][0]
|
| 481 |
+
best_count = pair_counts[best_pair]
|
| 482 |
+
|
| 483 |
+
if best_count < min_frequency:
|
| 484 |
+
break
|
| 485 |
+
|
| 486 |
+
# Merge pair
|
| 487 |
+
new_token = best_pair[0] + best_pair[1]
|
| 488 |
+
|
| 489 |
+
# Check length constraint
|
| 490 |
+
if max_token_length and len(new_token) > max_token_length:
|
| 491 |
+
# remove this pair from consideration for this iteration and future?
|
| 492 |
+
# Actually, skipping it here is tricky because we need to ignore it in pair_counts next time
|
| 493 |
+
# Simpler: Just skip adding it to merges and modify word_freqs?
|
| 494 |
+
# No, if we don't merge, we just continue to the next best pair in THIS iteration.
|
| 495 |
+
# But pair_counts is already computed.
|
| 496 |
+
# We need to loop until we find a valid pair or run out
|
| 497 |
+
|
| 498 |
+
# In this simple implementation, let's just skip this merge efficiently
|
| 499 |
+
# We need to find the NEXT most common pair.
|
| 500 |
+
|
| 501 |
+
# Re-do finding best pair loop
|
| 502 |
+
found_valid_pair = False
|
| 503 |
+
for pair, count in pair_counts.most_common():
|
| 504 |
+
token_candidate = pair[0] + pair[1]
|
| 505 |
+
if max_token_length and len(token_candidate) > max_token_length:
|
| 506 |
+
continue # Skip too long
|
| 507 |
+
|
| 508 |
+
if count < min_frequency:
|
| 509 |
+
break # Stop if frequency too low
|
| 510 |
+
|
| 511 |
+
# Found valid pair
|
| 512 |
+
best_pair = pair
|
| 513 |
+
best_count = count
|
| 514 |
+
new_token = token_candidate
|
| 515 |
+
found_valid_pair = True
|
| 516 |
+
break
|
| 517 |
+
|
| 518 |
+
if not found_valid_pair:
|
| 519 |
+
print(f" Stopping early: No more pairs satisfy max_token_length={max_token_length}")
|
| 520 |
+
break
|
| 521 |
+
|
| 522 |
+
# Final check before merging (in case we didn't enter the if block but updated vars)
|
| 523 |
+
# Actually the logic above handles it. If we entered the block, we either found a new best_pair or broke.
|
| 524 |
+
|
| 525 |
+
merges.append(best_pair)
|
| 526 |
+
|
| 527 |
+
# Add to vocab
|
| 528 |
+
tokenizer.token_to_id[new_token] = vocab_id
|
| 529 |
+
tokenizer.id_to_token[vocab_id] = new_token
|
| 530 |
+
vocab_id += 1
|
| 531 |
+
|
| 532 |
+
# Update word_freqs
|
| 533 |
+
new_word_freqs = {}
|
| 534 |
+
for word, freq in word_freqs.items():
|
| 535 |
+
new_word = []
|
| 536 |
+
i = 0
|
| 537 |
+
while i < len(word):
|
| 538 |
+
if i < len(word) - 1 and word[i] == best_pair[0] and word[i + 1] == best_pair[1]:
|
| 539 |
+
new_word.append(new_token)
|
| 540 |
+
i += 2
|
| 541 |
+
else:
|
| 542 |
+
new_word.append(word[i])
|
| 543 |
+
i += 1
|
| 544 |
+
new_word_freqs[tuple(new_word)] = freq
|
| 545 |
+
word_freqs = new_word_freqs
|
| 546 |
+
|
| 547 |
+
if (merge_idx + 1) % 100 == 0:
|
| 548 |
+
print(f" Merge {merge_idx + 1}/{num_merges}: '{best_pair[0]}' + '{best_pair[1]}' -> '{new_token}' (count={best_count})")
|
| 549 |
+
|
| 550 |
+
tokenizer.merges = merges
|
| 551 |
+
print(f"Final vocab size: {len(tokenizer.token_to_id)}")
|
| 552 |
+
|
| 553 |
+
# Save if requested
|
| 554 |
+
if output_path:
|
| 555 |
+
tokenizer.save_vocab(output_path)
|
| 556 |
+
|
| 557 |
+
return tokenizer
|
| 558 |
+
|
| 559 |
+
def apply_bpe(self, text: str) -> List[str]:
|
| 560 |
+
"""Apply BPE merges to a text string."""
|
| 561 |
+
if text in self.token_to_id:
|
| 562 |
+
return [text]
|
| 563 |
+
|
| 564 |
+
# Split into characters
|
| 565 |
+
tokens = list(text)
|
| 566 |
+
|
| 567 |
+
# Apply merges
|
| 568 |
+
for pair in self.merges:
|
| 569 |
+
new_tokens = []
|
| 570 |
+
i = 0
|
| 571 |
+
while i < len(tokens):
|
| 572 |
+
if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
|
| 573 |
+
new_tokens.append(pair[0] + pair[1])
|
| 574 |
+
i += 2
|
| 575 |
+
else:
|
| 576 |
+
new_tokens.append(tokens[i])
|
| 577 |
+
i += 1
|
| 578 |
+
tokens = new_tokens
|
| 579 |
+
|
| 580 |
+
return tokens
|
| 581 |
+
|
| 582 |
+
def tokenize(self, wurcs: str, max_length: int = 256) -> Dict:
|
| 583 |
+
"""
|
| 584 |
+
Tokenize a WURCS string.
|
| 585 |
+
|
| 586 |
+
Returns:
|
| 587 |
+
Dict with:
|
| 588 |
+
- tokens: List of token strings
|
| 589 |
+
- token_ids: List of token IDs
|
| 590 |
+
- residue_ids: List of residue IDs
|
| 591 |
+
- branch_depths: List of branch depths
|
| 592 |
+
- linkage_types: List of linkage types
|
| 593 |
+
- attention_mask: Attention mask
|
| 594 |
+
"""
|
| 595 |
+
# Pre-tokenize
|
| 596 |
+
units = self.pre_tokenizer.pre_tokenize(wurcs)
|
| 597 |
+
|
| 598 |
+
tokens = []
|
| 599 |
+
token_ids = []
|
| 600 |
+
residue_ids = []
|
| 601 |
+
branch_depths = []
|
| 602 |
+
linkage_types = []
|
| 603 |
+
|
| 604 |
+
for unit in units:
|
| 605 |
+
if unit['type'] == 'special':
|
| 606 |
+
# Special tokens stay as-is
|
| 607 |
+
tok = unit['text']
|
| 608 |
+
tokens.append(tok)
|
| 609 |
+
token_ids.append(self.token_to_id.get(tok, self.token_to_id['[UNK]']))
|
| 610 |
+
residue_ids.append(unit['residue_id'])
|
| 611 |
+
branch_depths.append(unit['branch_depth'])
|
| 612 |
+
linkage_types.append(unit['linkage_type'])
|
| 613 |
+
else:
|
| 614 |
+
# Apply BPE to this unit
|
| 615 |
+
bpe_tokens = self.apply_bpe(unit['text'])
|
| 616 |
+
for tok in bpe_tokens:
|
| 617 |
+
tokens.append(tok)
|
| 618 |
+
token_ids.append(self.token_to_id.get(tok, self.token_to_id['[UNK]']))
|
| 619 |
+
residue_ids.append(unit['residue_id'])
|
| 620 |
+
branch_depths.append(unit['branch_depth'])
|
| 621 |
+
linkage_types.append(unit['linkage_type'])
|
| 622 |
+
|
| 623 |
+
# Truncate if needed
|
| 624 |
+
if len(tokens) > max_length:
|
| 625 |
+
tokens = tokens[:max_length - 1] + ['[END]']
|
| 626 |
+
token_ids = token_ids[:max_length - 1] + [self.token_to_id['[END]']]
|
| 627 |
+
residue_ids = residue_ids[:max_length - 1] + [-1]
|
| 628 |
+
branch_depths = branch_depths[:max_length - 1] + [0]
|
| 629 |
+
linkage_types = linkage_types[:max_length - 1] + [0]
|
| 630 |
+
|
| 631 |
+
# Create attention mask and pad
|
| 632 |
+
length = len(tokens)
|
| 633 |
+
attention_mask = [1] * length
|
| 634 |
+
|
| 635 |
+
while len(tokens) < max_length:
|
| 636 |
+
tokens.append('[PAD]')
|
| 637 |
+
token_ids.append(self.token_to_id['[PAD]'])
|
| 638 |
+
residue_ids.append(-1)
|
| 639 |
+
branch_depths.append(0)
|
| 640 |
+
linkage_types.append(0)
|
| 641 |
+
attention_mask.append(0)
|
| 642 |
+
# Pre-tokenize
|
| 643 |
+
units = self.pre_tokenizer.pre_tokenize(wurcs)
|
| 644 |
+
|
| 645 |
+
# Extract distance matrix from pre-tokenizer result
|
| 646 |
+
dist_matrix_raw = units[0].get('distance_matrix', [])
|
| 647 |
+
num_residues = len(dist_matrix_raw)
|
| 648 |
+
|
| 649 |
+
# Map token-to-token distances using residue_ids
|
| 650 |
+
# token_i is associated with residue_ids[i].
|
| 651 |
+
# residue_ids[i] is index into dist_matrix_raw.
|
| 652 |
+
# If residue_ids[i] == -1 (special), distance is undefined (use -1 or 999)
|
| 653 |
+
|
| 654 |
+
# Use UNPADDED length for distance matrix to save massive memory
|
| 655 |
+
# distance_matrix will be e.g. 20x20, while tokens are padded to 256
|
| 656 |
+
token_len = length
|
| 657 |
+
distance_matrix = [[-1] * token_len for _ in range(token_len)]
|
| 658 |
+
|
| 659 |
+
for i in range(token_len):
|
| 660 |
+
for j in range(token_len):
|
| 661 |
+
r_i = residue_ids[i]
|
| 662 |
+
r_j = residue_ids[j]
|
| 663 |
+
|
| 664 |
+
if r_i >= 0 and r_j >= 0 and r_i < num_residues and r_j < num_residues:
|
| 665 |
+
distance_matrix[i][j] = dist_matrix_raw[r_i][r_j]
|
| 666 |
+
else:
|
| 667 |
+
distance_matrix[i][j] = -1 # Special/Padding
|
| 668 |
+
|
| 669 |
+
# MEMORY OPTIMIZATION: Do NOT pad matrix here.
|
| 670 |
+
# Pad on-the-fly in Dataset class instead.
|
| 671 |
+
# This saves massive memory (0.2GB vs 66GB).
|
| 672 |
+
|
| 673 |
+
return {
|
| 674 |
+
'tokens': tokens,
|
| 675 |
+
'token_ids': token_ids,
|
| 676 |
+
'residue_ids': residue_ids,
|
| 677 |
+
'branch_depths': branch_depths,
|
| 678 |
+
'linkage_types': linkage_types,
|
| 679 |
+
'attention_mask': attention_mask,
|
| 680 |
+
'distance_matrix': distance_matrix, # New Output
|
| 681 |
+
'length': length,
|
| 682 |
+
}
|
| 683 |
+
|
| 684 |
+
def save_vocab(self, path: str):
|
| 685 |
+
"""Save vocabulary to JSON file."""
|
| 686 |
+
data = {
|
| 687 |
+
'special_tokens': self.special_tokens,
|
| 688 |
+
'token_to_id': self.token_to_id,
|
| 689 |
+
'merges': [list(m) for m in self.merges],
|
| 690 |
+
'metadata': {
|
| 691 |
+
'vocab_size': len(self.token_to_id),
|
| 692 |
+
'num_merges': len(self.merges),
|
| 693 |
+
}
|
| 694 |
+
}
|
| 695 |
+
with open(path, 'w') as f:
|
| 696 |
+
json.dump(data, f, indent=2)
|
| 697 |
+
print(f"Saved vocabulary to {path}")
|
| 698 |
+
|
| 699 |
+
def load_vocab(self, path: str):
|
| 700 |
+
"""Load vocabulary from JSON file."""
|
| 701 |
+
with open(path, 'r') as f:
|
| 702 |
+
data = json.load(f)
|
| 703 |
+
|
| 704 |
+
self.special_tokens = data['special_tokens']
|
| 705 |
+
self.token_to_id = data['token_to_id']
|
| 706 |
+
self.id_to_token = {int(v): k for k, v in self.token_to_id.items()}
|
| 707 |
+
self.merges = [tuple(m) for m in data['merges']]
|
| 708 |
+
|
| 709 |
+
print(f"Loaded vocabulary with {len(self.token_to_id)} tokens")
|
| 710 |
+
|
| 711 |
+
@property
|
| 712 |
+
def vocab_size(self) -> int:
|
| 713 |
+
return len(self.token_to_id)
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
# ============================================================================
|
| 717 |
+
# Testing
|
| 718 |
+
# ============================================================================
|
| 719 |
+
|
| 720 |
+
if __name__ == '__main__':
|
| 721 |
+
# Test pre-tokenizer
|
| 722 |
+
print("="*80)
|
| 723 |
+
print("Testing WURCSPreTokenizer")
|
| 724 |
+
print("="*80)
|
| 725 |
+
|
| 726 |
+
pre_tok = WURCSPreTokenizer()
|
| 727 |
+
|
| 728 |
+
test_wurcs = [
|
| 729 |
+
"WURCS=2.0/2,2,1/[a2122h-1b_1-5][a2211m-1a_1-5]/1-2/a4-b1",
|
| 730 |
+
"WURCS=2.0/3,3,2/[a2122h-1b_1-5_2*NCC/3=O][a2112h-1a_1-5][a2211m-1a_1-5]/1-2-3/a4-b1_b3-c1",
|
| 731 |
+
]
|
| 732 |
+
|
| 733 |
+
for wurcs in test_wurcs:
|
| 734 |
+
print(f"\nWURCS: {wurcs[:60]}...")
|
| 735 |
+
units = pre_tok.pre_tokenize(wurcs)
|
| 736 |
+
print(f"Units ({len(units)}):")
|
| 737 |
+
for u in units[:10]:
|
| 738 |
+
print(f" {u['type']:10} | res={u['residue_id']:2} | depth={u['branch_depth']} | {u['text']}")
|
| 739 |
+
if len(units) > 10:
|
| 740 |
+
print(f" ... and {len(units) - 10} more")
|
vocab/bpe_ambiguity_tokens.json
ADDED
|
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"ambiguous_tokens": {
|
| 3 |
+
"?": 32,
|
| 4 |
+
"?|": 90,
|
| 5 |
+
"a?|": 108,
|
| 6 |
+
"a?|b": 109,
|
| 7 |
+
"?|c": 110,
|
| 8 |
+
"a?|b?|c": 111,
|
| 9 |
+
"?|d": 112,
|
| 10 |
+
"a?|b?|c?|d": 113,
|
| 11 |
+
"?|e": 114,
|
| 12 |
+
"a?|b?|c?|d?|e": 115,
|
| 13 |
+
"?|f": 116,
|
| 14 |
+
"a?|b?|c?|d?|e?|f": 117,
|
| 15 |
+
"?-": 118,
|
| 16 |
+
"?|g": 119,
|
| 17 |
+
"a?|b?|c?|d?|e?|f?|g": 120,
|
| 18 |
+
"?|h": 122,
|
| 19 |
+
"?|i": 123,
|
| 20 |
+
"?|h?|i": 124,
|
| 21 |
+
"?|j": 125,
|
| 22 |
+
"?|h?|i?|j": 126,
|
| 23 |
+
"?|k": 128,
|
| 24 |
+
"?|h?|i?|j?|k": 129,
|
| 25 |
+
"?|l": 130,
|
| 26 |
+
"?|h?|i?|j?|k?|l": 131,
|
| 27 |
+
"?|m": 132,
|
| 28 |
+
"?|h?|i?|j?|k?|l?|m": 133,
|
| 29 |
+
"?|h?|i?|j?|k?|l?|m?|": 138,
|
| 30 |
+
"n?|": 141,
|
| 31 |
+
"n?|o": 142,
|
| 32 |
+
"?}": 143,
|
| 33 |
+
"n?|o?|": 146,
|
| 34 |
+
"n?|o?|p": 147,
|
| 35 |
+
"?}-": 149,
|
| 36 |
+
"?}-{": 150,
|
| 37 |
+
"n?|o?|p?|": 153,
|
| 38 |
+
"n?|o?|p?|q": 154,
|
| 39 |
+
"n?|o?|p?|q?|": 157,
|
| 40 |
+
"n?|o?|p?|q?|r": 158,
|
| 41 |
+
"n?|o?|p?|q?|r?|": 165,
|
| 42 |
+
"n?|o?|p?|q?|r?|s": 166,
|
| 43 |
+
"n?|o?|p?|q?|r?|s?|": 170,
|
| 44 |
+
"n?|o?|p?|q?|r?|s?|t": 171,
|
| 45 |
+
"?|u": 189,
|
| 46 |
+
"a?-": 197,
|
| 47 |
+
"c?-": 201,
|
| 48 |
+
"?|u?|": 209,
|
| 49 |
+
"?|u?|v": 210,
|
| 50 |
+
"b?-": 211,
|
| 51 |
+
"a?-b1": 213,
|
| 52 |
+
"d?-": 217,
|
| 53 |
+
"b?-c1": 221,
|
| 54 |
+
"c?-d1": 230,
|
| 55 |
+
"?|u?|v?|": 231,
|
| 56 |
+
"?|u?|v?|w": 232,
|
| 57 |
+
"1-?": 242,
|
| 58 |
+
"d?-e1": 244,
|
| 59 |
+
"e?-": 245,
|
| 60 |
+
"e?-f1": 262,
|
| 61 |
+
"?|u?|v?|w?|": 266,
|
| 62 |
+
"?|u?|v?|w?|x": 267,
|
| 63 |
+
"f?-": 273,
|
| 64 |
+
"?|u?|v?|w?|x?|": 288,
|
| 65 |
+
"?|u?|v?|w?|x?|y": 289,
|
| 66 |
+
"g?-": 298,
|
| 67 |
+
"n?|o?|p?|q?|r?|s?|t?": 304,
|
| 68 |
+
"i?-": 306,
|
| 69 |
+
"h?-": 308,
|
| 70 |
+
"?|u?}-{": 312,
|
| 71 |
+
"?|u?": 313,
|
| 72 |
+
"n?|o?|p?|q?|r?}-{": 314,
|
| 73 |
+
"n?|o?|p?|q?|r?": 315,
|
| 74 |
+
"f?-g1": 318,
|
| 75 |
+
"n?|o?}-{": 322,
|
| 76 |
+
"n?|o?": 323,
|
| 77 |
+
"?|u?|v?|w?|x?|y?|": 325,
|
| 78 |
+
"?|u?|v?|w?|x?|y?|z": 326,
|
| 79 |
+
"n?|o?|p?|q?|r?|s?}-{": 328,
|
| 80 |
+
"n?|o?|p?|q?|r?|s?": 329,
|
| 81 |
+
"n?|o?|p?}-{": 331,
|
| 82 |
+
"n?|o?|p?": 332,
|
| 83 |
+
"n?|o?|p?|q?}-{": 336,
|
| 84 |
+
"n?|o?|p?|q?": 337,
|
| 85 |
+
"g?-h1": 339,
|
| 86 |
+
"?|h?|i?|j?|k?|l?}-{": 342,
|
| 87 |
+
"?|h?|i?|j?|k?|l?": 343,
|
| 88 |
+
"n?}-{": 344,
|
| 89 |
+
"n?": 345,
|
| 90 |
+
"j?-": 346,
|
| 91 |
+
"h?-i1": 347,
|
| 92 |
+
"?|u?|v?}-{": 351,
|
| 93 |
+
"?|u?|v?": 352,
|
| 94 |
+
"k?-": 353,
|
| 95 |
+
"i?-j1": 355,
|
| 96 |
+
"?|h?|i?|j?|k?|l?|m?": 363,
|
| 97 |
+
"?|h?|i?|j?|k?}-{": 364,
|
| 98 |
+
"?|h?|i?|j?|k?": 365,
|
| 99 |
+
"?|u?|v?|w?|x?|y?|z?|": 369,
|
| 100 |
+
"?|h?|i?|j?}-{": 375,
|
| 101 |
+
"?|h?|i?|j?": 376,
|
| 102 |
+
"l?-": 377,
|
| 103 |
+
"A?|": 392,
|
| 104 |
+
"A?|B": 393,
|
| 105 |
+
"j?-k1": 401,
|
| 106 |
+
"m?-": 404,
|
| 107 |
+
"?|h?|i?}-{": 408,
|
| 108 |
+
"?|h?|i?": 409,
|
| 109 |
+
"k?-l1": 418,
|
| 110 |
+
"?|h?}-{": 420,
|
| 111 |
+
"?|h?": 421,
|
| 112 |
+
"A?|B?|": 424,
|
| 113 |
+
"A?|B?|C": 425,
|
| 114 |
+
"a?|b?|c?|d?|e?|f?|g?": 427,
|
| 115 |
+
"f?-g2": 431,
|
| 116 |
+
"a?|b?|c?|d?|e?|f?}-{": 437,
|
| 117 |
+
"a?|b?|c?|d?|e?|f?": 438,
|
| 118 |
+
"l?-m1": 442,
|
| 119 |
+
"?|u?|v?|w?}-{": 450,
|
| 120 |
+
"?|u?|v?|w?": 451,
|
| 121 |
+
"A?|B?|C?|": 464,
|
| 122 |
+
"A?|B?|C?|D": 465,
|
| 123 |
+
"a?|b?|c?|d?|e?}-{": 475,
|
| 124 |
+
"a?|b?|c?|d?|e?": 476,
|
| 125 |
+
"n?-": 499,
|
| 126 |
+
"a?|b?|c?|d?}-{": 502,
|
| 127 |
+
"a?|b?|c?|d?": 503,
|
| 128 |
+
"m?-n1": 518,
|
| 129 |
+
"A?|B?|C?|D?|": 521,
|
| 130 |
+
"A?|B?|C?|D?|E": 522,
|
| 131 |
+
"o?-": 534,
|
| 132 |
+
"d?-h1": 536,
|
| 133 |
+
"A?|B?|C?|D?|E?|": 542,
|
| 134 |
+
"A?|B?|C?|D?|E?|F": 543,
|
| 135 |
+
"c?-i1": 544,
|
| 136 |
+
"c?-h1": 549,
|
| 137 |
+
"A?|B?|C?|D?|E?|F?|": 550,
|
| 138 |
+
"A?|B?|C?|D?|E?|F?|G": 551,
|
| 139 |
+
"?|u?|v?|w?|x?}-{": 563,
|
| 140 |
+
"?|u?|v?|w?|x?": 564,
|
| 141 |
+
"a?|b?|c?}-{": 571,
|
| 142 |
+
"a?|b?|c?}-{a?|b?|c": 572,
|
| 143 |
+
"a?|b?|c?}-{a?|b?|c?": 573,
|
| 144 |
+
"?|H": 581,
|
| 145 |
+
"2-?": 592,
|
| 146 |
+
"?|H?|": 598,
|
| 147 |
+
"?|H?|I": 599,
|
| 148 |
+
"?}*OC": 600,
|
| 149 |
+
"c?-k1": 607,
|
| 150 |
+
"c?-g1": 609,
|
| 151 |
+
"?|H?|I?|": 615,
|
| 152 |
+
"?|H?|I?|J": 616,
|
| 153 |
+
"n?-o1": 617,
|
| 154 |
+
"d?-g1": 629,
|
| 155 |
+
"o?-p1": 634,
|
| 156 |
+
"p?-": 646,
|
| 157 |
+
"?|u?|v?|w?|x?|y?}-{": 653,
|
| 158 |
+
"?|u?|v?|w?|x?|y?": 654,
|
| 159 |
+
"b?-c2": 656,
|
| 160 |
+
"d?-i1": 658,
|
| 161 |
+
"c?-j1": 691,
|
| 162 |
+
"?}*OSO": 696,
|
| 163 |
+
"e?-h1": 701,
|
| 164 |
+
"q?-": 713,
|
| 165 |
+
"c?-f1": 720,
|
| 166 |
+
"i?-j2": 728,
|
| 167 |
+
"?|h?|i?}": 742,
|
| 168 |
+
"h?-i2": 747,
|
| 169 |
+
"g?-h2": 753,
|
| 170 |
+
"c?-l1": 756,
|
| 171 |
+
"j?-k2": 758,
|
| 172 |
+
"?|h?}": 759,
|
| 173 |
+
"c?-e1": 760,
|
| 174 |
+
"?|H?|I?|J?|": 761,
|
| 175 |
+
"?|H?|I?|J?|K": 762,
|
| 176 |
+
"a?|b?|c?|d?|e?|f?}": 772,
|
| 177 |
+
"b?-e1": 774,
|
| 178 |
+
"b?-f1": 791,
|
| 179 |
+
"d?-f1": 794,
|
| 180 |
+
"p?-q1": 796,
|
| 181 |
+
"a?|b?|c?|d?|e?}": 798,
|
| 182 |
+
"a?-d1": 800,
|
| 183 |
+
"m?-n2": 803,
|
| 184 |
+
"e?-g1": 809,
|
| 185 |
+
"?|h?|i?|j?}": 812,
|
| 186 |
+
"r?-": 817,
|
| 187 |
+
"a?-c1": 818,
|
| 188 |
+
"?|u?|v?|w?|x?|y?|z?": 822,
|
| 189 |
+
"a?-e1": 826,
|
| 190 |
+
"d?-j1": 833,
|
| 191 |
+
"b?-g1": 834,
|
| 192 |
+
"q?-r1": 847,
|
| 193 |
+
"d?-e2": 854,
|
| 194 |
+
"c?-m1": 860,
|
| 195 |
+
"a?-f1": 875,
|
| 196 |
+
"b?-d1": 887,
|
| 197 |
+
"?|H?|I?|J?|K?|": 892,
|
| 198 |
+
"?|H?|I?|J?|K?|L": 893,
|
| 199 |
+
"?|H?|I?|J?|K?|L?|": 894,
|
| 200 |
+
"?|H?|I?|J?|K?|L?|M": 895,
|
| 201 |
+
"?|H?|I?|J?|K?|L?|M?|": 896,
|
| 202 |
+
"a?-l1": 920,
|
| 203 |
+
"?*OSO/3=O/3=O": 923,
|
| 204 |
+
"k?-l2": 940,
|
| 205 |
+
"k?-o1": 942,
|
| 206 |
+
"N?|": 965,
|
| 207 |
+
"N?|O": 966,
|
| 208 |
+
"N?|O?|": 967,
|
| 209 |
+
"N?|O?|P": 968,
|
| 210 |
+
"N?|O?|P?|": 969,
|
| 211 |
+
"N?|O?|P?|Q": 970,
|
| 212 |
+
"N?|O?|P?|Q?|": 971,
|
| 213 |
+
"N?|O?|P?|Q?|R": 972,
|
| 214 |
+
"N?|O?|P?|Q?|R?|": 973,
|
| 215 |
+
"N?|O?|P?|Q?|R?|S": 974,
|
| 216 |
+
"N?|O?|P?|Q?|R?|S?|": 975,
|
| 217 |
+
"N?|O?|P?|Q?|R?|S?|T": 976,
|
| 218 |
+
"?|U": 977,
|
| 219 |
+
"?|U?|": 978,
|
| 220 |
+
"?|U?|V": 979,
|
| 221 |
+
"c?-d2": 983,
|
| 222 |
+
"r?-s1": 988,
|
| 223 |
+
"a?|b?}-{": 995,
|
| 224 |
+
"a?|b?}-{a?|b": 996,
|
| 225 |
+
"a?|b?}-{a?|b?": 997,
|
| 226 |
+
"e?-f2": 1001,
|
| 227 |
+
"g?-i1": 1006,
|
| 228 |
+
"i?-l1": 1010,
|
| 229 |
+
"s?-": 1011,
|
| 230 |
+
"?|h?|i?|j?|k?}": 1017,
|
| 231 |
+
"b?-h1": 1034,
|
| 232 |
+
"a?-j1": 1038,
|
| 233 |
+
"n?-o2": 1046,
|
| 234 |
+
"a?-b2": 1069,
|
| 235 |
+
"e?-i1": 1095,
|
| 236 |
+
"h?-j1": 1102,
|
| 237 |
+
"a?-k1": 1108,
|
| 238 |
+
"i?-k1": 1115,
|
| 239 |
+
"a?-g1": 1116,
|
| 240 |
+
"?}*OPO": 1122,
|
| 241 |
+
"d?-k1": 1129,
|
| 242 |
+
"a?-m1": 1151,
|
| 243 |
+
"a?-i1": 1159,
|
| 244 |
+
"A?}-{": 1174,
|
| 245 |
+
"A?": 1175,
|
| 246 |
+
"?}*OCC": 1177,
|
| 247 |
+
"l?-m2": 1179,
|
| 248 |
+
"A?|B?}-{": 1180,
|
| 249 |
+
"A?|B?": 1181,
|
| 250 |
+
"f?-h1": 1183,
|
| 251 |
+
"a?-n1": 1189,
|
| 252 |
+
"p?-q2": 1192,
|
| 253 |
+
"c?-n1": 1197,
|
| 254 |
+
"?|U?|V?|": 1202,
|
| 255 |
+
"?|U?|V?|W": 1203,
|
| 256 |
+
"?|U?|V?|W?|": 1204,
|
| 257 |
+
"?|U?|V?|W?|X": 1205,
|
| 258 |
+
"?|U?|V?|W?|X?|": 1206,
|
| 259 |
+
"?|U?|V?|W?|X?|Y": 1207,
|
| 260 |
+
"?|a": 1208,
|
| 261 |
+
"s?-t1": 1223,
|
| 262 |
+
"?|h?|i?|j?|k?|l?|m?}": 1228,
|
| 263 |
+
"g?-j1": 1234,
|
| 264 |
+
"A?|B?|C?|D?}-{": 1242,
|
| 265 |
+
"A?|B?|C?|D?": 1243,
|
| 266 |
+
"a?-h1": 1253,
|
| 267 |
+
"?|H?|I?|J?}-{": 1257,
|
| 268 |
+
"?|H?|I?|J?": 1258,
|
| 269 |
+
"o?-p2": 1261,
|
| 270 |
+
"b?-i1": 1273,
|
| 271 |
+
"?|h?|i?|j?|k?|l?}": 1309,
|
| 272 |
+
"j?-m1": 1317,
|
| 273 |
+
"c?-o1": 1318,
|
| 274 |
+
"a?-o1": 1330,
|
| 275 |
+
"a?|b?|c?}*OC": 1331,
|
| 276 |
+
"b?-j1": 1357,
|
| 277 |
+
"a?-r1": 1361,
|
| 278 |
+
"n?}": 1363,
|
| 279 |
+
"A?|B?|C?}-{": 1371,
|
| 280 |
+
"A?|B?|C?": 1372,
|
| 281 |
+
"m?-p1": 1375,
|
| 282 |
+
"l?-p1": 1383,
|
| 283 |
+
"a?-p1": 1444,
|
| 284 |
+
"k?-n1": 1446,
|
| 285 |
+
"j?-l1": 1470,
|
| 286 |
+
"?|U?|V?|W?|X?|Y?|": 1471,
|
| 287 |
+
"?|U?|V?|W?|X?|Y?|Z": 1472,
|
| 288 |
+
"?|aa?|": 1473,
|
| 289 |
+
"?|aa?|a": 1474,
|
| 290 |
+
"?|aa?|ab": 1475,
|
| 291 |
+
"?*OPO/3O/3=O": 1476,
|
| 292 |
+
"l?-q1": 1489,
|
| 293 |
+
"l?-n1": 1499,
|
| 294 |
+
"a?-s1": 1517,
|
| 295 |
+
"k?-m1": 1524,
|
| 296 |
+
"a?-q1": 1546,
|
| 297 |
+
"c?-q1": 1547,
|
| 298 |
+
"t?-": 1551,
|
| 299 |
+
"a?|b?|c?|d?}*OC": 1565,
|
| 300 |
+
"f?-i1": 1590,
|
| 301 |
+
"c?-p1": 1591,
|
| 302 |
+
"n?-q1": 1593,
|
| 303 |
+
"?|i?}": 1611,
|
| 304 |
+
"a?|b?|c?|d?|e?}*OC": 1612,
|
| 305 |
+
"m?-q1": 1617,
|
| 306 |
+
"q?-r2": 1623,
|
| 307 |
+
"l?-o1": 1624,
|
| 308 |
+
"m?-r1": 1628,
|
| 309 |
+
"a?-t1": 1630,
|
| 310 |
+
"a?|b?|c?|d?}*OSO": 1649,
|
| 311 |
+
"c?-r1": 1675,
|
| 312 |
+
"1-d?|i?}": 1683,
|
| 313 |
+
"j?-n1": 1691,
|
| 314 |
+
"u?-": 1694,
|
| 315 |
+
"a?|b?|c?|d?|e?}*OSO": 1718,
|
| 316 |
+
"?*OCC/3=O": 1723,
|
| 317 |
+
"?%": 1752,
|
| 318 |
+
"?*OP^XOCCN/3O/3=O": 1770,
|
| 319 |
+
"t?-u1": 1772,
|
| 320 |
+
"?*": 1774,
|
| 321 |
+
"c?-s1": 1775,
|
| 322 |
+
"a?-u1": 1793,
|
| 323 |
+
"f?-h2": 1808,
|
| 324 |
+
"e?-j1": 1811,
|
| 325 |
+
"c?-t1": 1818,
|
| 326 |
+
"f1-a?|b?|c?|d?|e?}": 1822,
|
| 327 |
+
"u?-v1": 1835,
|
| 328 |
+
"h?-k1": 1841,
|
| 329 |
+
"?|H?|I?|J?|K?}-{": 1846,
|
| 330 |
+
"?|H?|I?|J?|K?": 1847,
|
| 331 |
+
"n?|o?}": 1851,
|
| 332 |
+
"1-d?|h?}": 1852,
|
| 333 |
+
"q?-s1": 1872,
|
| 334 |
+
"%?%": 1880,
|
| 335 |
+
"b?-g2": 1881,
|
| 336 |
+
"r?-s2": 1882,
|
| 337 |
+
"d?-l1": 1898,
|
| 338 |
+
"v?-": 1917,
|
| 339 |
+
"b?-k1": 1927,
|
| 340 |
+
"?|aa?|ab?|a": 1942,
|
| 341 |
+
"?|aa?|ab?|ac": 1943,
|
| 342 |
+
"?|aa?|ab?|ac?|": 1944,
|
| 343 |
+
"?|aa?|ab?|ac?|ad": 1945,
|
| 344 |
+
"a?|b?}*OC": 1949,
|
| 345 |
+
"?*OC": 1952,
|
| 346 |
+
"e?-k1": 1955,
|
| 347 |
+
"a?-d2": 1999,
|
| 348 |
+
"s?-t2": 2013,
|
| 349 |
+
"a?-f2": 2027,
|
| 350 |
+
"o?-q1": 2030,
|
| 351 |
+
"?}*OP^XOCCN": 2040,
|
| 352 |
+
"a?|b?|c?}*OCC": 2047,
|
| 353 |
+
"m?-o1": 2048,
|
| 354 |
+
"c?-f2": 2058,
|
| 355 |
+
"A?|B?|C?|D?|E?|F?|G?": 2060,
|
| 356 |
+
"a?|b?|c?}*OSO": 2071,
|
| 357 |
+
"?|U?|V?}-{": 2079,
|
| 358 |
+
"?|U?|V?": 2080,
|
| 359 |
+
"c?-u1": 2087
|
| 360 |
+
},
|
| 361 |
+
"ambiguous_ids": [
|
| 362 |
+
32,
|
| 363 |
+
90,
|
| 364 |
+
108,
|
| 365 |
+
109,
|
| 366 |
+
110,
|
| 367 |
+
111,
|
| 368 |
+
112,
|
| 369 |
+
113,
|
| 370 |
+
114,
|
| 371 |
+
115,
|
| 372 |
+
116,
|
| 373 |
+
117,
|
| 374 |
+
118,
|
| 375 |
+
119,
|
| 376 |
+
120,
|
| 377 |
+
122,
|
| 378 |
+
123,
|
| 379 |
+
124,
|
| 380 |
+
125,
|
| 381 |
+
126,
|
| 382 |
+
128,
|
| 383 |
+
129,
|
| 384 |
+
130,
|
| 385 |
+
131,
|
| 386 |
+
132,
|
| 387 |
+
133,
|
| 388 |
+
138,
|
| 389 |
+
141,
|
| 390 |
+
142,
|
| 391 |
+
143,
|
| 392 |
+
146,
|
| 393 |
+
147,
|
| 394 |
+
149,
|
| 395 |
+
150,
|
| 396 |
+
153,
|
| 397 |
+
154,
|
| 398 |
+
157,
|
| 399 |
+
158,
|
| 400 |
+
165,
|
| 401 |
+
166,
|
| 402 |
+
170,
|
| 403 |
+
171,
|
| 404 |
+
189,
|
| 405 |
+
197,
|
| 406 |
+
201,
|
| 407 |
+
209,
|
| 408 |
+
210,
|
| 409 |
+
211,
|
| 410 |
+
213,
|
| 411 |
+
217,
|
| 412 |
+
221,
|
| 413 |
+
230,
|
| 414 |
+
231,
|
| 415 |
+
232,
|
| 416 |
+
242,
|
| 417 |
+
244,
|
| 418 |
+
245,
|
| 419 |
+
262,
|
| 420 |
+
266,
|
| 421 |
+
267,
|
| 422 |
+
273,
|
| 423 |
+
288,
|
| 424 |
+
289,
|
| 425 |
+
298,
|
| 426 |
+
304,
|
| 427 |
+
306,
|
| 428 |
+
308,
|
| 429 |
+
312,
|
| 430 |
+
313,
|
| 431 |
+
314,
|
| 432 |
+
315,
|
| 433 |
+
318,
|
| 434 |
+
322,
|
| 435 |
+
323,
|
| 436 |
+
325,
|
| 437 |
+
326,
|
| 438 |
+
328,
|
| 439 |
+
329,
|
| 440 |
+
331,
|
| 441 |
+
332,
|
| 442 |
+
336,
|
| 443 |
+
337,
|
| 444 |
+
339,
|
| 445 |
+
342,
|
| 446 |
+
343,
|
| 447 |
+
344,
|
| 448 |
+
345,
|
| 449 |
+
346,
|
| 450 |
+
347,
|
| 451 |
+
351,
|
| 452 |
+
352,
|
| 453 |
+
353,
|
| 454 |
+
355,
|
| 455 |
+
363,
|
| 456 |
+
364,
|
| 457 |
+
365,
|
| 458 |
+
369,
|
| 459 |
+
375,
|
| 460 |
+
376,
|
| 461 |
+
377,
|
| 462 |
+
392,
|
| 463 |
+
393,
|
| 464 |
+
401,
|
| 465 |
+
404,
|
| 466 |
+
408,
|
| 467 |
+
409,
|
| 468 |
+
418,
|
| 469 |
+
420,
|
| 470 |
+
421,
|
| 471 |
+
424,
|
| 472 |
+
425,
|
| 473 |
+
427,
|
| 474 |
+
431,
|
| 475 |
+
437,
|
| 476 |
+
438,
|
| 477 |
+
442,
|
| 478 |
+
450,
|
| 479 |
+
451,
|
| 480 |
+
464,
|
| 481 |
+
465,
|
| 482 |
+
475,
|
| 483 |
+
476,
|
| 484 |
+
499,
|
| 485 |
+
502,
|
| 486 |
+
503,
|
| 487 |
+
518,
|
| 488 |
+
521,
|
| 489 |
+
522,
|
| 490 |
+
534,
|
| 491 |
+
536,
|
| 492 |
+
542,
|
| 493 |
+
543,
|
| 494 |
+
544,
|
| 495 |
+
549,
|
| 496 |
+
550,
|
| 497 |
+
551,
|
| 498 |
+
563,
|
| 499 |
+
564,
|
| 500 |
+
571,
|
| 501 |
+
572,
|
| 502 |
+
573,
|
| 503 |
+
581,
|
| 504 |
+
592,
|
| 505 |
+
598,
|
| 506 |
+
599,
|
| 507 |
+
600,
|
| 508 |
+
607,
|
| 509 |
+
609,
|
| 510 |
+
615,
|
| 511 |
+
616,
|
| 512 |
+
617,
|
| 513 |
+
629,
|
| 514 |
+
634,
|
| 515 |
+
646,
|
| 516 |
+
653,
|
| 517 |
+
654,
|
| 518 |
+
656,
|
| 519 |
+
658,
|
| 520 |
+
691,
|
| 521 |
+
696,
|
| 522 |
+
701,
|
| 523 |
+
713,
|
| 524 |
+
720,
|
| 525 |
+
728,
|
| 526 |
+
742,
|
| 527 |
+
747,
|
| 528 |
+
753,
|
| 529 |
+
756,
|
| 530 |
+
758,
|
| 531 |
+
759,
|
| 532 |
+
760,
|
| 533 |
+
761,
|
| 534 |
+
762,
|
| 535 |
+
772,
|
| 536 |
+
774,
|
| 537 |
+
791,
|
| 538 |
+
794,
|
| 539 |
+
796,
|
| 540 |
+
798,
|
| 541 |
+
800,
|
| 542 |
+
803,
|
| 543 |
+
809,
|
| 544 |
+
812,
|
| 545 |
+
817,
|
| 546 |
+
818,
|
| 547 |
+
822,
|
| 548 |
+
826,
|
| 549 |
+
833,
|
| 550 |
+
834,
|
| 551 |
+
847,
|
| 552 |
+
854,
|
| 553 |
+
860,
|
| 554 |
+
875,
|
| 555 |
+
887,
|
| 556 |
+
892,
|
| 557 |
+
893,
|
| 558 |
+
894,
|
| 559 |
+
895,
|
| 560 |
+
896,
|
| 561 |
+
920,
|
| 562 |
+
923,
|
| 563 |
+
940,
|
| 564 |
+
942,
|
| 565 |
+
965,
|
| 566 |
+
966,
|
| 567 |
+
967,
|
| 568 |
+
968,
|
| 569 |
+
969,
|
| 570 |
+
970,
|
| 571 |
+
971,
|
| 572 |
+
972,
|
| 573 |
+
973,
|
| 574 |
+
974,
|
| 575 |
+
975,
|
| 576 |
+
976,
|
| 577 |
+
977,
|
| 578 |
+
978,
|
| 579 |
+
979,
|
| 580 |
+
983,
|
| 581 |
+
988,
|
| 582 |
+
995,
|
| 583 |
+
996,
|
| 584 |
+
997,
|
| 585 |
+
1001,
|
| 586 |
+
1006,
|
| 587 |
+
1010,
|
| 588 |
+
1011,
|
| 589 |
+
1017,
|
| 590 |
+
1034,
|
| 591 |
+
1038,
|
| 592 |
+
1046,
|
| 593 |
+
1069,
|
| 594 |
+
1095,
|
| 595 |
+
1102,
|
| 596 |
+
1108,
|
| 597 |
+
1115,
|
| 598 |
+
1116,
|
| 599 |
+
1122,
|
| 600 |
+
1129,
|
| 601 |
+
1151,
|
| 602 |
+
1159,
|
| 603 |
+
1174,
|
| 604 |
+
1175,
|
| 605 |
+
1177,
|
| 606 |
+
1179,
|
| 607 |
+
1180,
|
| 608 |
+
1181,
|
| 609 |
+
1183,
|
| 610 |
+
1189,
|
| 611 |
+
1192,
|
| 612 |
+
1197,
|
| 613 |
+
1202,
|
| 614 |
+
1203,
|
| 615 |
+
1204,
|
| 616 |
+
1205,
|
| 617 |
+
1206,
|
| 618 |
+
1207,
|
| 619 |
+
1208,
|
| 620 |
+
1223,
|
| 621 |
+
1228,
|
| 622 |
+
1234,
|
| 623 |
+
1242,
|
| 624 |
+
1243,
|
| 625 |
+
1253,
|
| 626 |
+
1257,
|
| 627 |
+
1258,
|
| 628 |
+
1261,
|
| 629 |
+
1273,
|
| 630 |
+
1309,
|
| 631 |
+
1317,
|
| 632 |
+
1318,
|
| 633 |
+
1330,
|
| 634 |
+
1331,
|
| 635 |
+
1357,
|
| 636 |
+
1361,
|
| 637 |
+
1363,
|
| 638 |
+
1371,
|
| 639 |
+
1372,
|
| 640 |
+
1375,
|
| 641 |
+
1383,
|
| 642 |
+
1444,
|
| 643 |
+
1446,
|
| 644 |
+
1470,
|
| 645 |
+
1471,
|
| 646 |
+
1472,
|
| 647 |
+
1473,
|
| 648 |
+
1474,
|
| 649 |
+
1475,
|
| 650 |
+
1476,
|
| 651 |
+
1489,
|
| 652 |
+
1499,
|
| 653 |
+
1517,
|
| 654 |
+
1524,
|
| 655 |
+
1546,
|
| 656 |
+
1547,
|
| 657 |
+
1551,
|
| 658 |
+
1565,
|
| 659 |
+
1590,
|
| 660 |
+
1591,
|
| 661 |
+
1593,
|
| 662 |
+
1611,
|
| 663 |
+
1612,
|
| 664 |
+
1617,
|
| 665 |
+
1623,
|
| 666 |
+
1624,
|
| 667 |
+
1628,
|
| 668 |
+
1630,
|
| 669 |
+
1649,
|
| 670 |
+
1675,
|
| 671 |
+
1683,
|
| 672 |
+
1691,
|
| 673 |
+
1694,
|
| 674 |
+
1718,
|
| 675 |
+
1723,
|
| 676 |
+
1752,
|
| 677 |
+
1770,
|
| 678 |
+
1772,
|
| 679 |
+
1774,
|
| 680 |
+
1775,
|
| 681 |
+
1793,
|
| 682 |
+
1808,
|
| 683 |
+
1811,
|
| 684 |
+
1818,
|
| 685 |
+
1822,
|
| 686 |
+
1835,
|
| 687 |
+
1841,
|
| 688 |
+
1846,
|
| 689 |
+
1847,
|
| 690 |
+
1851,
|
| 691 |
+
1852,
|
| 692 |
+
1872,
|
| 693 |
+
1880,
|
| 694 |
+
1881,
|
| 695 |
+
1882,
|
| 696 |
+
1898,
|
| 697 |
+
1917,
|
| 698 |
+
1927,
|
| 699 |
+
1942,
|
| 700 |
+
1943,
|
| 701 |
+
1944,
|
| 702 |
+
1945,
|
| 703 |
+
1949,
|
| 704 |
+
1952,
|
| 705 |
+
1955,
|
| 706 |
+
1999,
|
| 707 |
+
2013,
|
| 708 |
+
2027,
|
| 709 |
+
2030,
|
| 710 |
+
2040,
|
| 711 |
+
2047,
|
| 712 |
+
2048,
|
| 713 |
+
2058,
|
| 714 |
+
2060,
|
| 715 |
+
2071,
|
| 716 |
+
2079,
|
| 717 |
+
2080,
|
| 718 |
+
2087
|
| 719 |
+
],
|
| 720 |
+
"source_vocab": "data/bpe_vocabulary_clean.json"
|
| 721 |
+
}
|
vocab/bpe_vocabulary.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|