IlPakoZ commited on
Commit
12cd9ef
·
1 Parent(s): c3c533f

Initial upload

Browse files
README.md ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - chemistry
5
+ - molecular-property-prediction
6
+ - selfies
7
+ - encoder
8
+ license: apache-2.0
9
+ ---
10
+
11
+ # M5 Encoder
12
+
13
+ A SELFIES-based molecular encoder built on a T5 backbone with custom
14
+ distance-aware relative position encodings. Two classes are available:
15
+
16
+ | Class | Description |
17
+ |---|---|
18
+ | `M5Encoder` | Bare encoder, outputs `last_hidden_state` |
19
+ | `M5ModelForRegression` | Encoder + sequence-level and token-level regression heads|
20
+
21
+ The model is pretrained on multi-task regression tasks, including quantum chemistry (QC) tasks
22
+ from the [PubChemQC B3LYP/PM6 dataset](https://nakatamaho.riken.jp/pubchemqc.riken.jp/b3lyp_pm6_datasets.html).
23
+
24
+ ## Usage
25
+
26
+ ```python
27
+ from transformers import AutoConfig, AutoModel
28
+
29
+ config = AutoConfig.from_pretrained("IlPakoZ/m5-encoder", trust_remote_code=True)
30
+ model = AutoModel.from_pretrained("IlPakoZ/m5-encoder", trust_remote_code=True)
31
+ ```
32
+
33
+ To load `M5ModelForRegression` explicitly:
34
+
35
+ ```python
36
+ from transformers import AutoModelForSequenceClassification
37
+
38
+ model = AutoModelForSequenceClassification.from_pretrained(
39
+ "IlPakoZ/m5-encoder", trust_remote_code=True
40
+ )
41
+ ```
42
+
43
+
44
+ ## Architecture
45
+
46
+ | Hyper-parameter | Value |
47
+ |---|---|
48
+ | `d_model` | 512 |
49
+ | `d_ff` | 2048 |
50
+ | `d_kv` | 64 |
51
+ | `num_layers` | 24 |
52
+ | `num_heads` | 12 |
53
+ | `vocab_size` | 1 032 |
54
+ | `feed_forward_proj` | gated-gelu |
55
+ | `relative_attention_num_buckets` | 48 |
56
+ | `relative_attention_max_distance` | 128 |
57
+
58
+ Position biases are replaced by molecular-graph distances computed
59
+ with RDKit and binned with a modified T5 logarithm binning algorithm, giving the model awareness to molecular topology without being too strict on precise distances.
60
+
61
+ ## Tasks
62
+
63
+ Pretraining consists of up to 1085 tasks across five regression heads. Tasks are grouped by source and prediction target:
64
+
65
+ ### Group 0 — General molecular descriptors (RDKit)
66
+
67
+ | Task | Description |
68
+ |---|---|
69
+ | `MW` | Molecular weight |
70
+ | `TDM` | Total dipole moment |
71
+
72
+ ### Group 1 — Physicochemical properties (RDKit)
73
+
74
+ | Task | Description |
75
+ |---|---|
76
+ | `MolLogP` | Wildman-Crippen LogP estimate |
77
+ | `MolMR` | Wildman-Crippen molar refractivity |
78
+ | `TPSA` | Topological polar surface area |
79
+ | `FractionCSP3` | Fraction of sp³ carbons |
80
+
81
+ ### Group 2 — Frontier orbital energies (PubChemQC B3LYP/PM6)
82
+
83
+ Alpha and beta spin-orbital energies from DFT calculations:
84
+
85
+ | Task | Description |
86
+ |---|---|
87
+ | `energy_alpha_homo` | Alpha HOMO energy |
88
+ | `energy_alpha_gap` | Alpha HOMO–LUMO gap |
89
+ | `energy_alpha_lumo` | Alpha LUMO energy |
90
+ | `energy_beta_homo` | Beta HOMO energy |
91
+ | `energy_beta_gap` | Beta HOMO–LUMO gap |
92
+ | `energy_beta_lumo` | Beta LUMO energy |
93
+
94
+ ### Group 3 — Orbital energies (PubChemQC B3LYP/PM6)
95
+
96
+ 50 linearly sampled energies (`orbital_0` … `orbital_49`) spanning each molecule's full orbital spectrum, predicted at the sequence level.
97
+
98
+ ### Group 4 — Atom Löwdin charges (PubChemQC B3LYP/PM6)
99
+
100
+ Up to 1023 partial charges (`lowdin_0` … `lowdin_1022`), one per atom, predicted using each atom's corresponding output token embedding. This head covers well beyond the maximum number of atoms observed in the dataset.
101
+
102
+ ## Dataset
103
+
104
+ The model is pretrained on a processed version of the
105
+ [PubChemQC B3LYP/PM6 dataset](https://nakatamaho.riken.jp/pubchemqc.riken.jp/b3lyp_pm6_datasets.html).
106
+ The raw database exposes a `b3lyp_pm6` table (columns: `cid`, `state`, `data` as JSON). Data was extracted,
107
+ invalid SMILES removed, relevant features selected, and saved in compressed HDF5 format. Duplicate
108
+ SMILES were intentionally retained to allow the model to encounter molecules with multiple conformers
109
+ and learn a soft compromise across them. This trades auxiliary-task accuracy for richer structural
110
+ representations. Molecules incompatible with strict SELFIES encoding were discarded.
111
+
112
+ The processed dataset contains **82,686,706 SMILES sequences**, each paired with a full set of labels across all tasks. It is split by scaffold:
113
+
114
+ | Split | Sequences | Tokens (approx.) |
115
+ |---|---|---|
116
+ | Train | 66,149,364 | ~2.5 B (×2 with augmentation → ~5 B) |
117
+ | Validation | 8,268,673 | — |
118
+ | Test | 8,268,669 | ~ 0.82 B (×2 with augmentation → ~1.64 B) |
119
+
120
+ Training augmentation generates randomized SELFIES on the fly from each SMILES. Labels are normalized before training.
121
+
122
+ The HDF5 files are available for download below. These are intended to be processed with the bundled `data_processing` library into LMDB datasets optimised for fast training throughput; the resulting LMDB files are too large to distribute directly.
123
+
124
+ | Split | Download |
125
+ |---|---|
126
+ | Train | [train.h5](#) |
127
+ | Validation | [validation.h5](#) |
128
+ | Test | [test.h5](#) |
129
+
130
+ ## Limitations
131
+
132
+ - **Token length:** The built-in `prepare_data` helper encodes pairwise molecular-graph distances in an `int16` matrix. Consequently, molecules whose SELFIES tokenization exceeds **32,767 tokens** (`numpy.iinfo(numpy.int16).max`) are not supported. In practice, no molecule in the training dataset approaches this limit.
133
+ - **Conformer handling:** Duplicate SMILES representing different conformers are kept in the dataset. The model therefore predicts an implicit average over conformers rather than a geometry-specific value, which may reduce accuracy for conformation-sensitive properties.
134
+ - **Scope:** The model is pretrained on organic molecules present in PubChemQC. Performance on inorganic compounds, organometallics, or very large macromolecules outside the training distribution has not been evaluated.
common.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from torch import nn
5
+ from transformers.models.t5.configuration_t5 import T5Config
6
+
7
+
8
+ class M5Pooler(nn.Module):
9
+ def __init__(self, config: T5Config):
10
+ super().__init__()
11
+ self.pool_weights = nn.Parameter(torch.tensor([0.5, 0.5]))
12
+ self.pad_token_id = config.pad_token_id
13
+
14
+ def forward(self, input_ids: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
15
+
16
+ mask = (input_ids[:, 1:] != self.pad_token_id).unsqueeze(-1).float() # [batch, seq_len, 1]
17
+ atoms = hidden_states[:, 1:, :]
18
+
19
+ # Zero out padding token embeddings
20
+ masked_embedded = atoms * mask # [batch, seq_len, hidden_dim]
21
+
22
+ # Sum and divide by number of real tokens
23
+ sum_embedded = masked_embedded.sum(dim=1) # [batch, hidden_dim]
24
+ num_real_tokens = mask.sum(dim=1).clamp(min=1e-9) # [batch, 1], avoid division by zero
25
+ mean_pool = sum_embedded / num_real_tokens # [batch, hidden_dim]
26
+
27
+ cls_token = hidden_states[:, 0, :]
28
+
29
+ # Learned weights for weighted average between CLS and non CLS tokens
30
+ weights = F.softmax(self.pool_weights, dim=0)
31
+
32
+ pooled = weights[0] * mean_pool + weights[1] * cls_token
33
+ return pooled
config.json ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "M5ModelForRegression"
4
+ ],
5
+ "classifier_dropout": 0,
6
+ "d_ff": 2048,
7
+ "d_kv": 64,
8
+ "d_model": 512,
9
+ "dense_act_fn": "gelu_new",
10
+ "dropout_rate": 0,
11
+ "eos_token_id": 1,
12
+ "feed_forward_proj": "gated-gelu",
13
+ "id2label": {
14
+ "0": "LABEL_0",
15
+ "1": "LABEL_1",
16
+ "2": "LABEL_2",
17
+ "3": "LABEL_3",
18
+ "4": "LABEL_4",
19
+ "5": "LABEL_5",
20
+ "6": "LABEL_6",
21
+ "7": "LABEL_7",
22
+ "8": "LABEL_8",
23
+ "9": "LABEL_9",
24
+ "10": "LABEL_10",
25
+ "11": "LABEL_11",
26
+ "12": "LABEL_12",
27
+ "13": "LABEL_13",
28
+ "14": "LABEL_14",
29
+ "15": "LABEL_15",
30
+ "16": "LABEL_16",
31
+ "17": "LABEL_17",
32
+ "18": "LABEL_18",
33
+ "19": "LABEL_19",
34
+ "20": "LABEL_20",
35
+ "21": "LABEL_21",
36
+ "22": "LABEL_22",
37
+ "23": "LABEL_23",
38
+ "24": "LABEL_24",
39
+ "25": "LABEL_25",
40
+ "26": "LABEL_26",
41
+ "27": "LABEL_27",
42
+ "28": "LABEL_28",
43
+ "29": "LABEL_29",
44
+ "30": "LABEL_30",
45
+ "31": "LABEL_31",
46
+ "32": "LABEL_32",
47
+ "33": "LABEL_33",
48
+ "34": "LABEL_34",
49
+ "35": "LABEL_35",
50
+ "36": "LABEL_36",
51
+ "37": "LABEL_37",
52
+ "38": "LABEL_38",
53
+ "39": "LABEL_39",
54
+ "40": "LABEL_40",
55
+ "41": "LABEL_41",
56
+ "42": "LABEL_42",
57
+ "43": "LABEL_43",
58
+ "44": "LABEL_44",
59
+ "45": "LABEL_45",
60
+ "46": "LABEL_46",
61
+ "47": "LABEL_47",
62
+ "48": "LABEL_48",
63
+ "49": "LABEL_49",
64
+ "50": "LABEL_50",
65
+ "51": "LABEL_51",
66
+ "52": "LABEL_52",
67
+ "53": "LABEL_53",
68
+ "54": "LABEL_54",
69
+ "55": "LABEL_55",
70
+ "56": "LABEL_56",
71
+ "57": "LABEL_57",
72
+ "58": "LABEL_58",
73
+ "59": "LABEL_59",
74
+ "60": "LABEL_60",
75
+ "61": "LABEL_61"
76
+ },
77
+ "initializer_factor": 1.0,
78
+ "is_encoder_decoder": false,
79
+ "is_gated_act": true,
80
+ "label2id": {
81
+ "LABEL_0": 0,
82
+ "LABEL_1": 1,
83
+ "LABEL_10": 10,
84
+ "LABEL_11": 11,
85
+ "LABEL_12": 12,
86
+ "LABEL_13": 13,
87
+ "LABEL_14": 14,
88
+ "LABEL_15": 15,
89
+ "LABEL_16": 16,
90
+ "LABEL_17": 17,
91
+ "LABEL_18": 18,
92
+ "LABEL_19": 19,
93
+ "LABEL_2": 2,
94
+ "LABEL_20": 20,
95
+ "LABEL_21": 21,
96
+ "LABEL_22": 22,
97
+ "LABEL_23": 23,
98
+ "LABEL_24": 24,
99
+ "LABEL_25": 25,
100
+ "LABEL_26": 26,
101
+ "LABEL_27": 27,
102
+ "LABEL_28": 28,
103
+ "LABEL_29": 29,
104
+ "LABEL_3": 3,
105
+ "LABEL_30": 30,
106
+ "LABEL_31": 31,
107
+ "LABEL_32": 32,
108
+ "LABEL_33": 33,
109
+ "LABEL_34": 34,
110
+ "LABEL_35": 35,
111
+ "LABEL_36": 36,
112
+ "LABEL_37": 37,
113
+ "LABEL_38": 38,
114
+ "LABEL_39": 39,
115
+ "LABEL_4": 4,
116
+ "LABEL_40": 40,
117
+ "LABEL_41": 41,
118
+ "LABEL_42": 42,
119
+ "LABEL_43": 43,
120
+ "LABEL_44": 44,
121
+ "LABEL_45": 45,
122
+ "LABEL_46": 46,
123
+ "LABEL_47": 47,
124
+ "LABEL_48": 48,
125
+ "LABEL_49": 49,
126
+ "LABEL_5": 5,
127
+ "LABEL_50": 50,
128
+ "LABEL_51": 51,
129
+ "LABEL_52": 52,
130
+ "LABEL_53": 53,
131
+ "LABEL_54": 54,
132
+ "LABEL_55": 55,
133
+ "LABEL_56": 56,
134
+ "LABEL_57": 57,
135
+ "LABEL_58": 58,
136
+ "LABEL_59": 59,
137
+ "LABEL_6": 6,
138
+ "LABEL_60": 60,
139
+ "LABEL_61": 61,
140
+ "LABEL_7": 7,
141
+ "LABEL_8": 8,
142
+ "LABEL_9": 9
143
+ },
144
+ "layer_norm_epsilon": 1e-06,
145
+ "model_type": "m5_model",
146
+ "num_decoder_layers": 24,
147
+ "num_heads": 12,
148
+ "num_layers": 24,
149
+ "pad_token_id": 2,
150
+ "relative_attention_max_distance": 96,
151
+ "relative_attention_num_buckets": 32,
152
+ "torch_dtype": "float32",
153
+ "transformers_version": "4.51.3",
154
+ "use_cache": false,
155
+ "vocab_size": 1032
156
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eac7062b1d66d0ad63fff0f71e8f86d7cc86397d1c6783ee3099bcaf1237027d
3
+ size 497310076
modeling_m5_encoder.py ADDED
@@ -0,0 +1,922 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import math
4
+ import logging
5
+
6
+ from typing import Optional, Union
7
+ import torch.nn as nn
8
+ from transformers import PreTrainedModel, T5EncoderModel, T5ForConditionalGeneration, T5ForQuestionAnswering, T5ForTokenClassification, T5Model, load_tf_weights_in_t5
9
+ from torch import nn
10
+ from transformers.models.t5.modeling_t5 import T5Attention, T5DenseActDense, T5DenseGatedActDense, T5ClassificationHead, T5LayerNorm, T5Stack, T5Block, T5LayerSelfAttention, T5LayerFF
11
+ from transformers.cache_utils import DynamicCache, EncoderDecoderCache
12
+ from transformers.models.t5.configuration_t5 import T5Config
13
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutput
14
+ from transformers.utils import DUMMY_INPUTS, DUMMY_MASK, is_torch_fx_proxy, is_torchdynamo_compiling
15
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutput
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+ from .common import M5Pooler
18
+ from .prepare_data import get_positional_encodings_and_align
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ class M5EncoderConfig(T5Config):
23
+ model_type = "m5_model"
24
+
25
+ def __init__(
26
+ self,
27
+ d_ff= 2048,
28
+ d_kv = 64,
29
+ d_model = 512,
30
+ num_layers = 24,
31
+ num_heads = 12,
32
+ pad_token_id = 2,
33
+ dropout_rate = 0,
34
+ feed_forward_proj = "gated-gelu",
35
+ classifier_dropout=0,
36
+ relative_attention_max_distance=128,
37
+ relative_attention_num_buckets=48,
38
+ vocab_size=1032,
39
+ **kwargs,
40
+ ):
41
+ super().__init__(d_ff=d_ff,
42
+ d_kv=d_kv,
43
+ d_model=d_model,
44
+ num_layers=num_layers,
45
+ num_heads=num_heads,
46
+ pad_token_id=pad_token_id,
47
+ dropout_rate=dropout_rate,
48
+ feed_forward_proj=feed_forward_proj,
49
+ classifier_dropout=classifier_dropout,
50
+ relative_attention_max_distance=relative_attention_max_distance,
51
+ relative_attention_num_buckets=relative_attention_num_buckets,
52
+ vocab_size=vocab_size,
53
+ **kwargs)
54
+
55
+ class M5Encoder(PreTrainedModel):
56
+ config_class = M5EncoderConfig
57
+
58
+ def __init__(self, config):
59
+ super().__init__(config)
60
+ self.model = M5EncoderModel(config)
61
+ #self.model = torch.compile(self.model, mode="max-autotune", fullgraph=True)
62
+
63
+ def forward(self, input_ids, attention_mask=None, relative_position=None, **kwargs):
64
+ return self.model(input_ids=input_ids,
65
+ attention_mask=attention_mask,
66
+ relative_position=relative_position)
67
+
68
+ def get_positional_embeddings_and_align(self, smiles, token_regr, seed):
69
+ return get_positional_encodings_and_align(smiles, token_regr, seed)
70
+
71
+ class M5EncoderModel(T5EncoderModel):
72
+ def __init__(self, config: T5Config):
73
+ super().__init__(config)
74
+
75
+ encoder_config = config
76
+ encoder_config.use_cache = False
77
+ encoder_config.is_encoder_decoder = False
78
+ self.encoder = M5Stack(encoder_config, self.shared)
79
+
80
+ # Initialize weights and apply final processing
81
+ self.post_init()
82
+
83
+ def forward(
84
+ self,
85
+ input_ids: Optional[torch.LongTensor] = None,
86
+ attention_mask: Optional[torch.FloatTensor] = None,
87
+ head_mask: Optional[torch.FloatTensor] = None,
88
+ inputs_embeds: Optional[torch.FloatTensor] = None,
89
+ output_attentions: Optional[bool] = None,
90
+ output_hidden_states: Optional[bool] = None,
91
+ return_dict: Optional[bool] = None,
92
+ relative_position: Optional[torch.LongTensor] = None
93
+ ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
94
+ r"""
95
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
96
+ Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
97
+ should be able to pad the inputs on both the right and the left.
98
+
99
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
100
+ [`PreTrainedTokenizer.__call__`] for detail.
101
+
102
+ To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
103
+
104
+ Example:
105
+
106
+ ```python
107
+ >>> from transformers import AutoTokenizer, T5EncoderModel
108
+
109
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
110
+ >>> model = T5EncoderModel.from_pretrained("google-t5/t5-small")
111
+ >>> input_ids = tokenizer(
112
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
113
+ ... ).input_ids # Batch size 1
114
+ >>> outputs = model(input_ids=input_ids)
115
+ >>> last_hidden_states = outputs.last_hidden_state
116
+ ```"""
117
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
118
+
119
+ encoder_outputs = self.encoder(
120
+ input_ids=input_ids,
121
+ attention_mask=attention_mask,
122
+ inputs_embeds=inputs_embeds,
123
+ head_mask=head_mask,
124
+ output_attentions=output_attentions,
125
+ output_hidden_states=output_hidden_states,
126
+ return_dict=return_dict,
127
+ relative_position=relative_position
128
+ )
129
+
130
+ return encoder_outputs
131
+
132
+ class M5Stack(T5Stack):
133
+ def __init__(self, config, embed_tokens=None):
134
+ super().__init__(config, embed_tokens)
135
+
136
+ self.block = nn.ModuleList(
137
+ [M5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)]
138
+ )
139
+
140
+ # Initialize weights and apply final processing
141
+ self.post_init()
142
+
143
+ def forward(
144
+ self,
145
+ input_ids=None,
146
+ attention_mask=None,
147
+ encoder_hidden_states=None,
148
+ encoder_attention_mask=None,
149
+ inputs_embeds=None,
150
+ head_mask=None,
151
+ cross_attn_head_mask=None,
152
+ past_key_values=None,
153
+ use_cache=None,
154
+ output_attentions=None,
155
+ output_hidden_states=None,
156
+ return_dict=None,
157
+ cache_position=None,
158
+ relative_position=None
159
+ ):
160
+ # Model parallel
161
+ if self.model_parallel:
162
+ torch.cuda.set_device(self.first_device)
163
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
164
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
165
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
166
+ output_hidden_states = (
167
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
168
+ )
169
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
170
+
171
+ if input_ids is not None and inputs_embeds is not None:
172
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
173
+ raise ValueError(
174
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
175
+ )
176
+ elif input_ids is not None:
177
+ input_shape = input_ids.size()
178
+ input_ids = input_ids.view(-1, input_shape[-1])
179
+ elif inputs_embeds is not None:
180
+ input_shape = inputs_embeds.size()[:-1]
181
+ else:
182
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
183
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
184
+
185
+ if self.gradient_checkpointing and self.training:
186
+ if use_cache:
187
+ logger.warning_once(
188
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
189
+ )
190
+ use_cache = False
191
+
192
+ if inputs_embeds is None:
193
+ if self.embed_tokens is None:
194
+ raise ValueError("You have to initialize the model with valid token embeddings")
195
+ inputs_embeds = self.embed_tokens(input_ids)
196
+
197
+ batch_size, seq_length = input_shape
198
+
199
+ if use_cache is True:
200
+ if not self.is_decoder:
201
+ raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
202
+
203
+ if self.is_decoder:
204
+ if use_cache and past_key_values is None:
205
+ if self.config.is_encoder_decoder:
206
+ past_key_values = EncoderDecoderCache(
207
+ DynamicCache(config=self.config), DynamicCache(config=self.config)
208
+ )
209
+ else:
210
+ past_key_values = DynamicCache(config=self.config)
211
+ elif not self.is_decoder:
212
+ # do not pass cache object down the line for encoder stack
213
+ # it messes indexing later in decoder-stack because cache object is modified in-place
214
+ past_key_values = None
215
+
216
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
217
+ if cache_position is None:
218
+ cache_position = torch.arange(
219
+ past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
220
+ )
221
+
222
+ if attention_mask is None and not is_torchdynamo_compiling():
223
+ # required mask seq length can be calculated via length of past cache
224
+ mask_seq_length = past_key_values_length + seq_length
225
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
226
+
227
+ if self.config.is_decoder:
228
+ causal_mask = self._update_causal_mask(
229
+ attention_mask,
230
+ inputs_embeds,
231
+ cache_position,
232
+ past_key_values.self_attention_cache
233
+ if isinstance(past_key_values, EncoderDecoderCache)
234
+ else past_key_values,
235
+ output_attentions,
236
+ )
237
+ elif attention_mask is not None:
238
+ causal_mask = attention_mask[:, None, None, :]
239
+ causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
240
+ causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
241
+ else:
242
+ causal_mask = None
243
+
244
+ # If a 2D or 3D attention mask is provided for the cross-attention
245
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
246
+ if self.is_decoder and encoder_hidden_states is not None:
247
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
248
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
249
+ if encoder_attention_mask is None:
250
+ encoder_attention_mask = torch.ones(
251
+ encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
252
+ )
253
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
254
+ else:
255
+ encoder_extended_attention_mask = None
256
+
257
+ # Prepare head mask if needed
258
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
259
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
260
+ all_hidden_states = () if output_hidden_states else None
261
+ all_attentions = () if output_attentions else None
262
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
263
+ position_bias = None
264
+ encoder_decoder_position_bias = None
265
+
266
+ hidden_states = self.dropout(inputs_embeds)
267
+
268
+ for i, layer_module in enumerate(self.block):
269
+ layer_head_mask = head_mask[i]
270
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
271
+ # Model parallel
272
+ if self.model_parallel:
273
+ torch.cuda.set_device(hidden_states.device)
274
+ # Ensure that attention_mask is always on the same device as hidden_states
275
+ if causal_mask is not None:
276
+ causal_mask = causal_mask.to(hidden_states.device)
277
+ if position_bias is not None:
278
+ position_bias = position_bias.to(hidden_states.device)
279
+ if encoder_hidden_states is not None:
280
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
281
+ if encoder_extended_attention_mask is not None:
282
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
283
+ if encoder_decoder_position_bias is not None:
284
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
285
+ if layer_head_mask is not None:
286
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
287
+ if cross_attn_layer_head_mask is not None:
288
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
289
+ if output_hidden_states:
290
+ all_hidden_states = all_hidden_states + (hidden_states,)
291
+
292
+ layer_outputs = layer_module(
293
+ hidden_states,
294
+ causal_mask,
295
+ position_bias,
296
+ encoder_hidden_states,
297
+ encoder_extended_attention_mask,
298
+ encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
299
+ layer_head_mask=layer_head_mask,
300
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
301
+ past_key_values=past_key_values,
302
+ use_cache=use_cache,
303
+ output_attentions=output_attentions,
304
+ return_dict=return_dict,
305
+ cache_position=cache_position,
306
+ relative_position=relative_position
307
+ )
308
+
309
+ hidden_states = layer_outputs[0]
310
+
311
+ # We share the position biases between the layers - the first layer store them
312
+ # layer_outputs = hidden-states, key-valPilot phaseue-states (self-attention position bias), (self-attention weights),
313
+ # (cross-attention position bias), (cross-attention weights)
314
+ position_bias = layer_outputs[1]
315
+ if self.is_decoder and encoder_hidden_states is not None:
316
+ encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
317
+
318
+ if output_attentions:
319
+ all_attentions = all_attentions + (layer_outputs[2],)
320
+ if self.is_decoder:
321
+ all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
322
+
323
+ # Model Parallel: If it's the last layer for that device, put things on the next device
324
+ if self.model_parallel:
325
+ for k, v in self.device_map.items():
326
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
327
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
328
+
329
+ hidden_states = self.final_layer_norm(hidden_states)
330
+ hidden_states = self.dropout(hidden_states)
331
+
332
+ # Add last layer
333
+ if output_hidden_states:
334
+ all_hidden_states = all_hidden_states + (hidden_states,)
335
+
336
+ if not return_dict:
337
+ return tuple(
338
+ v
339
+ for v in [
340
+ hidden_states,
341
+ past_key_values,
342
+ all_hidden_states,
343
+ all_attentions,
344
+ all_cross_attentions,
345
+ ]
346
+ if v is not None
347
+ )
348
+ return BaseModelOutputWithPastAndCrossAttentions(
349
+ last_hidden_state=hidden_states,
350
+ past_key_values=past_key_values,
351
+ hidden_states=all_hidden_states,
352
+ attentions=all_attentions,
353
+ cross_attentions=all_cross_attentions,
354
+ )
355
+
356
+ class M5Block(T5Block):
357
+ def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
358
+ super().__init__(config, has_relative_attention_bias, layer_idx)
359
+ self.layer = nn.ModuleList()
360
+ self.layer.append(
361
+ M5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
362
+ )
363
+ if self.is_decoder:
364
+ self.layer.append(M5LayerSelfAttention(config, layer_idx=layer_idx))
365
+ self.layer.append(T5LayerFF(config))
366
+
367
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
368
+ def forward(
369
+ self,
370
+ hidden_states,
371
+ attention_mask=None,
372
+ position_bias=None,
373
+ encoder_hidden_states=None,
374
+ encoder_attention_mask=None,
375
+ encoder_decoder_position_bias=None,
376
+ layer_head_mask=None,
377
+ cross_attn_layer_head_mask=None,
378
+ past_key_values=None,
379
+ use_cache=False,
380
+ output_attentions=False,
381
+ return_dict=True,
382
+ cache_position=None,
383
+ relative_position=None,
384
+ ):
385
+ self_attention_outputs = self.layer[0](
386
+ hidden_states,
387
+ attention_mask=attention_mask,
388
+ position_bias=position_bias,
389
+ layer_head_mask=layer_head_mask,
390
+ past_key_values=past_key_values,
391
+ use_cache=use_cache,
392
+ output_attentions=output_attentions,
393
+ cache_position=cache_position,
394
+ relative_position=relative_position
395
+ )
396
+ hidden_states = self_attention_outputs[0]
397
+ attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
398
+
399
+ # clamp inf values to enable fp16 training
400
+ if hidden_states.dtype == torch.float16:
401
+ clamp_value = torch.where(
402
+ torch.isinf(hidden_states).any(),
403
+ torch.finfo(hidden_states.dtype).max - 1000,
404
+ torch.finfo(hidden_states.dtype).max,
405
+ )
406
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
407
+
408
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
409
+ if do_cross_attention:
410
+ cross_attention_outputs = self.layer[1](
411
+ hidden_states,
412
+ key_value_states=encoder_hidden_states,
413
+ attention_mask=encoder_attention_mask,
414
+ position_bias=encoder_decoder_position_bias,
415
+ layer_head_mask=cross_attn_layer_head_mask,
416
+ past_key_values=past_key_values,
417
+ query_length=cache_position[-1] + 1,
418
+ use_cache=use_cache,
419
+ output_attentions=output_attentions,
420
+ )
421
+ hidden_states = cross_attention_outputs[0]
422
+
423
+ # clamp inf values to enable fp16 training
424
+ if hidden_states.dtype == torch.float16:
425
+ clamp_value = torch.where(
426
+ torch.isinf(hidden_states).any(),
427
+ torch.finfo(hidden_states.dtype).max - 1000,
428
+ torch.finfo(hidden_states.dtype).max,
429
+ )
430
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
431
+
432
+ # Keep cross-attention outputs and relative position weights
433
+ attention_outputs = attention_outputs + cross_attention_outputs[1:]
434
+
435
+ # Apply Feed Forward layer
436
+ hidden_states = self.layer[-1](hidden_states)
437
+
438
+ # clamp inf values to enable fp16 training
439
+ if hidden_states.dtype == torch.float16:
440
+ clamp_value = torch.where(
441
+ torch.isinf(hidden_states).any(),
442
+ torch.finfo(hidden_states.dtype).max - 1000,
443
+ torch.finfo(hidden_states.dtype).max,
444
+ )
445
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
446
+
447
+ outputs = (hidden_states,)
448
+
449
+ return (
450
+ outputs + attention_outputs
451
+ ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
452
+
453
+ class M5LayerSelfAttention(T5LayerSelfAttention):
454
+ def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
455
+ super().__init__(config, has_relative_attention_bias, layer_idx)
456
+ self.SelfAttention = M5Attention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
457
+
458
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
459
+ def forward(
460
+ self,
461
+ hidden_states,
462
+ attention_mask=None,
463
+ position_bias=None,
464
+ layer_head_mask=None,
465
+ past_key_values=None,
466
+ use_cache=False,
467
+ output_attentions=False,
468
+ cache_position=None,
469
+ relative_position=None,
470
+ ):
471
+
472
+ normed_hidden_states = self.layer_norm(hidden_states)
473
+ attention_output = self.SelfAttention(
474
+ normed_hidden_states,
475
+ mask=attention_mask,
476
+ position_bias=position_bias,
477
+ layer_head_mask=layer_head_mask,
478
+ past_key_values=past_key_values,
479
+ use_cache=use_cache,
480
+ output_attentions=output_attentions,
481
+ cache_position=cache_position,
482
+ relative_position=relative_position
483
+ )
484
+ hidden_states = hidden_states + self.dropout(attention_output[0])
485
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
486
+ return outputs
487
+
488
+ class M5Attention(T5Attention):
489
+ """
490
+ def __init__(
491
+ self,
492
+ config: T5Config,
493
+ has_relative_attention_bias=False,
494
+ layer_idx: Optional[int] = None,
495
+ ):
496
+ super().__init__(config, has_relative_attention_bias, layer_idx)
497
+
498
+ if self.has_relative_attention_bias:
499
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
500
+ else:
501
+ self.elaborate = nn.Linear()
502
+
503
+ """
504
+
505
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
506
+ def forward(
507
+ self,
508
+ hidden_states,
509
+ mask=None,
510
+ key_value_states=None,
511
+ position_bias=None,
512
+ past_key_values=None,
513
+ layer_head_mask=None,
514
+ query_length=None,
515
+ use_cache=False,
516
+ output_attentions=False,
517
+ cache_position=None,
518
+ relative_position=None
519
+
520
+ ):
521
+ """
522
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
523
+ """
524
+ # Input is (batch_size, seq_length, dim)
525
+ # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
526
+ batch_size, seq_length = hidden_states.shape[:2]
527
+
528
+ # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
529
+ is_cross_attention = key_value_states is not None
530
+
531
+ query_states = self.q(hidden_states)
532
+ query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
533
+
534
+ # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
535
+ is_updated = False
536
+ if isinstance(past_key_values, EncoderDecoderCache):
537
+ is_updated = past_key_values.is_updated.get(self.layer_idx)
538
+ if is_cross_attention:
539
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
540
+ curr_past_key_value = past_key_values.cross_attention_cache
541
+ else:
542
+ curr_past_key_value = past_key_values.self_attention_cache
543
+ else:
544
+ curr_past_key_value = past_key_values
545
+
546
+ current_states = key_value_states if is_cross_attention else hidden_states
547
+ if is_cross_attention and past_key_values is not None and is_updated:
548
+ # reuse k,v, cross_attentions
549
+ key_states = curr_past_key_value.layers[self.layer_idx].keys
550
+ value_states = curr_past_key_value.layers[self.layer_idx].values
551
+ else:
552
+ key_states = self.k(current_states)
553
+ value_states = self.v(current_states)
554
+ key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
555
+ value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
556
+
557
+ if past_key_values is not None:
558
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
559
+ cache_position = cache_position if not is_cross_attention else None
560
+ key_states, value_states = curr_past_key_value.update(
561
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
562
+ )
563
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
564
+ if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
565
+ past_key_values.is_updated[self.layer_idx] = True
566
+
567
+ # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
568
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
569
+
570
+ if position_bias is None:
571
+ key_length = key_states.shape[-2]
572
+ # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
573
+ real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
574
+ if not self.has_relative_attention_bias:
575
+ position_bias = torch.zeros(
576
+ (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
577
+ )
578
+ if self.gradient_checkpointing and self.training:
579
+ position_bias.requires_grad = True
580
+ else:
581
+ position_bias = self.compute_bias(
582
+ real_seq_length, key_length, device=scores.device, cache_position=cache_position, relative_position=relative_position
583
+ )
584
+ position_bias = position_bias[:, :, -seq_length:, :]
585
+
586
+ if mask is not None:
587
+ causal_mask = mask[:, :, :, : key_states.shape[-2]]
588
+ position_bias = position_bias + causal_mask
589
+
590
+ if self.pruned_heads:
591
+ mask = torch.ones(position_bias.shape[1])
592
+ mask[list(self.pruned_heads)] = 0
593
+ position_bias_masked = position_bias[:, mask.bool()]
594
+ else:
595
+ position_bias_masked = position_bias
596
+
597
+ scores += position_bias_masked
598
+
599
+ # (batch_size, n_heads, seq_length, key_length)
600
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
601
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
602
+
603
+ # Mask heads if we want to
604
+ if layer_head_mask is not None:
605
+ attn_weights = attn_weights * layer_head_mask
606
+
607
+ attn_output = torch.matmul(attn_weights, value_states)
608
+
609
+ attn_output = attn_output.transpose(1, 2).contiguous()
610
+ attn_output = attn_output.view(batch_size, -1, self.inner_dim)
611
+ attn_output = self.o(attn_output)
612
+
613
+ outputs = (attn_output, position_bias)
614
+
615
+ if output_attentions:
616
+ outputs = outputs + (attn_weights,)
617
+ return outputs
618
+
619
+ @staticmethod
620
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
621
+ """
622
+ Adapted from Mesh Tensorflow:
623
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
624
+
625
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
626
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
627
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
628
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
629
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
630
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
631
+
632
+ Args:
633
+ relative_position: an int32 Tensor
634
+ bidirectional: a boolean - whether the attention is bidirectional
635
+ num_buckets: an integer
636
+ max_distance: an integer
637
+
638
+ Returns:
639
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
640
+ """
641
+ # Make all positions positive, effectively using the non-bidirectional path
642
+ # However, it uses positive distances instead of negative
643
+ relative_position = relative_position + 1
644
+ relative_position = torch.max(relative_position, torch.zeros_like(relative_position))
645
+
646
+ # half of the buckets are for exact increments in positions
647
+ max_exact = num_buckets // 2
648
+ is_small = relative_position < max_exact
649
+
650
+ num_log_buckets = num_buckets - max_exact - 1
651
+
652
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
653
+ relative_position_if_large = max_exact + (
654
+ torch.log(relative_position.float() / max_exact)
655
+ / math.log(max_distance / max_exact)
656
+ * (num_buckets - num_log_buckets)
657
+ ).to(torch.long)
658
+
659
+ relative_position_if_large = torch.min(
660
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 2)
661
+ )
662
+
663
+ relative_buckets = torch.where(is_small, relative_position, relative_position_if_large)
664
+
665
+ # The +1 is because we added 1 at the beginning (relative_position + 1).
666
+ # This special mask is the equivalent of +inf distance and is assigned
667
+ # to the last bucket.
668
+ special_mask = (relative_position == np.iinfo(np.int16).max+1)
669
+ relative_buckets[special_mask] = num_buckets-1
670
+
671
+ return relative_buckets
672
+
673
+ def compute_bias(self, query_length, key_length, device=None, cache_position=None, relative_position=None):
674
+ """Compute binned relative position bias"""
675
+ if device is None:
676
+ device = self.relative_attention_bias.weight.device
677
+
678
+ if relative_position is None:
679
+ if cache_position is None:
680
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
681
+ else:
682
+ context_position = cache_position[:, None].to(device)
683
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
684
+ relative_position = memory_position - context_position # shape (query_length, key_length)
685
+
686
+ # Removing relative_position calculation breaks cache_position but it's fine since the positions are precomputed anyways
687
+ relative_position_bucket = self._relative_position_bucket(
688
+ relative_position, # shape (query_length, key_length)
689
+ bidirectional=(not self.is_decoder),
690
+ num_buckets=self.relative_attention_num_buckets,
691
+ max_distance=self.relative_attention_max_distance,
692
+ )
693
+
694
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
695
+ values = values.permute([0, 3, 1, 2]) # shape (batch_size, num_heads, query_length, key_length)
696
+ return values
697
+
698
+ # RegressionHead for tasks froms groups 0, 1, 2 and 3
699
+ # Used as regression head and classification head for pretraining
700
+ class M5RegressionHead(nn.Module):
701
+ def __init__(self, config: T5Config):
702
+ super().__init__()
703
+
704
+ self.pooler = M5Pooler(config)
705
+ self.transform = nn.Linear(config.d_model, config.d_model)
706
+ if config.is_gated_act:
707
+ self.DenseReluDense = T5DenseGatedActDense(config)
708
+ else:
709
+ self.DenseReluDense = T5DenseActDense(config)
710
+ self.out_proj = nn.Linear(config.d_model, config.num_labels)
711
+
712
+ def forward(self, input_ids: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
713
+ pooled = self.pooler(input_ids, hidden_states)
714
+
715
+ pooled = self.transform(pooled)
716
+ pooled = self.DenseReluDense(pooled)
717
+ output = self.out_proj(pooled)
718
+
719
+ return output
720
+
721
+ # TokenRegressionHead for tasks from group 4
722
+ class M5TokenRegressionHead(nn.Module):
723
+ def __init__(self, config: T5Config):
724
+ super().__init__()
725
+
726
+ # Dimension is multiplied by 2 to account for CLS dimensional embeddings.
727
+ self.transform1 = nn.Linear(config.d_model*2, config.d_model)
728
+ if config.is_gated_act:
729
+ self.DenseReluDense1 = T5DenseGatedActDense(config)
730
+ else:
731
+ self.DenseReluDense1 = T5DenseActDense(config)
732
+
733
+ self.transform2 = nn.Linear(config.d_model, config.d_model)
734
+
735
+ if config.is_gated_act:
736
+ self.DenseReluDense2 = T5DenseGatedActDense(config)
737
+ else:
738
+ self.DenseReluDense2 = T5DenseActDense(config)
739
+
740
+ # The output has shape (num_batches, context_length, 1) because each token has a label
741
+
742
+ self.output = nn.Linear(config.d_model, 1)
743
+ self.config = config
744
+
745
+ def forward(self, token_hidden_states: torch.Tensor) -> torch.Tensor:
746
+ # Concatenate CLS token hidden states to each token hidden state
747
+
748
+ #hidden_states = torch.cat([token_hidden_states, cls_hidden_states], dim=-1)
749
+ cls_hidden = token_hidden_states[:, 0, :]
750
+ token_hidden = token_hidden_states[:, 1:, :]
751
+
752
+ cls_repeated = cls_hidden.unsqueeze(1).expand(-1, token_hidden.size(1), -1)
753
+ augmented_hidden = torch.cat([token_hidden, cls_repeated], dim=-1).contiguous()
754
+
755
+ transformed = self.transform1(augmented_hidden)
756
+ transformed = self.DenseReluDense1(transformed)
757
+ transformed = self.transform2(transformed)
758
+ transformed = self.DenseReluDense2(transformed)
759
+
760
+ output = self.output(transformed)
761
+ output = output.squeeze(-1)
762
+ # (batch_size, num_labels)
763
+ # NOTE: num_labels = seq_length
764
+ return output
765
+
766
+
767
+ class M5PreTrainedModel(PreTrainedModel):
768
+ """
769
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
770
+ models.
771
+ """
772
+
773
+ config_class = T5Config
774
+ load_tf_weights = load_tf_weights_in_t5
775
+ base_model_prefix = "transformer"
776
+ is_parallelizable = True
777
+ supports_gradient_checkpointing = True
778
+ _supports_quantized_cache = False # enc-dec models don't support yet
779
+ _supports_static_cache = True
780
+ _supports_cache_class = True
781
+ _no_split_modules = ["T5Block"]
782
+ _keep_in_fp32_modules = ["wo"]
783
+
784
+ @property
785
+ def dummy_inputs(self):
786
+ input_ids = torch.tensor(DUMMY_INPUTS)
787
+ input_mask = torch.tensor(DUMMY_MASK)
788
+ dummy_inputs = {
789
+ "decoder_input_ids": input_ids,
790
+ "input_ids": input_ids,
791
+ "decoder_attention_mask": input_mask,
792
+ }
793
+ return dummy_inputs
794
+
795
+ def _init_weights(self, module):
796
+ """Initialize the weights"""
797
+ factor = self.config.initializer_factor # Used for testing weights initialization
798
+ if isinstance(module, T5LayerNorm):
799
+ module.weight.data.fill_(factor * 1.0)
800
+ elif isinstance(
801
+ module,
802
+ (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5ForQuestionAnswering),
803
+ ):
804
+ # Mesh TensorFlow embeddings initialization
805
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
806
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
807
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
808
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
809
+ if hasattr(module, "qa_outputs"):
810
+ module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
811
+ module.qa_outputs.bias.data.zero_()
812
+ elif isinstance(module, T5ForTokenClassification):
813
+ if hasattr(module, "classifier"):
814
+ module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0)
815
+ module.classifier.bias.data.zero_()
816
+ elif isinstance(module, T5ClassificationHead):
817
+ module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
818
+ if hasattr(module.dense, "bias") and module.dense.bias is not None:
819
+ module.dense.bias.data.zero_()
820
+ module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
821
+ if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
822
+ module.out_proj.bias.data.zero_()
823
+ elif isinstance(module, T5DenseActDense):
824
+ # Mesh TensorFlow FF initialization
825
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
826
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
827
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
828
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
829
+ module.wi.bias.data.zero_()
830
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
831
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
832
+ module.wo.bias.data.zero_()
833
+ elif isinstance(module, T5DenseGatedActDense):
834
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
835
+ if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
836
+ module.wi_0.bias.data.zero_()
837
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
838
+ if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
839
+ module.wi_1.bias.data.zero_()
840
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
841
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
842
+ module.wo.bias.data.zero_()
843
+ elif isinstance(module, M5RegressionHead):
844
+ module.transform.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
845
+ if hasattr(module.transform, "bias") and module.transform.bias is not None:
846
+ module.transform.bias.data.zero_()
847
+ module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
848
+ if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
849
+ module.out_proj.bias.data.zero_()
850
+ elif isinstance(module, M5TokenRegressionHead):
851
+ module.transform1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model*2) ** -0.5))
852
+ module.transform1.bias.data.zero_()
853
+ module.transform2.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
854
+ module.transform2.bias.data.zero_()
855
+ module.output.weight.data.normal_(mean=0.0, std=factor * ((37.84) ** -0.5))
856
+ module.output.bias.data.zero_()
857
+
858
+ elif isinstance(module, T5Attention):
859
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
860
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
861
+ d_model = self.config.d_model
862
+ key_value_proj_dim = self.config.d_kv
863
+ n_heads = self.config.num_heads
864
+ module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
865
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
866
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
867
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
868
+ if module.has_relative_attention_bias:
869
+ module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
870
+
871
+ def _shift_right(self, input_ids):
872
+ decoder_start_token_id = self.config.decoder_start_token_id
873
+ pad_token_id = self.config.pad_token_id
874
+
875
+ if decoder_start_token_id is None:
876
+ raise ValueError(
877
+ "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
878
+ "See T5 docs for more information."
879
+ )
880
+
881
+ # shift inputs to the right
882
+ if is_torch_fx_proxy(input_ids):
883
+ # Item assignment is not supported natively for proxies.
884
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
885
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
886
+ else:
887
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
888
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
889
+ shifted_input_ids[..., 0] = decoder_start_token_id
890
+
891
+ if pad_token_id is None:
892
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
893
+ # replace possible -100 values in labels by `pad_token_id`
894
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
895
+
896
+ return shifted_input_ids
897
+
898
+
899
+ class M5ModelForRegression(M5PreTrainedModel):
900
+ config_class = M5EncoderConfig
901
+ model_type = "m5_model"
902
+
903
+ def __init__(
904
+ self,
905
+ config: T5Config):
906
+
907
+ super().__init__(config)
908
+ self.encoder: M5Encoder = M5Encoder(config)
909
+ self.token_reg_head: M5TokenRegressionHead = M5TokenRegressionHead(config)
910
+ self.reg_head: M5RegressionHead = M5RegressionHead(config)
911
+
912
+ self.init_weights()
913
+
914
+ def forward(self, input_ids, attention_mask=None, relative_position=None, **kwargs):
915
+ output = self.encoder(input_ids, attention_mask, relative_position=relative_position, **kwargs)
916
+ hidden_states = output.last_hidden_state
917
+
918
+ tokreg_head = self.token_reg_head(hidden_states)
919
+ reg_head = self.reg_head(input_ids, hidden_states)
920
+
921
+ concatenated_preds = torch.cat([reg_head, tokreg_head], dim=-1)
922
+ return concatenated_preds
prepare_data.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import selfies as sf
2
+ from rdkit import Chem
3
+ import ast
4
+ import numpy as np
5
+
6
+
7
+ # Get molecule old smiles to permuted smiles correspondence for token_regr
8
+ def __get_correspondence__(mol, epoch):
9
+ if epoch == 0:
10
+ new_smiles = Chem.MolToSmiles(mol, canonical=True)
11
+ else:
12
+ new_smiles = Chem.MolToRandomSmilesVect(mol, 1, randomSeed=epoch)[0]
13
+
14
+ output_order = mol.GetProp('_smilesAtomOutputOrder')
15
+ mapping = ast.literal_eval(output_order)
16
+
17
+ return new_smiles, mapping
18
+
19
+ # We already know the [Ring] token connects the token immediately before...
20
+
21
+ def get_ring_masks(mol, map_smiles_to_selfies, tokens):
22
+ # This is fine, atoms are given indices in the molecule based on the order they appear in the SMILES
23
+
24
+ Chem.FastFindRings(mol)
25
+
26
+ rings = mol.GetRingInfo().AtomRings()
27
+ ring_masks = []
28
+ for i, ring in enumerate(rings):
29
+ selfies_ring = map_smiles_to_selfies[list(ring)]
30
+ ring_idx = selfies_ring.max()+1
31
+ ring_masks.append((ring_idx, selfies_ring))
32
+ assert "Ring" in tokens[ring_idx]
33
+
34
+ return ring_masks
35
+
36
+
37
+ # Distances are set to 0 for the tokens in the molecules at the right and at the left of . tokens (except padding tokens)
38
+ def __get_attribution_mapping__(tokens):
39
+ special_token_masks = []
40
+ map_smiles_to_selfies = []
41
+ dots = []
42
+
43
+ idx = 1 # Start after [CLS]
44
+
45
+ while idx < len(tokens):
46
+ token = tokens[idx]
47
+
48
+ if token == ".":
49
+ dots.append(idx)
50
+ idx += 1
51
+ continue
52
+
53
+ branch_idx = token.find("Branch")
54
+ if branch_idx >= 0:
55
+ n = int(token[branch_idx + 6])
56
+ special_token_masks.append(np.arange(idx, idx + n + 1, dtype=np.int16))
57
+ idx += n + 1
58
+ continue
59
+ else:
60
+ ring_idx = token.find("Ring")
61
+ if ring_idx >= 0:
62
+ n = int(token[ring_idx + 4])
63
+ special_token_masks.append(np.arange(idx, idx + n + 1, dtype=np.int16))
64
+ idx += n + 1
65
+ continue
66
+
67
+ # Real (atom) token
68
+ map_smiles_to_selfies.append(idx)
69
+ idx += 1
70
+
71
+ # Existing dot_masks construction (unchanged)
72
+ dot_masks = []
73
+ last_dots = [-1]
74
+ for dot_idx in dots:
75
+ if len(last_dots) == 2:
76
+ val = last_dots.pop(0)
77
+ dot_masks.append([el for el in range(val + 1, dot_idx, 1)])
78
+ last_dots.append(dot_idx)
79
+
80
+ if len(dots) >= 1:
81
+ dot_masks.append([el for el in range(last_dots.pop(0) + 1, len(tokens), 1)])
82
+
83
+ return special_token_masks, np.array(map_smiles_to_selfies), list(zip(dots, dot_masks, strict=True))
84
+
85
+ def __get_positional_encodings__(mol, smiles_to_selfies, context_length, special_token_masks, double_masks, first_padding_token_idx):
86
+ ats = np.array(smiles_to_selfies, dtype=np.int64)
87
+ distance = Chem.GetDistanceMatrix(mol)
88
+
89
+ # Distance of encodings is capped at the int16 upper bound minus 1
90
+ # (because the int16 upper bound value is reserved for special distances)
91
+ limit = np.iinfo(np.int16).max
92
+ distance = np.minimum(distance, limit-1).astype(np.int16)
93
+
94
+ pos_encod = np.full((context_length, context_length), limit, dtype=np.int16)
95
+
96
+ # Set first row and column to 0 only for non-padding tokens (positions in ats)
97
+ pos_encod[0, :first_padding_token_idx] = 0
98
+ pos_encod[:first_padding_token_idx, 0] = 0
99
+
100
+ for m in special_token_masks:
101
+ pos_encod[m[:, None], m] = -1
102
+
103
+ for i, m in double_masks:
104
+ pos_encod[i, m] = 0
105
+ pos_encod[m, i] = 0
106
+
107
+ np.fill_diagonal(pos_encod, 0)
108
+
109
+ # Use advanced indexing for distance assignment
110
+ pos_encod[ats[:, None], ats] = distance
111
+
112
+ return pos_encod
113
+
114
+ def get_positional_encodings_and_align(smiles, token_regr, epoch):
115
+ orig_mol = Chem.MolFromSmiles(smiles, sanitize = False)
116
+
117
+ # Converts SMILES to the final SMILES so that the mapping is already correct for the token-level labels.
118
+ # Generates a predictable variation of the SMILES.
119
+ new_smiles, mapping_to_new = __get_correspondence__(orig_mol, epoch)
120
+
121
+ # Convert to SELFIES, simulate tokenization and add [CLS] token at the beginning
122
+ selfies = sf.encoder(new_smiles)
123
+ tokens = ["[CLS]"] + list(sf.split_selfies(selfies))
124
+
125
+ special_token_masks, map_smiles_to_selfies, dot_masks = __get_attribution_mapping__(tokens)
126
+
127
+ # Align token labels to SELFIES tokens
128
+ if token_regr is not None:
129
+ # Align token labels to the new SMILES
130
+ token_regr[:len(mapping_to_new)] = token_regr[mapping_to_new]
131
+
132
+ token_regr_selfies = np.full(len(tokens)-1, np.nan, dtype=token_regr.dtype)
133
+
134
+ valid = map_smiles_to_selfies < len(tokens)
135
+ token_regr_selfies[map_smiles_to_selfies[valid] - 1] = token_regr[:np.sum(valid)]
136
+ else:
137
+ token_regr_selfies = None
138
+
139
+ # Generate molecule from the new SMILES (remove sanitization to preserve the original structure)
140
+ mol = Chem.MolFromSmiles(new_smiles, sanitize = False)
141
+
142
+ ring_masks = get_ring_masks(mol, map_smiles_to_selfies, tokens)
143
+ double_masks = ring_masks + dot_masks
144
+ pos_encod = __get_positional_encodings__(mol, map_smiles_to_selfies, len(tokens), special_token_masks, double_masks, len(tokens))
145
+
146
+ return selfies, pos_encod, token_regr_selfies
147
+
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "unk_token": "[UNK]"
6
+ }
tokenizer.json ADDED
@@ -0,0 +1,1139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
5
+ "added_tokens": [
6
+ {
7
+ "id": 0,
8
+ "content": "[UNK]",
9
+ "single_word": false,
10
+ "lstrip": false,
11
+ "rstrip": false,
12
+ "normalized": false,
13
+ "special": true
14
+ },
15
+ {
16
+ "id": 1,
17
+ "content": "[CLS]",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": false,
22
+ "special": true
23
+ },
24
+ {
25
+ "id": 2,
26
+ "content": "[PAD]",
27
+ "single_word": false,
28
+ "lstrip": false,
29
+ "rstrip": false,
30
+ "normalized": false,
31
+ "special": true
32
+ },
33
+ {
34
+ "id": 3,
35
+ "content": "[MASK]",
36
+ "single_word": false,
37
+ "lstrip": false,
38
+ "rstrip": false,
39
+ "normalized": false,
40
+ "special": true
41
+ }
42
+ ],
43
+ "normalizer": {
44
+ "type": "Replace",
45
+ "pattern": {
46
+ "String": "\n"
47
+ },
48
+ "content": ""
49
+ },
50
+ "pre_tokenizer": {
51
+ "type": "Split",
52
+ "pattern": {
53
+ "Regex": "\\[.+?\\]|\\."
54
+ },
55
+ "behavior": "Isolated",
56
+ "invert": false
57
+ },
58
+ "post_processor": {
59
+ "type": "TemplateProcessing",
60
+ "single": [
61
+ {
62
+ "SpecialToken": {
63
+ "id": "[CLS]",
64
+ "type_id": 0
65
+ }
66
+ },
67
+ {
68
+ "Sequence": {
69
+ "id": "A",
70
+ "type_id": 0
71
+ }
72
+ }
73
+ ],
74
+ "pair": [
75
+ {
76
+ "Sequence": {
77
+ "id": "A",
78
+ "type_id": 0
79
+ }
80
+ },
81
+ {
82
+ "Sequence": {
83
+ "id": "B",
84
+ "type_id": 1
85
+ }
86
+ }
87
+ ],
88
+ "special_tokens": {
89
+ "[CLS]": {
90
+ "id": "[CLS]",
91
+ "ids": [
92
+ 1
93
+ ],
94
+ "tokens": [
95
+ "[CLS]"
96
+ ]
97
+ }
98
+ }
99
+ },
100
+ "decoder": null,
101
+ "model": {
102
+ "type": "WordLevel",
103
+ "vocab": {
104
+ "[UNK]": 0,
105
+ "[CLS]": 1,
106
+ "[PAD]": 2,
107
+ "[MASK]": 3,
108
+ "[C]": 4,
109
+ "[=C]": 5,
110
+ "[=Branch1]": 6,
111
+ "[Branch1]": 7,
112
+ "[Ring1]": 8,
113
+ "[N]": 9,
114
+ "[=O]": 10,
115
+ "[O]": 11,
116
+ "[Ring2]": 12,
117
+ "[=N]": 13,
118
+ "[Branch2]": 14,
119
+ "[F]": 15,
120
+ "[S]": 16,
121
+ "[=Branch2]": 17,
122
+ "[Cl]": 18,
123
+ "[#Branch2]": 19,
124
+ "[#Branch1]": 20,
125
+ "[C@@H1]": 21,
126
+ "[C@H1]": 22,
127
+ "[Br]": 23,
128
+ "[#C]": 24,
129
+ "[P]": 25,
130
+ "[/C]": 26,
131
+ "[O-1]": 27,
132
+ "[#N]": 28,
133
+ "[N+1]": 29,
134
+ ".": 30,
135
+ "[=S]": 31,
136
+ "[I]": 32,
137
+ "[C@@]": 33,
138
+ "[C@]": 34,
139
+ "[/N]": 35,
140
+ "[Si]": 36,
141
+ "[2H]": 37,
142
+ "[/O]": 38,
143
+ "[=N+1]": 39,
144
+ "[B]": 40,
145
+ "[/S]": 41,
146
+ "[=N-1]": 42,
147
+ "[Na+1]": 43,
148
+ "[Cl-1]": 44,
149
+ "[#C-1]": 45,
150
+ "[NH1+1]": 46,
151
+ "[BH0]": 47,
152
+ "[K+1]": 48,
153
+ "[Br-1]": 49,
154
+ "[S@@]": 50,
155
+ "[/C@H1]": 51,
156
+ "[S@]": 52,
157
+ "[P+1]": 53,
158
+ "[NH3+1]": 54,
159
+ "[/Cl]": 55,
160
+ "[/C@@H1]": 56,
161
+ "[Se]": 57,
162
+ "[NH2+1]": 58,
163
+ "[I-1]": 59,
164
+ "[C-1]": 60,
165
+ "[Li+1]": 61,
166
+ "[B-1]": 62,
167
+ "[#N+1]": 63,
168
+ "[3H]": 64,
169
+ "[/N+1]": 65,
170
+ "[N-1]": 66,
171
+ "[CH1]": 67,
172
+ "[H+1]": 68,
173
+ "[13C]": 69,
174
+ "[S-1]": 70,
175
+ "[CH2-1]": 71,
176
+ "[Mg+2]": 72,
177
+ "[P@@]": 73,
178
+ "[=P]": 74,
179
+ "[P@]": 75,
180
+ "[S+1]": 76,
181
+ "[/F]": 77,
182
+ "[/O-1]": 78,
183
+ "[As]": 79,
184
+ "[/Br]": 80,
185
+ "[SiH1]": 81,
186
+ "[18F]": 82,
187
+ "[NH4+1]": 83,
188
+ "[Al]": 84,
189
+ "[13CH2]": 85,
190
+ "[Ge]": 86,
191
+ "[Sn]": 87,
192
+ "[Ca+2]": 88,
193
+ "[13CH1]": 89,
194
+ "[OH1-1]": 90,
195
+ "[/I]": 91,
196
+ "[Zn+2]": 92,
197
+ "[/Si]": 93,
198
+ "[=13CH1]": 94,
199
+ "[=C-1]": 95,
200
+ "[Zn]": 96,
201
+ "[Na]": 97,
202
+ "[SiH2]": 98,
203
+ "[=NH1+1]": 99,
204
+ "[/-Ring1]": 100,
205
+ "[/P]": 101,
206
+ "[14C]": 102,
207
+ "[=13C]": 103,
208
+ "[Te]": 104,
209
+ "[13CH3]": 105,
210
+ "[H]": 106,
211
+ "[Li]": 107,
212
+ "[Mg]": 108,
213
+ "[CH1-1]": 109,
214
+ "[PH1+1]": 110,
215
+ "[=Se]": 111,
216
+ "[Zn+1]": 112,
217
+ "[SiH3]": 113,
218
+ "[/C@@]": 114,
219
+ "[/C@]": 115,
220
+ "[#P]": 116,
221
+ "[P-1]": 117,
222
+ "[15NH1]": 118,
223
+ "[=NH2+1]": 119,
224
+ "[PH3+1]": 120,
225
+ "[F-1]": 121,
226
+ "[CH0]": 122,
227
+ "[13C@@H1]": 123,
228
+ "[11CH3]": 124,
229
+ "[Ca]": 125,
230
+ "[15N]": 126,
231
+ "[13C@H1]": 127,
232
+ "[14CH1]": 128,
233
+ "[Cu+1]": 129,
234
+ "[14CH2]": 130,
235
+ "[15NH2]": 131,
236
+ "[NH1-1]": 132,
237
+ "[=14CH1]": 133,
238
+ "[125I]": 134,
239
+ "[=O+1]": 135,
240
+ "[Sb]": 136,
241
+ "[CH2]": 137,
242
+ "[SeH1]": 138,
243
+ "[SH2+1]": 139,
244
+ "[Ga]": 140,
245
+ "[11C]": 141,
246
+ "[=14C]": 142,
247
+ "[CH3-1]": 143,
248
+ "[14CH3]": 144,
249
+ "[=15N]": 145,
250
+ "[123I]": 146,
251
+ "[Al+1]": 147,
252
+ "[=Si]": 148,
253
+ "[=18O]": 149,
254
+ "[K]": 150,
255
+ "[Sn+2]": 151,
256
+ "[H-1]": 152,
257
+ "[OH0]": 153,
258
+ "[PH2+1]": 154,
259
+ "[OH2+1]": 155,
260
+ "[CH2+1]": 156,
261
+ "[/Se]": 157,
262
+ "[=CH0]": 158,
263
+ "[Se-1]": 159,
264
+ "[Al-1]": 160,
265
+ "[Sb-1]": 161,
266
+ "[O+1]": 162,
267
+ "[In]": 163,
268
+ "[C+1]": 164,
269
+ "[/S@]": 165,
270
+ "[N@+1]": 166,
271
+ "[Cu]": 167,
272
+ "[131I]": 168,
273
+ "[SnH1]": 169,
274
+ "[/S@@]": 170,
275
+ "[=CH1-1]": 171,
276
+ "[N@@+1]": 172,
277
+ "[1H]": 173,
278
+ "[18OH1]": 174,
279
+ "[GeH1]": 175,
280
+ "[=S@]": 176,
281
+ "[/P+1]": 177,
282
+ "[19F]": 178,
283
+ "[Al+3]": 179,
284
+ "[14C@H1]": 180,
285
+ "[As+1]": 181,
286
+ "[14C@@H1]": 182,
287
+ "[18O]": 183,
288
+ "[Si@@]": 184,
289
+ "[SnH2]": 185,
290
+ "[GeH3]": 186,
291
+ "[=S@@]": 187,
292
+ "[HH1]": 188,
293
+ "[Sn+1]": 189,
294
+ "[GeH2]": 190,
295
+ "[Si@]": 191,
296
+ "[#O+1]": 192,
297
+ "[CH1+1]": 193,
298
+ "[#S]": 194,
299
+ "[SnH3]": 195,
300
+ "[AsH1]": 196,
301
+ "[15N+1]": 197,
302
+ "[#NH1+1]": 198,
303
+ "[124I]": 199,
304
+ "[11CH2]": 200,
305
+ "[/-Ring2]": 201,
306
+ "[Al+2]": 202,
307
+ "[16OH1]": 203,
308
+ "[Si-1]": 204,
309
+ "[Ar]": 205,
310
+ "[/13CH1]": 206,
311
+ "[/2H]": 207,
312
+ "[13C@@]": 208,
313
+ "[PH1-1]": 209,
314
+ "[#15N]": 210,
315
+ "[/13C]": 211,
316
+ "[NH0]": 212,
317
+ "[13C@]": 213,
318
+ "[12C]": 214,
319
+ "[Ag]": 215,
320
+ "[BH3-1]": 216,
321
+ "[=C+1]": 217,
322
+ "[NH2-1]": 218,
323
+ "[Pd]": 219,
324
+ "[AsH2]": 220,
325
+ "[As-1]": 221,
326
+ "[=Te]": 222,
327
+ "[Ti]": 223,
328
+ "[Be+2]": 224,
329
+ "[PH4+1]": 225,
330
+ "[BH2-1]": 226,
331
+ "[#CH0]": 227,
332
+ "[=13CH2]": 228,
333
+ "[SH1+1]": 229,
334
+ "[32P]": 230,
335
+ "[/NH1+1]": 231,
336
+ "[=CH1]": 232,
337
+ "[35S]": 233,
338
+ "[/Te]": 234,
339
+ "[Be]": 235,
340
+ "[Ni+2]": 236,
341
+ "[SH1-1]": 237,
342
+ "[=17O]": 238,
343
+ "[/Ge]": 239,
344
+ "[11CH1]": 240,
345
+ "[/CH2-1]": 241,
346
+ "[/SiH1]": 242,
347
+ "[Se+1]": 243,
348
+ "[=OH1+1]": 244,
349
+ "[NH1]": 245,
350
+ "[OH1+1]": 246,
351
+ "[O-2]": 247,
352
+ "[=NH0]": 248,
353
+ "[=SiH2]": 249,
354
+ "[BH1-1]": 250,
355
+ "[TeH1]": 251,
356
+ "[76Br]": 252,
357
+ "[SH3+1]": 253,
358
+ "[/123I]": 254,
359
+ "[#13C]": 255,
360
+ "[=SiH1]": 256,
361
+ "[17OH1]": 257,
362
+ "[SH0]": 258,
363
+ "[Si@@H1]": 259,
364
+ "[=As]": 260,
365
+ "[18O-1]": 261,
366
+ "[/13CH2]": 262,
367
+ "[Si@H1]": 263,
368
+ "[17O]": 264,
369
+ "[35Cl]": 265,
370
+ "[3HH1]": 266,
371
+ "[=S+1]": 267,
372
+ "[Pd+2]": 268,
373
+ "[/B]": 269,
374
+ "[37Cl]": 270,
375
+ "[P-2]": 271,
376
+ "[Si+1]": 272,
377
+ "[/125I]": 273,
378
+ "[=SH1+1]": 274,
379
+ "[Cd]": 275,
380
+ "[Te+1]": 276,
381
+ "[Zn-2]": 277,
382
+ "[/Al]": 278,
383
+ "[=Sn]": 279,
384
+ "[12CH1]": 280,
385
+ "[=11C]": 281,
386
+ "[=15NH1]": 282,
387
+ "[16N]": 283,
388
+ "[Ag+1]": 284,
389
+ "[=Ge]": 285,
390
+ "[=CH1+1]": 286,
391
+ "[Pd+1]": 287,
392
+ "[12CH3]": 288,
393
+ "[=Zn]": 289,
394
+ "[/C-1]": 290,
395
+ "[/S-1]": 291,
396
+ "[Ti+4]": 292,
397
+ "[14NH1]": 293,
398
+ "[Ga+3]": 294,
399
+ "[GeH4]": 295,
400
+ "[#Si]": 296,
401
+ "[16O]": 297,
402
+ "[AlH1]": 298,
403
+ "[AlH2]": 299,
404
+ "[#C+1]": 300,
405
+ "[=P@@]": 301,
406
+ "[=Ring1]": 302,
407
+ "[/13CH3]": 303,
408
+ "[/S+1]": 304,
409
+ "[=P@]": 305,
410
+ "[PH0]": 306,
411
+ "[10B]": 307,
412
+ "[77Br]": 308,
413
+ "[=12CH1]": 309,
414
+ "[Ti+2]": 310,
415
+ "[/14C]": 311,
416
+ "[/CH1]": 312,
417
+ "[/SiH2]": 313,
418
+ "[He]": 314,
419
+ "[/N-1]": 315,
420
+ "[/NH3+1]": 316,
421
+ "[13NH2]": 317,
422
+ "[SbH2]": 318,
423
+ "[/As]": 319,
424
+ "[12CH2]": 320,
425
+ "[=14N]": 321,
426
+ "[/14CH1]": 322,
427
+ "[=12C]": 323,
428
+ "[=35S]": 324,
429
+ "[=P+1]": 325,
430
+ "[=16O]": 326,
431
+ "[=Ti]": 327,
432
+ "[In+1]": 328,
433
+ "[Pt+2]": 329,
434
+ "[#13CH1]": 330,
435
+ "[14NH2]": 331,
436
+ "[2HH1]": 332,
437
+ "[=Ti+2]": 333,
438
+ "[S-2]": 334,
439
+ "[14N]": 335,
440
+ "[33P]": 336,
441
+ "[Pt]": 337,
442
+ "[=11CH1]": 338,
443
+ "[AlH3]": 339,
444
+ "[BH4-1]": 340,
445
+ "[Ni]": 341,
446
+ "[/SiH3]": 342,
447
+ "[Cd+2]": 343,
448
+ "[Cr]": 344,
449
+ "[PH2-1]": 345,
450
+ "[Pb]": 346,
451
+ "[Sn+3]": 347,
452
+ "[Sn+4]": 348,
453
+ "[/Sn]": 349,
454
+ "[15NH3+1]": 350,
455
+ "[75Se]": 351,
456
+ "[=GeH2]": 352,
457
+ "[In-1]": 353,
458
+ "[Sn-1]": 354,
459
+ "[13C-1]": 355,
460
+ "[16NH1]": 356,
461
+ "[=14CH2]": 357,
462
+ "[Hg]": 358,
463
+ "[In+3]": 359,
464
+ "[Rh]": 360,
465
+ "[Ru]": 361,
466
+ "[Sc+3]": 362,
467
+ "[/131I]": 363,
468
+ "[/NH2+1]": 364,
469
+ "[34S]": 365,
470
+ "[35SH1]": 366,
471
+ "[P@H1]": 367,
472
+ "[PH1]": 368,
473
+ "[#14N]": 369,
474
+ "[14C@]": 370,
475
+ "[=S-1]": 371,
476
+ "[Cu+2]": 372,
477
+ "[I+1]": 373,
478
+ "[Sb+1]": 374,
479
+ "[/SeH1]": 375,
480
+ "[=P-1]": 376,
481
+ "[Sb+3]": 377,
482
+ "[127I]": 378,
483
+ "[14C@@]": 379,
484
+ "[15NH3]": 380,
485
+ "[16NH2]": 381,
486
+ "[75Br]": 382,
487
+ "[=Al]": 383,
488
+ "[=Mg]": 384,
489
+ "[Hg+2]": 385,
490
+ "[P@@H1]": 386,
491
+ "[Ti-2]": 387,
492
+ "[#13C-1]": 388,
493
+ "[#As]": 389,
494
+ "[/11C]": 390,
495
+ "[/P@]": 391,
496
+ "[14C-1]": 392,
497
+ "[18F-1]": 393,
498
+ "[6Li+1]": 394,
499
+ "[82Br]": 395,
500
+ "[=15N+1]": 396,
501
+ "[=34S]": 397,
502
+ "[Au]": 398,
503
+ "[Ga-1]": 399,
504
+ "[Kr]": 400,
505
+ "[Li-1]": 401,
506
+ "[Rh+3]": 402,
507
+ "[V]": 403,
508
+ "[15NH4+1]": 404,
509
+ "[7Li+1]": 405,
510
+ "[=BH0]": 406,
511
+ "[Cs+1]": 407,
512
+ "[Fe]": 408,
513
+ "[Ru+2]": 409,
514
+ "[#15N+1]": 410,
515
+ "[/18F]": 411,
516
+ "[/CH0]": 412,
517
+ "[/P@@]": 413,
518
+ "[80Br]": 414,
519
+ "[AlH2-1]": 415,
520
+ "[Bi]": 416,
521
+ "[GaH2]": 417,
522
+ "[PH2]": 418,
523
+ "[/14CH2]": 419,
524
+ "[/15NH1]": 420,
525
+ "[/76Br]": 421,
526
+ "[/CH1-1]": 422,
527
+ "[/PH1+1]": 423,
528
+ "[36Cl]": 424,
529
+ "[4H]": 425,
530
+ "[79Br]": 426,
531
+ "[=Ca]": 427,
532
+ "[=GeH1]": 428,
533
+ "[=Pd]": 429,
534
+ "[Au+1]": 430,
535
+ "[GaH3]": 431,
536
+ "[Gd]": 432,
537
+ "[Hf]": 433,
538
+ "[Pd-2]": 434,
539
+ "[SbH1]": 435,
540
+ "[Ti+3]": 436,
541
+ "[Y]": 437,
542
+ "[Zr]": 438,
543
+ "[#Sb]": 439,
544
+ "[#Si+1]": 440,
545
+ "[#SiH1]": 441,
546
+ "[121I]": 442,
547
+ "[13N]": 443,
548
+ "[14NH3]": 444,
549
+ "[15O]": 445,
550
+ "[17F]": 446,
551
+ "[28Si]": 447,
552
+ "[2H-1]": 448,
553
+ "[3H-1]": 449,
554
+ "[6H]": 450,
555
+ "[8CH1]": 451,
556
+ "[=15O]": 452,
557
+ "[=SiH1-1]": 453,
558
+ "[AlH1+2]": 454,
559
+ "[AlH1-1]": 455,
560
+ "[Ba+2]": 456,
561
+ "[Ba]": 457,
562
+ "[CH3+1]": 458,
563
+ "[Mg+1]": 459,
564
+ "[Ne]": 460,
565
+ "[OH3+1]": 461,
566
+ "[Si-2]": 462,
567
+ "[SiH4]": 463,
568
+ "[SnH4]": 464,
569
+ "[#Si-1]": 465,
570
+ "[/13C@@H1]": 466,
571
+ "[/13C@H1]": 467,
572
+ "[11C@H1]": 468,
573
+ "[11CH4]": 469,
574
+ "[122I]": 470,
575
+ "[125I-1]": 471,
576
+ "[14N+1]": 472,
577
+ "[15OH1]": 473,
578
+ "[17O-1]": 474,
579
+ "[18FH1]": 475,
580
+ "[5H]": 476,
581
+ "[77Se]": 477,
582
+ "[=33S]": 478,
583
+ "[=SH2]": 479,
584
+ "[AsH3]": 480,
585
+ "[BH1]": 481,
586
+ "[InH2]": 482,
587
+ "[Lu+3]": 483,
588
+ "[Mo]": 484,
589
+ "[Ti+1]": 485,
590
+ "[Y+3]": 486,
591
+ "[Zr+2]": 487,
592
+ "[#13N]": 488,
593
+ "[#Ge]": 489,
594
+ "[#Nb]": 490,
595
+ "[#P+1]": 491,
596
+ "[/Ga]": 492,
597
+ "[11C-1]": 493,
598
+ "[11C@@H1]": 494,
599
+ "[12C@@H1]": 495,
600
+ "[12C@@]": 496,
601
+ "[12C@]": 497,
602
+ "[131I-1]": 498,
603
+ "[13CH4]": 499,
604
+ "[15C]": 500,
605
+ "[2H+1]": 501,
606
+ "[8CH2]": 502,
607
+ "[=13O]": 503,
608
+ "[=14C-1]": 504,
609
+ "[=15N-1]": 505,
610
+ "[=AsH3]": 506,
611
+ "[=Pt]": 507,
612
+ "[Al-2]": 508,
613
+ "[AlH4-1]": 509,
614
+ "[Au-1]": 510,
615
+ "[Hg+1]": 511,
616
+ "[Ru+3]": 512,
617
+ "[SiH3-1]": 513,
618
+ "[Ta]": 514,
619
+ "[#Al]": 515,
620
+ "[#S+1]": 516,
621
+ "[/11CH3]": 517,
622
+ "[/15NH2]": 518,
623
+ "[/15N]": 519,
624
+ "[/Al-1]": 520,
625
+ "[/GeH1]": 521,
626
+ "[/NH1-1]": 522,
627
+ "[11B]": 523,
628
+ "[11CH3-1]": 524,
629
+ "[121Sb]": 525,
630
+ "[123I-1]": 526,
631
+ "[125IH1]": 527,
632
+ "[12C-1]": 528,
633
+ "[12C@H1]": 529,
634
+ "[13CH3-1]": 530,
635
+ "[13NH1]": 531,
636
+ "[14CH3-1]": 532,
637
+ "[14O]": 533,
638
+ "[15NH1+1]": 534,
639
+ "[1H+1]": 535,
640
+ "[32Cl]": 536,
641
+ "[33S]": 537,
642
+ "[68Ga+3]": 538,
643
+ "[74As]": 539,
644
+ "[75Ge]": 540,
645
+ "[82Se]": 541,
646
+ "[9CH2]": 542,
647
+ "[=15NH2+1]": 543,
648
+ "[=AsH1]": 544,
649
+ "[=Cr]": 545,
650
+ "[=Cu]": 546,
651
+ "[=Ga]": 547,
652
+ "[=Ni]": 548,
653
+ "[=Os]": 549,
654
+ "[=Sb]": 550,
655
+ "[=SeH1]": 551,
656
+ "[=SnH2]": 552,
657
+ "[=TeH2]": 553,
658
+ "[=Zr]": 554,
659
+ "[Al-3]": 555,
660
+ "[Co+3]": 556,
661
+ "[Fe+2]": 557,
662
+ "[GaH1]": 558,
663
+ "[Ge-2]": 559,
664
+ "[InH1]": 560,
665
+ "[Os]": 561,
666
+ "[Rb+1]": 562,
667
+ "[Sc]": 563,
668
+ "[SiH1-1]": 564,
669
+ "[Sr+2]": 565,
670
+ "[TeH2]": 566,
671
+ "[Zr+4]": 567,
672
+ "[#12CH1]": 568,
673
+ "[#14C]": 569,
674
+ "[#17O+1]": 570,
675
+ "[#18O+1]": 571,
676
+ "[#AsH1]": 572,
677
+ "[#Ga]": 573,
678
+ "[#In]": 574,
679
+ "[#Lu]": 575,
680
+ "[#Sc]": 576,
681
+ "[#Ta]": 577,
682
+ "[/124I]": 578,
683
+ "[/35Cl]": 579,
684
+ "[/37Cl]": 580,
685
+ "[/As+1]": 581,
686
+ "[/BH0]": 582,
687
+ "[/In]": 583,
688
+ "[/O+1]": 584,
689
+ "[/Sb]": 585,
690
+ "[120I]": 586,
691
+ "[124I-1]": 587,
692
+ "[129I]": 588,
693
+ "[14CH4]": 589,
694
+ "[15N-1]": 590,
695
+ "[29Si]": 591,
696
+ "[32PH2]": 592,
697
+ "[32S]": 593,
698
+ "[34SH1]": 594,
699
+ "[35Cl-1]": 595,
700
+ "[45Ca+2]": 596,
701
+ "[47Ca+2]": 597,
702
+ "[70Zn]": 598,
703
+ "[72Zn]": 599,
704
+ "[73Ge]": 600,
705
+ "[74Se]": 601,
706
+ "[76Br-1]": 602,
707
+ "[79BrH1]": 603,
708
+ "[7Be]": 604,
709
+ "[81BrH1]": 605,
710
+ "[81Br]": 606,
711
+ "[8CH4]": 607,
712
+ "[9CH1]": 608,
713
+ "[=11CH2]": 609,
714
+ "[=12CH2]": 610,
715
+ "[=13N]": 611,
716
+ "[=18CH2]": 612,
717
+ "[=32S]": 613,
718
+ "[=Ag]": 614,
719
+ "[=AlH1]": 615,
720
+ "[=Mo]": 616,
721
+ "[=PH2+1]": 617,
722
+ "[=SH0]": 618,
723
+ "[=SeH2]": 619,
724
+ "[=Ta]": 620,
725
+ "[=V]": 621,
726
+ "[=W]": 622,
727
+ "[Cr+2]": 623,
728
+ "[Ir]": 624,
729
+ "[Nb]": 625,
730
+ "[Ni-2]": 626,
731
+ "[OH1]": 627,
732
+ "[PbH3]": 628,
733
+ "[Rb]": 629,
734
+ "[Rh+2]": 630,
735
+ "[SbH1+1]": 631,
736
+ "[Si+4]": 632,
737
+ "[Tl+1]": 633,
738
+ "[Tl+3]": 634,
739
+ "[#11CH1]": 635,
740
+ "[#11C]": 636,
741
+ "[#14C-1]": 637,
742
+ "[#14CH1]": 638,
743
+ "[#15O+1]": 639,
744
+ "[#16O+1]": 640,
745
+ "[#17CH1]": 641,
746
+ "[#18CH1]": 642,
747
+ "[#Cr]": 643,
748
+ "[#GeH1]": 644,
749
+ "[#Mo+1]": 645,
750
+ "[#Mo]": 646,
751
+ "[#PH2]": 647,
752
+ "[#SH1-1]": 648,
753
+ "[#Se]": 649,
754
+ "[#Sn]": 650,
755
+ "[#Ti+1]": 651,
756
+ "[#V]": 652,
757
+ "[#Y]": 653,
758
+ "[/127I]": 654,
759
+ "[/14CH3]": 655,
760
+ "[/15N+1]": 656,
761
+ "[/18OH1]": 657,
762
+ "[/18O]": 658,
763
+ "[/32P]": 659,
764
+ "[/80Br]": 660,
765
+ "[/Al+1]": 661,
766
+ "[/CH2]": 662,
767
+ "[/GeH3]": 663,
768
+ "[/N@+1]": 664,
769
+ "[/N@@+1]": 665,
770
+ "[/NH0]": 666,
771
+ "[/OH0]": 667,
772
+ "[/PH3+1]": 668,
773
+ "[/Te+1]": 669,
774
+ "[/TeH1]": 670,
775
+ "[100Mo]": 671,
776
+ "[100Pd]": 672,
777
+ "[101Mo]": 673,
778
+ "[101Pd]": 674,
779
+ "[104Pd]": 675,
780
+ "[105Pd]": 676,
781
+ "[108Pd]": 677,
782
+ "[10B-1]": 678,
783
+ "[10BH3]": 679,
784
+ "[10Be]": 680,
785
+ "[10CH4]": 681,
786
+ "[10C]": 682,
787
+ "[111I-1]": 683,
788
+ "[111IH1]": 684,
789
+ "[111In+3]": 685,
790
+ "[111In]": 686,
791
+ "[112Pd]": 687,
792
+ "[117SnH2]": 688,
793
+ "[119Sn]": 689,
794
+ "[11NH3]": 690,
795
+ "[120I-1]": 691,
796
+ "[120IH1]": 692,
797
+ "[121I-1]": 693,
798
+ "[121IH1]": 694,
799
+ "[121SnH2]": 695,
800
+ "[122IH1]": 696,
801
+ "[123IH1]": 697,
802
+ "[123Te]": 698,
803
+ "[124IH1]": 699,
804
+ "[124Xe]": 700,
805
+ "[125Te]": 701,
806
+ "[126IH1]": 702,
807
+ "[126Xe]": 703,
808
+ "[127I-1]": 704,
809
+ "[127IH1]": 705,
810
+ "[127Xe]": 706,
811
+ "[128I-1]": 707,
812
+ "[128IH1]": 708,
813
+ "[128I]": 709,
814
+ "[129I-1]": 710,
815
+ "[129IH1]": 711,
816
+ "[12B]": 712,
817
+ "[12CH4]": 713,
818
+ "[12Li+1]": 714,
819
+ "[12OH1]": 715,
820
+ "[130I-1]": 716,
821
+ "[130IH1]": 717,
822
+ "[131IH1]": 718,
823
+ "[131Xe]": 719,
824
+ "[132I-1]": 720,
825
+ "[132IH1]": 721,
826
+ "[132Xe]": 722,
827
+ "[133I-1]": 723,
828
+ "[133IH1]": 724,
829
+ "[134I-1]": 725,
830
+ "[134IH1]": 726,
831
+ "[134Xe]": 727,
832
+ "[135I-1]": 728,
833
+ "[135IH1]": 729,
834
+ "[135I]": 730,
835
+ "[13CH1+1]": 731,
836
+ "[13CH2-1]": 732,
837
+ "[13NH3]": 733,
838
+ "[13OH2]": 734,
839
+ "[13O]": 735,
840
+ "[145Gd]": 736,
841
+ "[146Gd]": 737,
842
+ "[147Gd]": 738,
843
+ "[148Gd]": 739,
844
+ "[149Gd]": 740,
845
+ "[14CH2-1]": 741,
846
+ "[14NH4+1]": 742,
847
+ "[151Gd]": 743,
848
+ "[152Gd]": 744,
849
+ "[153Gd]": 745,
850
+ "[154Gd]": 746,
851
+ "[155Gd]": 747,
852
+ "[156Gd]": 748,
853
+ "[157Gd]": 749,
854
+ "[158Gd]": 750,
855
+ "[159Gd]": 751,
856
+ "[15CH3]": 752,
857
+ "[15CH4]": 753,
858
+ "[15NH2+1]": 754,
859
+ "[15OH2]": 755,
860
+ "[160Gd]": 756,
861
+ "[161Gd]": 757,
862
+ "[16CH1]": 758,
863
+ "[16CH3]": 759,
864
+ "[16C]": 760,
865
+ "[16F]": 761,
866
+ "[16NH3]": 762,
867
+ "[16O-1]": 763,
868
+ "[16OH1-1]": 764,
869
+ "[16OH2]": 765,
870
+ "[177Lu+3]": 766,
871
+ "[17CH1]": 767,
872
+ "[17CH2]": 768,
873
+ "[17FH1]": 769,
874
+ "[17NH3]": 770,
875
+ "[17OH1-1]": 771,
876
+ "[17OH2]": 772,
877
+ "[18CH1]": 773,
878
+ "[18CH2]": 774,
879
+ "[18OH1-1]": 775,
880
+ "[18OH2]": 776,
881
+ "[19B]": 777,
882
+ "[19FH1]": 778,
883
+ "[19Ne]": 779,
884
+ "[19OH2]": 780,
885
+ "[19O]": 781,
886
+ "[1H-1]": 782,
887
+ "[1HH1]": 783,
888
+ "[20CH1]": 784,
889
+ "[20Ne]": 785,
890
+ "[20OH1]": 786,
891
+ "[21CH4]": 787,
892
+ "[21NH3]": 788,
893
+ "[21Ne]": 789,
894
+ "[22CH4]": 790,
895
+ "[22Na+1]": 791,
896
+ "[22Ne]": 792,
897
+ "[24FH1]": 793,
898
+ "[24Mg]": 794,
899
+ "[24NH3]": 795,
900
+ "[24Na+1]": 796,
901
+ "[25FH1]": 797,
902
+ "[25Mg]": 798,
903
+ "[25OH1]": 799,
904
+ "[26FH1]": 800,
905
+ "[27Mg]": 801,
906
+ "[28F]": 802,
907
+ "[28Mg]": 803,
908
+ "[28SiH3]": 804,
909
+ "[30Si]": 805,
910
+ "[31PH3]": 806,
911
+ "[31P]": 807,
912
+ "[31Si]": 808,
913
+ "[32ClH1]": 809,
914
+ "[32PH3]": 810,
915
+ "[32SH2]": 811,
916
+ "[32Si]": 812,
917
+ "[33ClH1]": 813,
918
+ "[33PH3]": 814,
919
+ "[33SH2]": 815,
920
+ "[34ClH1]": 816,
921
+ "[34SH2]": 817,
922
+ "[35ClH1]": 818,
923
+ "[35P]": 819,
924
+ "[35S-1]": 820,
925
+ "[35SH2]": 821,
926
+ "[36Ar]": 822,
927
+ "[36Cl-1]": 823,
928
+ "[36ClH1]": 824,
929
+ "[36SH2]": 825,
930
+ "[37Ar]": 826,
931
+ "[37Cl-1]": 827,
932
+ "[37ClH1]": 828,
933
+ "[37SH2]": 829,
934
+ "[38Ar]": 830,
935
+ "[38Cl-1]": 831,
936
+ "[38ClH1]": 832,
937
+ "[38PH3]": 833,
938
+ "[38SH2]": 834,
939
+ "[39Ar]": 835,
940
+ "[39ClH1]": 836,
941
+ "[3He]": 837,
942
+ "[40Ar]": 838,
943
+ "[40Ca]": 839,
944
+ "[40PH3]": 840,
945
+ "[41Ar]": 841,
946
+ "[41Ca+2]": 842,
947
+ "[41Ca]": 843,
948
+ "[42Ca]": 844,
949
+ "[42K+1]": 845,
950
+ "[43Ca+2]": 846,
951
+ "[43Ca]": 847,
952
+ "[43K+1]": 848,
953
+ "[44Ca+2]": 849,
954
+ "[44Ca]": 850,
955
+ "[45Ca]": 851,
956
+ "[46Ca]": 852,
957
+ "[47Ca]": 853,
958
+ "[48Ca]": 854,
959
+ "[49Ca]": 855,
960
+ "[4HH1]": 856,
961
+ "[4He]": 857,
962
+ "[61Cu+1]": 858,
963
+ "[62Cu+1]": 859,
964
+ "[62Zn]": 860,
965
+ "[63Zn]": 861,
966
+ "[64Cu+1]": 862,
967
+ "[64Cu]": 863,
968
+ "[64Zn+2]": 864,
969
+ "[64Zn]": 865,
970
+ "[65Zn+2]": 866,
971
+ "[65Zn]": 867,
972
+ "[66Ge]": 868,
973
+ "[66Zn]": 869,
974
+ "[67Ga+3]": 870,
975
+ "[67Ge]": 871,
976
+ "[67Zn]": 872,
977
+ "[68Ga]": 873,
978
+ "[68Ge]": 874,
979
+ "[68Zn]": 875,
980
+ "[69Ge]": 876,
981
+ "[69Zn]": 877,
982
+ "[6He]": 878,
983
+ "[70As]": 879,
984
+ "[70Se]": 880,
985
+ "[71As]": 881,
986
+ "[71Ge]": 882,
987
+ "[71Se]": 883,
988
+ "[71Zn]": 884,
989
+ "[72As]": 885,
990
+ "[72BrH1]": 886,
991
+ "[72Ge]": 887,
992
+ "[72Se]": 888,
993
+ "[73Se]": 889,
994
+ "[74Br-1]": 890,
995
+ "[74BrH1]": 891,
996
+ "[74Ge]": 892,
997
+ "[74Kr]": 893,
998
+ "[75Br-1]": 894,
999
+ "[75BrH1]": 895,
1000
+ "[76As]": 896,
1001
+ "[76BrH1]": 897,
1002
+ "[76Kr]": 898,
1003
+ "[76Se]": 899,
1004
+ "[77As]": 900,
1005
+ "[77Br-1]": 901,
1006
+ "[77BrH1]": 902,
1007
+ "[77Ge]": 903,
1008
+ "[77Kr]": 904,
1009
+ "[78BrH1]": 905,
1010
+ "[78Ge]": 906,
1011
+ "[78Kr]": 907,
1012
+ "[78Se]": 908,
1013
+ "[79Kr]": 909,
1014
+ "[79Se]": 910,
1015
+ "[80Br-1]": 911,
1016
+ "[80BrH1]": 912,
1017
+ "[80Kr]": 913,
1018
+ "[80Se]": 914,
1019
+ "[80Sr]": 915,
1020
+ "[81Kr]": 916,
1021
+ "[81Se]": 917,
1022
+ "[82Br-1]": 918,
1023
+ "[82BrH1]": 919,
1024
+ "[82Kr]": 920,
1025
+ "[82Rb+1]": 921,
1026
+ "[83Br-1]": 922,
1027
+ "[83BrH1]": 923,
1028
+ "[83Kr]": 924,
1029
+ "[83Se]": 925,
1030
+ "[84BrH1]": 926,
1031
+ "[84Kr]": 927,
1032
+ "[85Br]": 928,
1033
+ "[85Kr]": 929,
1034
+ "[86Kr]": 930,
1035
+ "[86Rb+1]": 931,
1036
+ "[86Zr]": 932,
1037
+ "[87Kr]": 933,
1038
+ "[87Sr]": 934,
1039
+ "[88Kr]": 935,
1040
+ "[88Zr]": 936,
1041
+ "[89Kr]": 937,
1042
+ "[89Zr]": 938,
1043
+ "[8B]": 939,
1044
+ "[8Be]": 940,
1045
+ "[8He]": 941,
1046
+ "[90Mo]": 942,
1047
+ "[90Y+3]": 943,
1048
+ "[90Zr]": 944,
1049
+ "[91Y+3]": 945,
1050
+ "[92Mo]": 946,
1051
+ "[92Sr]": 947,
1052
+ "[93Mo]": 948,
1053
+ "[93Zr]": 949,
1054
+ "[94Zr]": 950,
1055
+ "[95Mo]": 951,
1056
+ "[95Zr]": 952,
1057
+ "[96Mo]": 953,
1058
+ "[97Mo]": 954,
1059
+ "[97Zr]": 955,
1060
+ "[98Mo]": 956,
1061
+ "[99Mo]": 957,
1062
+ "[99Ru+2]": 958,
1063
+ "[9B]": 959,
1064
+ "[9Be]": 960,
1065
+ "[=11NH1]": 961,
1066
+ "[=12O]": 962,
1067
+ "[=13C-1]": 963,
1068
+ "[=14NH1]": 964,
1069
+ "[=16N]": 965,
1070
+ "[=19O]": 966,
1071
+ "[=25O]": 967,
1072
+ "[=77Se]": 968,
1073
+ "[=8CH1]": 969,
1074
+ "[=Al-1]": 970,
1075
+ "[=AsH2]": 971,
1076
+ "[=Ba]": 972,
1077
+ "[=Be]": 973,
1078
+ "[=Cd]": 974,
1079
+ "[=Fe]": 975,
1080
+ "[=Hg]": 976,
1081
+ "[=In]": 977,
1082
+ "[=Mo+4]": 978,
1083
+ "[=Rh]": 979,
1084
+ "[=SH1-1]": 980,
1085
+ "[=Si+1]": 981,
1086
+ "[=Si-1]": 982,
1087
+ "[=SiH1+1]": 983,
1088
+ "[=TeH1]": 984,
1089
+ "[=Ti+1]": 985,
1090
+ "[AlH6-3]": 986,
1091
+ "[As+3]": 987,
1092
+ "[AsH1+1]": 988,
1093
+ "[AsH5]": 989,
1094
+ "[Au+3]": 990,
1095
+ "[Bi+2]": 991,
1096
+ "[Bi+3]": 992,
1097
+ "[Branch3]": 993,
1098
+ "[CH3]": 994,
1099
+ "[Cr+4]": 995,
1100
+ "[CuH1]": 996,
1101
+ "[Fe+4]": 997,
1102
+ "[Gd+2]": 998,
1103
+ "[Ge+4]": 999,
1104
+ "[Ge-1]": 1000,
1105
+ "[Ge@@H1]": 1001,
1106
+ "[Ge@@]": 1002,
1107
+ "[Ge@]": 1003,
1108
+ "[InH3]": 1004,
1109
+ "[Ir+3]": 1005,
1110
+ "[Mn]": 1006,
1111
+ "[Mo+2]": 1007,
1112
+ "[Nb+3]": 1008,
1113
+ "[Pt+4]": 1009,
1114
+ "[Re]": 1010,
1115
+ "[Rh-3]": 1011,
1116
+ "[RhH1+2]": 1012,
1117
+ "[Ru+4]": 1013,
1118
+ "[Ru-2]": 1014,
1119
+ "[RuH1+3]": 1015,
1120
+ "[RuH4]": 1016,
1121
+ "[S@@H1]": 1017,
1122
+ "[SbH3]": 1018,
1123
+ "[SbH5]": 1019,
1124
+ "[SeH2]": 1020,
1125
+ "[Si+2]": 1021,
1126
+ "[SiH2-1]": 1022,
1127
+ "[SiH4-1]": 1023,
1128
+ "[Sr]": 1024,
1129
+ "[TeH3]": 1025,
1130
+ "[TeH4]": 1026,
1131
+ "[TlH2]": 1027,
1132
+ "[Tl]": 1028,
1133
+ "[W]": 1029,
1134
+ "[Xe]": 1030,
1135
+ "[ZnH1+1]": 1031
1136
+ },
1137
+ "unk_token": "[UNK]"
1138
+ }
1139
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[UNK]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[CLS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[PAD]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[MASK]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ }
35
+ },
36
+ "clean_up_tokenization_spaces": false,
37
+ "cls_token": "[CLS]",
38
+ "extra_special_tokens": {},
39
+ "mask_token": "[MASK]",
40
+ "model_max_length": 1000000000000000019884624838656,
41
+ "pad_token": "[PAD]",
42
+ "tokenizer_class": "PreTrainedTokenizerFast",
43
+ "unk_token": "[UNK]"
44
+ }