Upload folder using huggingface_hub
Browse files- README.md +178 -0
- config.json +104 -0
- model.safetensors +3 -0
- modeling_mist_finetuned.py +775 -0
- requirements.txt +4 -0
- special_tokens_map.json +51 -0
- tokenizer.json +267 -0
- tokenizer_config.json +72 -0
README.md
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
library_name: transformers
|
| 4 |
+
license: gpl-3.0
|
| 5 |
+
tags:
|
| 6 |
+
- mist
|
| 7 |
+
- chemistry
|
| 8 |
+
- molecular-property-prediction
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# MIST: Molecular Insight SMILES Transformers
|
| 12 |
+
|
| 13 |
+
MIST is a family of molecular foundation models for molecular property prediction.
|
| 14 |
+
The models were pre-trained on SMILES strings from the [Enamine REAL Space](https://enamine.net/compound-collections/real-compounds/real-space-navigator) dataset using the Masked Language Modeling (MLM) objective, then fine-tuned for downstream prediction tasks.
|
| 15 |
+
|
| 16 |
+
## Model Details
|
| 17 |
+
|
| 18 |
+
- **Architecture**: Encoder-only transformer [``RoBERTa-PreLayerNorm``](https://huggingface.co/docs/transformers/en/model_doc/roberta-prelayernorm)
|
| 19 |
+
- **Pre-training**: Masked Language Modeling on molecular SMILES
|
| 20 |
+
- **Tokenization**: [``Smirk``](https://eeg.engin.umich.edu/smirk/) tokenizer
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
### Quick Start
|
| 24 |
+
|
| 25 |
+
```python
|
| 26 |
+
from transformers import AutoModel
|
| 27 |
+
|
| 28 |
+
# Load the model
|
| 29 |
+
model = AutoModel.from_pretrained(
|
| 30 |
+
"path/to/model",
|
| 31 |
+
trust_remote_code=True
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Make predictions
|
| 35 |
+
smiles_batch = [
|
| 36 |
+
"CCO", # Ethanol
|
| 37 |
+
"CC(=O)O", # Acetic acid
|
| 38 |
+
"c1ccccc1" # Benzene
|
| 39 |
+
]
|
| 40 |
+
results = model.predict(smiles_batch)
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
### Setting Up Your Environment
|
| 44 |
+
|
| 45 |
+
Create a virtual environment and install dependencies:
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
python -m venv .venv
|
| 49 |
+
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
| 50 |
+
pip install -r requirements.txt
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
> **Note**: SMIRK tokenizers require Rust to be installed. See the [Rust installation guide](https://www.rust-lang.org/tools/install) for details.
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
## Model Inputs and Outputs
|
| 57 |
+
|
| 58 |
+
### Inputs
|
| 59 |
+
- **SMILES strings**: Standard SMILES notation for molecular structures
|
| 60 |
+
- **Batch size**: Variable, automatically padded during inference
|
| 61 |
+
|
| 62 |
+
### Outputs
|
| 63 |
+
- **Predictions**: Task-specific numerical or categorical predictions
|
| 64 |
+
- **Format**: Dictionary with channel names and predicted values (if channels are configured), or raw tensor output
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
## Provided Models
|
| 68 |
+
|
| 69 |
+
### Pre-trained
|
| 70 |
+
- `mist-1.8B-dh61satt`: Flagship MIST model (MIST-1.8B)
|
| 71 |
+
- `mist-28M-ti624ev1`: Smaller MIST model (MIST-28M).
|
| 72 |
+
|
| 73 |
+
Below is a full list of finetuned variants hosted on HuggingFace:
|
| 74 |
+
### MoleculeNet Benchmark Models
|
| 75 |
+
|
| 76 |
+
| Folder | Encoder | Dataset |
|
| 77 |
+
| ---------------------------- | :------: | ------------------------------------ |
|
| 78 |
+
| mist-1.8B-fbdn8e35-bbbp | MIST-1.8B| MoleculeNet BBBP |
|
| 79 |
+
| mist-1.8B-1a4puhg2-hiv | MIST-1.8B| MoleculeNet HIV |
|
| 80 |
+
| mist-1.8B-m50jgolp-bace | MIST-1.8B| MoleculeNet BACE |
|
| 81 |
+
| mist-1.8B-uop1z0dc-tox21 | MIST-1.8B| MoleculeNet Tox21 |
|
| 82 |
+
| mist-1.8B-lu1l5ieh-clintox | MIST-1.8B| MoleculeNet ClinTox |
|
| 83 |
+
| mist-1.8B-l1wfo7oa-sider | MIST-1.8B| MoleculeNet SIDER. |
|
| 84 |
+
| mist-1.8B-hxiygjsm-esol | MIST-1.8B| MoleculeNet ESOL |
|
| 85 |
+
| mist-1.8B-iwqj2cld-freesolv | MIST-1.8B| MoleculeNet FreeSolv |
|
| 86 |
+
| mist-1.8B-jvt4azpz-lipo | MIST-1.8B| MoleculeNet Lipophilicity |
|
| 87 |
+
| mist-1.8B-8nd1ot5j-qm8 | MIST-1.8B| MoleculeNet QM8 |
|
| 88 |
+
| mist-28M-3xpfhv48-bbbp | MIST-28M | MoleculeNet BBBP |
|
| 89 |
+
| mist-28M-8fh43gke-hiv | MIST-28M | MoleculeNet HIV |
|
| 90 |
+
| mist-28M-8loj3bab-bace | MIST-28M | MoleculeNet BACE |
|
| 91 |
+
| mist-28M-kw4ks27p-tox21 | MIST-28M | MoleculeNet Tox21 |
|
| 92 |
+
| mist-28M-97vfcykk-clintox | MIST-28M | MoleculeNet ClinTox |
|
| 93 |
+
| mist-28M-z8qo16uy-sider | MIST-28M | MoleculeNet SIDER |
|
| 94 |
+
| mist-28M-kcwb9le5-esol | MIST-28M | MoleculeNet ESOL |
|
| 95 |
+
| mist-28M-0uiq7o7m-freesolv | MIST-28M | MoleculeNet FreeSolv |
|
| 96 |
+
| mist-28M-xzr5ulva-lipo | MIST-28M | MoleculeNet Lipophilicity |
|
| 97 |
+
| mist-28M-gzwqzpcr-qm8 | MIST-28M | MoleculeNet QM8 |
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
#### QM9 Benchmark Models
|
| 101 |
+
The single target (MIST-1.8B encoder) models for properties in QM9 are available.
|
| 102 |
+
|
| 103 |
+
| Folder | Encoder | Target |
|
| 104 |
+
| ---------------------------- | :------: | ----------------------------------------------------------------- |
|
| 105 |
+
| mist-1.8B-ez05expv-mu | MIST-1.8B| μ - Dipole moment (unit: D) |
|
| 106 |
+
| mist-1.8B-rcwary93-alpha | MIST-1.8B| α - Isotropic polarizability (unit: Bohr^3) |
|
| 107 |
+
| mist-1.8B-jmjosq12-homo | MIST-1.8B| HOMO - Highest occupied molecular orbital energy (unit: Hartree) |
|
| 108 |
+
| mist-1.8B-n14wshc9-lumo | MIST-1.8B| LUMO - Lowest unoccupied molecular orbital energy (unit: Hartree) |
|
| 109 |
+
| mist-1.8B-kayun6v3-gap | MIST-1.8B| Gap - Gap between HOMO and LUMO (unit: Hartree) |
|
| 110 |
+
| mist-1.8B-xxe7t35e-r2 | MIST-1.8B| \<R2\> - Electronic spatial extent (unit: Bohr^2) |
|
| 111 |
+
| mist-1.8B-6nmcwyrp-zpve | MIST-1.8B| ZPVE - Zero point vibrational energy (unit: Hartree) |
|
| 112 |
+
| mist-1.8B-a7akimjj-u0 | MIST-1.8B| U0 - Internal energy at 0K (unit: Hartree) |
|
| 113 |
+
| mist-1.8B-85f24xkj-u298 | MIST-1.8B| U298 - Internal energy at 298.15K (unit: Hartree) |
|
| 114 |
+
| mist-1.8B-3fbbz4is-h298 | MIST-1.8B| H298 - Enthalpy at 298.15K (unit: Hartree) |
|
| 115 |
+
| mist-1.8B-09sntn03-g298 | MIST-1.8B| G298 - Free energy at 298.15K (unit: Hartree) |
|
| 116 |
+
| mist-1.8B-j356b3nf-cv | MIST-1.8B| Cv - Heat capacity at 298.15K (unit: cal/(mol*K)) |
|
| 117 |
+
|
| 118 |
+
- `mist-ti624ev1-moleculenet`: Contains MoleculeNet benchmark MIST-28M models trained as part of doi:10.5281/zenodo.13761263
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
### Finetuned Single Task Models
|
| 122 |
+
|
| 123 |
+
These models consist of a MIST-encoder and task network finetuned on a single dataset used in the applications demonstrated in the manuscript.
|
| 124 |
+
|
| 125 |
+
| Folder | Encoder | Dataset |
|
| 126 |
+
| ------------------------- | :------: | ----------------------------------------------------------- |
|
| 127 |
+
| mist-26.9M-48kpooqf-odour | MIST-28M | Olfaction |
|
| 128 |
+
| mist-26.9M-6hk5coof-dn | MIST-28M | Donor Number |
|
| 129 |
+
| mist-26.9M-0vxdbm36-kt | MIST-28M | Kamlet-Taft Solvochromatic Parameters |
|
| 130 |
+
| mist-26.9M-b302p09x-bp | MIST-28M | Boiling Point (Part of Characteristic Temperatures Dataset) |
|
| 131 |
+
| mist-26.9M-cyuo2xb6-fp | MIST-28M | Flash Point (Part of Characteristic Temperatures Dataset) |
|
| 132 |
+
| mist-26.9M-y3ge5pf9-mp | MIST-28M | Melting Point (Part of Characteristic Temperatures Dataset) |
|
| 133 |
+
|
| 134 |
+
### Finetuned Multi-Task Models
|
| 135 |
+
These are additional multi-target finetuned models consisting of a MIST encoder and task network.
|
| 136 |
+
| Folder | Encoder | Dataset |
|
| 137 |
+
| -------------------------- | :------: | ----------------------------------------------------------- |
|
| 138 |
+
| mist-26.9M-kkgx0omx-qm9 | MIST-28M | QM9 Dataset with SMILES randomization |
|
| 139 |
+
| mist-28M-ttqcvt6fs-toxcast | MIST-28M | ToxCast |
|
| 140 |
+
| mist-28M-yr1urd2c-muv | MIST-28M | Maximum Unbiased Validation (MUV) |
|
| 141 |
+
|
| 142 |
+
### Finetuned Mixture Models
|
| 143 |
+
|
| 144 |
+
These models consist of a MIST-encoder and physics informed task network for mixture property prediction.
|
| 145 |
+
| Folder | Encoder | Dataset |
|
| 146 |
+
| -------------------------------- | :------: | ----------------------------------------------------------- |
|
| 147 |
+
| mist-conductivity-28M-2mpg8dcd | MIST-28M | Ionic Conductivity |
|
| 148 |
+
| mist-mixtures-zffffbex | MIST-28M | Excess Density, Molar Volume and Molar Enthalpy |
|
| 149 |
+
|
| 150 |
+
## Citation
|
| 151 |
+
|
| 152 |
+
If you use this model in your research, please cite:
|
| 153 |
+
|
| 154 |
+
```bibtex
|
| 155 |
+
@online{MIST,
|
| 156 |
+
title = {Foundation Models for Discovery and Exploration in Chemical Space},
|
| 157 |
+
author = {Wadell, Alexius and Bhutani, Anoushka and Azumah, Victor and Ellis-Mohr, Austin R. and Kelly, Celia and Zhao, Hancheng and Nayak, Anuj K. and Hegazy, Kareem and Brace, Alexander and Lin, Hongyi and Emani, Murali and Vishwanath, Venkatram and Gering, Kevin and Alkan, Melisa and Gibbs, Tom and Wells, Jack and Varshney, Lav R. and Ramsundar, Bharath and Duraisamy, Karthik and Mahoney, Michael W. and Ramanathan, Arvind and Viswanathan, Venkatasubramanian},
|
| 158 |
+
date = {2025-10-20},
|
| 159 |
+
eprint = {2510.18900},
|
| 160 |
+
eprinttype = {arXiv},
|
| 161 |
+
eprintclass = {physics},
|
| 162 |
+
doi = {10.48550/arXiv.2510.18900},
|
| 163 |
+
url = {http://arxiv.org/abs/2510.18900},
|
| 164 |
+
}
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
## License and Notice
|
| 168 |
+
|
| 169 |
+
Model weights are provided as-is for research purposes only, without guarantees of correctness, fitness for purpose, or warranties of any kind.
|
| 170 |
+
|
| 171 |
+
**Restrictions:**
|
| 172 |
+
- Research use only
|
| 173 |
+
- No redistribution without permission
|
| 174 |
+
- No commercial use without licensing agreement
|
| 175 |
+
|
| 176 |
+
For questions, issues, or licensing inquiries, please contact [venkvis@umich.edu](mailto:venkvis@umich.edu).
|
| 177 |
+
|
| 178 |
+
<hr>
|
config.json
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"MISTFinetuned"
|
| 4 |
+
],
|
| 5 |
+
"channels": null,
|
| 6 |
+
"dtype": "float32",
|
| 7 |
+
"encoder": {
|
| 8 |
+
"_name_or_path": "",
|
| 9 |
+
"add_cross_attention": false,
|
| 10 |
+
"architectures": null,
|
| 11 |
+
"attention_probs_dropout_prob": 0.1,
|
| 12 |
+
"attn_implementation": null,
|
| 13 |
+
"bad_words_ids": null,
|
| 14 |
+
"begin_suppress_tokens": null,
|
| 15 |
+
"bos_token_id": 0,
|
| 16 |
+
"chunk_size_feed_forward": 0,
|
| 17 |
+
"classifier_dropout": null,
|
| 18 |
+
"cross_attention_hidden_size": null,
|
| 19 |
+
"decoder_start_token_id": null,
|
| 20 |
+
"diversity_penalty": 0.0,
|
| 21 |
+
"do_sample": false,
|
| 22 |
+
"dtype": null,
|
| 23 |
+
"early_stopping": false,
|
| 24 |
+
"enable_token_counter": true,
|
| 25 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 26 |
+
"eos_token_id": 2,
|
| 27 |
+
"exponential_decay_length_penalty": null,
|
| 28 |
+
"finetuning_task": null,
|
| 29 |
+
"forced_bos_token_id": null,
|
| 30 |
+
"forced_eos_token_id": null,
|
| 31 |
+
"hidden_act": "gelu",
|
| 32 |
+
"hidden_dropout_prob": 0.1,
|
| 33 |
+
"hidden_size": 512,
|
| 34 |
+
"id2label": {
|
| 35 |
+
"0": "LABEL_0",
|
| 36 |
+
"1": "LABEL_1"
|
| 37 |
+
},
|
| 38 |
+
"initializer_range": 0.02,
|
| 39 |
+
"intermediate_size": 2048,
|
| 40 |
+
"is_decoder": false,
|
| 41 |
+
"is_encoder_decoder": false,
|
| 42 |
+
"label2id": {
|
| 43 |
+
"LABEL_0": 0,
|
| 44 |
+
"LABEL_1": 1
|
| 45 |
+
},
|
| 46 |
+
"layer_norm_eps": 1e-12,
|
| 47 |
+
"length_penalty": 1.0,
|
| 48 |
+
"max_length": 20,
|
| 49 |
+
"max_position_embeddings": 2048,
|
| 50 |
+
"min_length": 0,
|
| 51 |
+
"model_type": "roberta-prelayernorm",
|
| 52 |
+
"no_repeat_ngram_size": 0,
|
| 53 |
+
"num_attention_heads": 8,
|
| 54 |
+
"num_beam_groups": 1,
|
| 55 |
+
"num_beams": 1,
|
| 56 |
+
"num_hidden_layers": 8,
|
| 57 |
+
"num_return_sequences": 1,
|
| 58 |
+
"output_attentions": false,
|
| 59 |
+
"output_hidden_states": false,
|
| 60 |
+
"output_scores": false,
|
| 61 |
+
"pad_token_id": 1,
|
| 62 |
+
"position_embedding_type": "absolute",
|
| 63 |
+
"prefix": null,
|
| 64 |
+
"problem_type": null,
|
| 65 |
+
"pruned_heads": {},
|
| 66 |
+
"remove_invalid_values": false,
|
| 67 |
+
"repetition_penalty": 1.0,
|
| 68 |
+
"return_dict": true,
|
| 69 |
+
"return_dict_in_generate": false,
|
| 70 |
+
"sep_token_id": null,
|
| 71 |
+
"suppress_tokens": null,
|
| 72 |
+
"task_specific_params": null,
|
| 73 |
+
"temperature": 1.0,
|
| 74 |
+
"tf_legacy_loss": false,
|
| 75 |
+
"tie_encoder_decoder": false,
|
| 76 |
+
"tie_word_embeddings": true,
|
| 77 |
+
"tokenizer_class": null,
|
| 78 |
+
"top_k": 50,
|
| 79 |
+
"top_p": 1.0,
|
| 80 |
+
"torchscript": false,
|
| 81 |
+
"transformers_version": "4.57.1",
|
| 82 |
+
"type_vocab_size": 2,
|
| 83 |
+
"typical_p": 1.0,
|
| 84 |
+
"use_bfloat16": false,
|
| 85 |
+
"use_cache": true,
|
| 86 |
+
"vocab_size": 165
|
| 87 |
+
},
|
| 88 |
+
"model_type": "mist_finetuned",
|
| 89 |
+
"task_network": {
|
| 90 |
+
"dropout": 0.2,
|
| 91 |
+
"embed_dim": 512,
|
| 92 |
+
"output_size": 1
|
| 93 |
+
},
|
| 94 |
+
"tokenizer_class": "SmirkTokenizerFast",
|
| 95 |
+
"transform": {
|
| 96 |
+
"class": "Standardize",
|
| 97 |
+
"num_outputs": 1
|
| 98 |
+
},
|
| 99 |
+
"transformers_version": "4.57.1",
|
| 100 |
+
"auto_map": {
|
| 101 |
+
"AutoConfig": "modeling_mist_finetuned.MISTFinetunedConfig",
|
| 102 |
+
"AutoModel": "modeling_mist_finetuned.MISTFinetuned"
|
| 103 |
+
}
|
| 104 |
+
}
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3d82a38329ae3821140679785e3fd2083c26ecd167cd9a15298f076501131da1
|
| 3 |
+
size 108591540
|
modeling_mist_finetuned.py
ADDED
|
@@ -0,0 +1,775 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import IterableDataset
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from smirk import SmirkTokenizerFast
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.masked import MaskedTensor, masked_tensor
|
| 6 |
+
from transformers import (
|
| 7 |
+
AutoConfig,
|
| 8 |
+
AutoModel,
|
| 9 |
+
AutoTokenizer,
|
| 10 |
+
DataCollatorWithPadding,
|
| 11 |
+
PreTrainedModel,
|
| 12 |
+
PretrainedConfig,
|
| 13 |
+
)
|
| 14 |
+
from typing import Any, Callable, Optional, Union
|
| 15 |
+
from typing import Any, Dict, List, Optional
|
| 16 |
+
import json
|
| 17 |
+
import logging
|
| 18 |
+
import math
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
MODEL_TYPE_ALIASES = {}
|
| 24 |
+
|
| 25 |
+
def build_encoder(enc_dict: Dict[str, Any]):
|
| 26 |
+
mtype = enc_dict.get("model_type")
|
| 27 |
+
if mtype:
|
| 28 |
+
base = MODEL_TYPE_ALIASES.get(mtype, mtype)
|
| 29 |
+
cfg_cls = AutoConfig.for_model(base)
|
| 30 |
+
enc_cfg = cfg_cls.from_dict(enc_dict)
|
| 31 |
+
elif enc_dict.get("_name_or_path"):
|
| 32 |
+
enc_cfg = AutoConfig.from_pretrained(enc_dict["_name_or_path"])
|
| 33 |
+
else:
|
| 34 |
+
raise KeyError("encoder config missing 'model_type' or '_name_or_path'")
|
| 35 |
+
if hasattr(enc_cfg, "add_pooling_layer"):
|
| 36 |
+
enc_cfg.add_pooling_layer = False
|
| 37 |
+
return AutoModel.from_config(enc_cfg)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MISTFinetunedConfig(PretrainedConfig):
|
| 41 |
+
"""HF config for a single-task MIST wrapper."""
|
| 42 |
+
|
| 43 |
+
model_type = "mist_finetuned"
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
encoder: Optional[Dict[str, Any]] = None,
|
| 48 |
+
task_network: Optional[Dict[str, Any]] = None,
|
| 49 |
+
transform: Optional[Dict[str, Any]] = None,
|
| 50 |
+
channels: Optional[List[Dict[str, Any]]] = None,
|
| 51 |
+
tokenizer_class: Optional[str] = "SmirkTokenizer",
|
| 52 |
+
**kwargs,
|
| 53 |
+
):
|
| 54 |
+
super().__init__(**kwargs)
|
| 55 |
+
self.encoder = encoder or {}
|
| 56 |
+
self.task_network = task_network or {}
|
| 57 |
+
self.transform = transform or {}
|
| 58 |
+
self.channels = channels
|
| 59 |
+
self.tokenizer_class = tokenizer_class
|
| 60 |
+
|
| 61 |
+
class MISTFinetuned(PreTrainedModel):
|
| 62 |
+
config_class = MISTFinetunedConfig
|
| 63 |
+
|
| 64 |
+
def __init__(self, config: MISTFinetunedConfig):
|
| 65 |
+
super().__init__(config)
|
| 66 |
+
self.encoder = build_encoder_from_dict(config.encoder)
|
| 67 |
+
|
| 68 |
+
tn = config.task_network
|
| 69 |
+
self.task_network = PredictionTaskHead(
|
| 70 |
+
embed_dim=tn["embed_dim"],
|
| 71 |
+
output_size=tn["output_size"],
|
| 72 |
+
dropout=tn["dropout"],
|
| 73 |
+
)
|
| 74 |
+
self.transform = AbstractNormalizer.get(
|
| 75 |
+
config.transform["class"], config.transform["num_outputs"]
|
| 76 |
+
)
|
| 77 |
+
self.channels = config.channels
|
| 78 |
+
self.tokenizer = None
|
| 79 |
+
self.post_init()
|
| 80 |
+
|
| 81 |
+
@classmethod
|
| 82 |
+
def from_components(
|
| 83 |
+
cls,
|
| 84 |
+
encoder: PreTrainedModel,
|
| 85 |
+
task_network: nn.Module,
|
| 86 |
+
transform: Any,
|
| 87 |
+
tokenizer: Optional[Any] = None,
|
| 88 |
+
channels: Optional[List[Dict[str, Any]]] = None,
|
| 89 |
+
) -> "MISTFinetuned":
|
| 90 |
+
cfg = MISTFinetunedConfig(
|
| 91 |
+
encoder=encoder.config.to_dict(),
|
| 92 |
+
task_network={
|
| 93 |
+
"embed_dim": encoder.config.hidden_size,
|
| 94 |
+
"output_size": task_network.final.out_features,
|
| 95 |
+
"dropout": task_network.dropout1.p,
|
| 96 |
+
},
|
| 97 |
+
transform=transform.to_config(),
|
| 98 |
+
channels=channels,
|
| 99 |
+
tokenizer_class=(
|
| 100 |
+
getattr(tokenizer, "__class__", type("T", (), {})).__name__
|
| 101 |
+
if tokenizer
|
| 102 |
+
else "SmirkTokenizer"
|
| 103 |
+
),
|
| 104 |
+
)
|
| 105 |
+
model = cls(cfg)
|
| 106 |
+
# load component weights
|
| 107 |
+
model.encoder.load_state_dict(encoder.state_dict(), strict=False)
|
| 108 |
+
model.task_network.load_state_dict(task_network.state_dict())
|
| 109 |
+
model.transform.load_state_dict(transform.state_dict())
|
| 110 |
+
model.tokenizer = tokenizer
|
| 111 |
+
return model
|
| 112 |
+
|
| 113 |
+
def forward(self, input_ids, attention_mask=None):
|
| 114 |
+
hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state
|
| 115 |
+
y = self.task_network(hs)
|
| 116 |
+
return self.transform.forward(y)
|
| 117 |
+
|
| 118 |
+
def _resolve_tokenizer(self, tokenizer):
|
| 119 |
+
if tokenizer is not None:
|
| 120 |
+
return tokenizer
|
| 121 |
+
if getattr(self, "tokenizer", None) is not None:
|
| 122 |
+
return self.tokenizer
|
| 123 |
+
try:
|
| 124 |
+
return AutoTokenizer.from_pretrained(self.name_or_path, use_fast=True)
|
| 125 |
+
except Exception:
|
| 126 |
+
return AutoTokenizer.from_pretrained(
|
| 127 |
+
self.config._name_or_path, use_fast=True
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def embed(self, smi: List[str], tokenizer=None):
|
| 131 |
+
tok = self._resolve_tokenizer(tokenizer)
|
| 132 |
+
batch = tok(smi)
|
| 133 |
+
batch = DataCollatorWithPadding(tok)(batch)
|
| 134 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 135 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
| 136 |
+
with torch.inference_mode():
|
| 137 |
+
hs = self.encoder(
|
| 138 |
+
input_ids, attention_mask=attention_mask
|
| 139 |
+
).last_hidden_state[:, 0, :]
|
| 140 |
+
return hs.to("cpu")
|
| 141 |
+
|
| 142 |
+
def predict(self, smi: List[str], return_dict: bool = True, tokenizer=None):
|
| 143 |
+
tok = self._resolve_tokenizer(tokenizer)
|
| 144 |
+
batch = tok(smi)
|
| 145 |
+
batch = DataCollatorWithPadding(tok)(batch)
|
| 146 |
+
inputs = {k: v.to(self.device) for k, v in batch.items()}
|
| 147 |
+
with torch.inference_mode():
|
| 148 |
+
out = self(**inputs).cpu()
|
| 149 |
+
if self.channels is None or not return_dict:
|
| 150 |
+
return out
|
| 151 |
+
return annotate_prediction(out, self.channels)
|
| 152 |
+
|
| 153 |
+
def save_pretrained(self, save_directory, **kwargs):
|
| 154 |
+
super().save_pretrained(save_directory, **kwargs)
|
| 155 |
+
if getattr(self, "tokenizer", None) is not None:
|
| 156 |
+
self.tokenizer.save_pretrained(save_directory)
|
| 157 |
+
|
| 158 |
+
def maybe_get_annotated_channels(channels: List[Any]):
|
| 159 |
+
for chn in channels:
|
| 160 |
+
if isinstance(chn, str):
|
| 161 |
+
yield {"name": chn, "description": None, "unit": None}
|
| 162 |
+
else:
|
| 163 |
+
yield chn
|
| 164 |
+
|
| 165 |
+
def annotate_prediction(
|
| 166 |
+
y: torch.Tensor, channels: List[Dict[str, str]]
|
| 167 |
+
) -> Dict[str, Dict[str, Any]]:
|
| 168 |
+
out: Dict[str, Dict[str, Any]] = {}
|
| 169 |
+
for idx, chn in enumerate(channels):
|
| 170 |
+
channel_info = {f: v for f, v in chn.items() if f != "name"}
|
| 171 |
+
out[chn["name"]] = {"value": y[:, idx], **channel_info}
|
| 172 |
+
return out
|
| 173 |
+
|
| 174 |
+
def build_encoder_from_dict(enc_dict):
|
| 175 |
+
if "model_type" in enc_dict:
|
| 176 |
+
cfg_cls = AutoConfig.for_model(enc_dict["model_type"])
|
| 177 |
+
enc_cfg = cfg_cls.from_dict(enc_dict, strict=False)
|
| 178 |
+
elif "_name_or_path" in enc_dict:
|
| 179 |
+
enc_cfg = AutoConfig.from_pretrained(enc_dict["_name_or_path"], strict=False)
|
| 180 |
+
else:
|
| 181 |
+
raise KeyError("Encoder config is missing 'model_type' and '_name_or_path.")
|
| 182 |
+
|
| 183 |
+
# Ensure pooling layer is disabled to match saved checkpoints
|
| 184 |
+
if hasattr(enc_cfg, "add_pooling_layer"):
|
| 185 |
+
enc_cfg.add_pooling_layer = False
|
| 186 |
+
|
| 187 |
+
return AutoModel.from_config(enc_cfg)
|
| 188 |
+
|
| 189 |
+
class MISTMultiTaskConfig(PretrainedConfig):
|
| 190 |
+
"""HuggingFace config for a multi-task MIST wrapper."""
|
| 191 |
+
|
| 192 |
+
model_type = "mist_multitask"
|
| 193 |
+
|
| 194 |
+
def __init__(
|
| 195 |
+
self,
|
| 196 |
+
encoder: Optional[Dict[str, Any]] = None,
|
| 197 |
+
task_networks: Optional[List[Dict[str, Any]]] = None,
|
| 198 |
+
transforms: Optional[List[Dict[str, Any]]] = None,
|
| 199 |
+
channels: Optional[List[Dict[str, Any]]] = None,
|
| 200 |
+
tokenizer_class: Optional[str] = "SmirkTokenizer",
|
| 201 |
+
**kwargs,
|
| 202 |
+
):
|
| 203 |
+
super().__init__(**kwargs)
|
| 204 |
+
self.encoder = encoder or {}
|
| 205 |
+
self.task_networks = task_networks or []
|
| 206 |
+
self.transforms = transforms or []
|
| 207 |
+
self.channels = channels
|
| 208 |
+
self.tokenizer_class = tokenizer_class
|
| 209 |
+
|
| 210 |
+
class MISTMultiTask(PreTrainedModel):
|
| 211 |
+
config_class = MISTMultiTaskConfig
|
| 212 |
+
|
| 213 |
+
def __init__(self, config: MISTMultiTaskConfig):
|
| 214 |
+
super().__init__(config)
|
| 215 |
+
self.encoder = build_encoder_from_dict(config.encoder)
|
| 216 |
+
|
| 217 |
+
self.task_networks = nn.ModuleList(
|
| 218 |
+
[
|
| 219 |
+
PredictionTaskHead(
|
| 220 |
+
embed_dim=tn["embed_dim"],
|
| 221 |
+
output_size=tn["output_size"],
|
| 222 |
+
dropout=tn["dropout"],
|
| 223 |
+
)
|
| 224 |
+
for tn in config.task_networks
|
| 225 |
+
]
|
| 226 |
+
)
|
| 227 |
+
self.transforms = nn.ModuleList(
|
| 228 |
+
[
|
| 229 |
+
AbstractNormalizer.get(tf_cfg["class"], tf_cfg["num_outputs"])
|
| 230 |
+
for tf_cfg in config.transforms
|
| 231 |
+
]
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
assert len(self.task_networks) == len(
|
| 235 |
+
self.transforms
|
| 236 |
+
), "task_networks and transforms must align"
|
| 237 |
+
self.channels = config.channels
|
| 238 |
+
self.tokenizer = None
|
| 239 |
+
self.post_init()
|
| 240 |
+
|
| 241 |
+
@classmethod
|
| 242 |
+
def from_components(
|
| 243 |
+
cls,
|
| 244 |
+
encoder: PreTrainedModel,
|
| 245 |
+
task_networks: List[nn.Module],
|
| 246 |
+
transforms: List[Any],
|
| 247 |
+
tokenizer: Optional[Any] = None,
|
| 248 |
+
channels: Optional[List[Dict[str, Any]]] = None,
|
| 249 |
+
) -> "MISTMultiTask":
|
| 250 |
+
cfg = MISTMultiTaskConfig(
|
| 251 |
+
encoder=encoder.config.to_dict(),
|
| 252 |
+
task_networks=[
|
| 253 |
+
{
|
| 254 |
+
"embed_dim": encoder.config.hidden_size,
|
| 255 |
+
"output_size": tn.final.out_features,
|
| 256 |
+
"dropout": tn.dropout1.p,
|
| 257 |
+
}
|
| 258 |
+
for tn in task_networks
|
| 259 |
+
],
|
| 260 |
+
transforms=[tf.to_config() for tf in transforms],
|
| 261 |
+
channels=channels,
|
| 262 |
+
tokenizer_class=(
|
| 263 |
+
getattr(tokenizer, "__class__", type("T", (), {})).__name__
|
| 264 |
+
if tokenizer
|
| 265 |
+
else "SmirkTokenizer"
|
| 266 |
+
),
|
| 267 |
+
)
|
| 268 |
+
model = cls(cfg)
|
| 269 |
+
model.encoder.load_state_dict(encoder.state_dict(), strict=False)
|
| 270 |
+
for dst, src in zip(model.task_networks, task_networks):
|
| 271 |
+
dst.load_state_dict(src.state_dict())
|
| 272 |
+
for dst, src in zip(model.transforms, transforms):
|
| 273 |
+
dst.load_state_dict(src.state_dict())
|
| 274 |
+
model.tokenizer = tokenizer
|
| 275 |
+
return model
|
| 276 |
+
|
| 277 |
+
def forward(self, input_ids, attention_mask=None):
|
| 278 |
+
hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state
|
| 279 |
+
outs = []
|
| 280 |
+
for tn, tf in zip(self.task_networks, self.transforms):
|
| 281 |
+
outs.append(tf.forward(tn(hs)))
|
| 282 |
+
return torch.cat(outs, dim=-1)
|
| 283 |
+
|
| 284 |
+
def _resolve_tokenizer(self, tokenizer):
|
| 285 |
+
if tokenizer is not None:
|
| 286 |
+
return tokenizer
|
| 287 |
+
if getattr(self, "tokenizer", None) is not None:
|
| 288 |
+
return self.tokenizer
|
| 289 |
+
try:
|
| 290 |
+
return AutoTokenizer.from_pretrained(self.name_or_path, use_fast=True)
|
| 291 |
+
except Exception:
|
| 292 |
+
return AutoTokenizer.from_pretrained(
|
| 293 |
+
self.config._name_or_path, use_fast=True
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
def predict(self, smi: List[str], tokenizer=None):
|
| 297 |
+
tok = self._resolve_tokenizer(tokenizer)
|
| 298 |
+
batch = tok(smi)
|
| 299 |
+
batch = DataCollatorWithPadding(tok)(batch)
|
| 300 |
+
inputs = {k: v.to(self.device) for k, v in batch.items()}
|
| 301 |
+
with torch.inference_mode():
|
| 302 |
+
out = self(**inputs).cpu()
|
| 303 |
+
if self.channels is None:
|
| 304 |
+
return out
|
| 305 |
+
return annotate_prediction(out, self.channels)
|
| 306 |
+
|
| 307 |
+
def embed(self, smi: List[str], tokenizer=None):
|
| 308 |
+
tok = self._resolve_tokenizer(tokenizer)
|
| 309 |
+
batch = tok(smi)
|
| 310 |
+
batch = DataCollatorWithPadding(tok)(batch)
|
| 311 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 312 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
| 313 |
+
with torch.inference_mode():
|
| 314 |
+
hs = self.encoder(
|
| 315 |
+
input_ids, attention_mask=attention_mask
|
| 316 |
+
).last_hidden_state[:, 0, :]
|
| 317 |
+
return hs.to("cpu")
|
| 318 |
+
|
| 319 |
+
def save_pretrained(self, save_directory, **kwargs):
|
| 320 |
+
super().save_pretrained(save_directory, **kwargs)
|
| 321 |
+
if getattr(self, "tokenizer", None) is not None:
|
| 322 |
+
self.tokenizer.save_pretrained(save_directory)
|
| 323 |
+
|
| 324 |
+
class PredictionTaskHead(nn.Module):
|
| 325 |
+
def __init__(
|
| 326 |
+
self, embed_dim: int, output_size: int = 1, dropout: float = 0.2
|
| 327 |
+
) -> None:
|
| 328 |
+
super().__init__()
|
| 329 |
+
self.desc_skip_connection = True
|
| 330 |
+
|
| 331 |
+
self.fc1 = nn.Linear(embed_dim, embed_dim)
|
| 332 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 333 |
+
self.relu1 = nn.GELU()
|
| 334 |
+
self.fc2 = nn.Linear(embed_dim, embed_dim)
|
| 335 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 336 |
+
self.relu2 = nn.GELU()
|
| 337 |
+
self.final = nn.Linear(embed_dim, output_size)
|
| 338 |
+
|
| 339 |
+
def forward(self, emb):
|
| 340 |
+
if emb.ndim > 2:
|
| 341 |
+
emb = emb[:, 0, :]
|
| 342 |
+
x_out = self.fc1(emb)
|
| 343 |
+
x_out = self.dropout1(x_out)
|
| 344 |
+
x_out = self.relu1(x_out)
|
| 345 |
+
|
| 346 |
+
if self.desc_skip_connection is True:
|
| 347 |
+
x_out = x_out + emb
|
| 348 |
+
|
| 349 |
+
z = self.fc2(x_out)
|
| 350 |
+
z = self.dropout2(z)
|
| 351 |
+
z = self.relu2(z)
|
| 352 |
+
if self.desc_skip_connection is True:
|
| 353 |
+
z = self.final(z + x_out)
|
| 354 |
+
else:
|
| 355 |
+
z = self.final(z)
|
| 356 |
+
return z
|
| 357 |
+
|
| 358 |
+
class AbstractNormalizer(torch.nn.Module):
|
| 359 |
+
def __init__(self, num_outputs: Optional[int] = None):
|
| 360 |
+
super().__init__()
|
| 361 |
+
self.num_outputs = num_outputs
|
| 362 |
+
|
| 363 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 364 |
+
"""Remove normalization"""
|
| 365 |
+
raise NotImplementedError
|
| 366 |
+
|
| 367 |
+
def inverse(self, x: torch.Tensor) -> torch.Tensor:
|
| 368 |
+
"""Apply normalization"""
|
| 369 |
+
raise NotImplementedError
|
| 370 |
+
|
| 371 |
+
def _fit(self, x: MaskedTensor) -> dict:
|
| 372 |
+
"""Fit the normalization parameters"""
|
| 373 |
+
raise NotImplementedError
|
| 374 |
+
|
| 375 |
+
def to_config(self) -> dict:
|
| 376 |
+
return {"class": self.__class__.__name__, "num_outputs": self.num_outputs}
|
| 377 |
+
|
| 378 |
+
def leader_fit(self, ds, rank: int, broadcast: Callable):
|
| 379 |
+
state = None
|
| 380 |
+
if rank == 0:
|
| 381 |
+
state = self.fit(ds)
|
| 382 |
+
state = broadcast(state)
|
| 383 |
+
self.load_state_dict(state)
|
| 384 |
+
|
| 385 |
+
def fit(self, ds, name: str = "target") -> dict:
|
| 386 |
+
"""Fit the normalization parameters on dataset"""
|
| 387 |
+
if isinstance(ds, IterableDataset):
|
| 388 |
+
target = []
|
| 389 |
+
mask = []
|
| 390 |
+
for x in ds:
|
| 391 |
+
target.append(x[name])
|
| 392 |
+
mask.append(x[f"{name}_mask"])
|
| 393 |
+
|
| 394 |
+
target = torch.stack(target)
|
| 395 |
+
mask = torch.stack(mask)
|
| 396 |
+
|
| 397 |
+
else:
|
| 398 |
+
target = torch.stack([torch.tensor(x) for x in ds[name]])
|
| 399 |
+
mask = torch.stack([torch.tensor(x) for x in ds[f"{name}_mask"]])
|
| 400 |
+
|
| 401 |
+
# Use masked tensor to compute normalization parameters
|
| 402 |
+
target = masked_tensor(target, mask)
|
| 403 |
+
|
| 404 |
+
state = self._fit(target)
|
| 405 |
+
return state
|
| 406 |
+
|
| 407 |
+
@classmethod
|
| 408 |
+
def get(
|
| 409 |
+
cls, transform: Optional[Union[list[str], str]], num_outputs: int
|
| 410 |
+
) -> "AbstractNormalizer":
|
| 411 |
+
if isinstance(transform, list):
|
| 412 |
+
assert len(transform) == num_outputs
|
| 413 |
+
return ChannelWiseTransform([cls.get(t, 1) for t in transform])
|
| 414 |
+
elif transform in ["standardize", Standardize.__name__]:
|
| 415 |
+
return Standardize(num_outputs)
|
| 416 |
+
elif transform in ["power_transform", PowerTransform.__name__]:
|
| 417 |
+
return PowerTransform(num_outputs)
|
| 418 |
+
elif transform in ["log_transform", LogTransform.__name__]:
|
| 419 |
+
return LogTransform(num_outputs)
|
| 420 |
+
elif transform in ["max_scale", MaxScaleTransform.__name__]:
|
| 421 |
+
return MaxScaleTransform(num_outputs)
|
| 422 |
+
else:
|
| 423 |
+
return IdentityTransform()
|
| 424 |
+
|
| 425 |
+
class Standardize(AbstractNormalizer):
|
| 426 |
+
def __init__(self, num_outputs: int, eps: float = 1e-8):
|
| 427 |
+
super().__init__(num_outputs)
|
| 428 |
+
self.register_buffer("mean", torch.zeros(num_outputs))
|
| 429 |
+
self.register_buffer("std", torch.zeros(num_outputs))
|
| 430 |
+
self.eps = float(eps)
|
| 431 |
+
assert 0 <= self.eps
|
| 432 |
+
|
| 433 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 434 |
+
return (self.std * x) + self.mean
|
| 435 |
+
|
| 436 |
+
def inverse(self, x: torch.Tensor) -> torch.Tensor:
|
| 437 |
+
return (x - self.mean) / self.std
|
| 438 |
+
|
| 439 |
+
def fit(self, ds, name: str = "target") -> dict:
|
| 440 |
+
num_outputs = self.num_outputs
|
| 441 |
+
assert num_outputs is not None
|
| 442 |
+
mean = torch.zeros(num_outputs)
|
| 443 |
+
m2 = torch.zeros(num_outputs)
|
| 444 |
+
n = torch.zeros(num_outputs, dtype=torch.int)
|
| 445 |
+
for row in ds:
|
| 446 |
+
target = torch.tensor(row[name])
|
| 447 |
+
mask = torch.tensor(row[f"{name}_mask"])
|
| 448 |
+
x = masked_tensor(target, mask)
|
| 449 |
+
n += mask.view(-1, num_outputs).sum(0)
|
| 450 |
+
xs = x.view(-1, num_outputs).sum(0)
|
| 451 |
+
delta = xs - mean
|
| 452 |
+
# Only update masked values
|
| 453 |
+
mean += (delta / n).get_data().masked_fill(~delta.get_mask(), 0)
|
| 454 |
+
delta2 = xs - mean
|
| 455 |
+
m2 += (delta * delta2).get_data().masked_fill(~delta.get_mask(), 0)
|
| 456 |
+
|
| 457 |
+
self.mean = mean.to(self.mean)
|
| 458 |
+
self.std = (m2 / n).sqrt().to(self.std) + self.eps
|
| 459 |
+
self.mean[self.mean.isnan()] = 0
|
| 460 |
+
self.std[self.std.isnan()] = 1
|
| 461 |
+
logging.debug("Fitted %s", self.state_dict())
|
| 462 |
+
return self.state_dict()
|
| 463 |
+
|
| 464 |
+
def _fit(self, target: MaskedTensor) -> dict:
|
| 465 |
+
self.mean = target.mean(0).get_data().to(self.mean)
|
| 466 |
+
self.std = target.std(0).get_data().to(self.std) + self.eps
|
| 467 |
+
return self.state_dict()
|
| 468 |
+
|
| 469 |
+
def load_state_dict(
|
| 470 |
+
self, state_dict: dict[str, Any], strict: bool = True, assign: bool = False
|
| 471 |
+
):
|
| 472 |
+
# Handle legacy case where keys have "transform." prefix
|
| 473 |
+
if "transform.mean" in state_dict:
|
| 474 |
+
state_dict = state_dict.copy() # Don't modify original
|
| 475 |
+
state_dict["mean"] = state_dict.pop("transform.mean")
|
| 476 |
+
state_dict["std"] = state_dict.pop("transform.std")
|
| 477 |
+
|
| 478 |
+
if assign:
|
| 479 |
+
# Manually assign buffers when assign=True
|
| 480 |
+
for key, value in state_dict.items():
|
| 481 |
+
if key in ["mean", "std"]:
|
| 482 |
+
# Use register_buffer to properly replace the buffer
|
| 483 |
+
self.register_buffer(key, value)
|
| 484 |
+
result = None # No incompatible keys when we do it manually
|
| 485 |
+
else:
|
| 486 |
+
result = super().load_state_dict(state_dict, strict=strict, assign=False)
|
| 487 |
+
|
| 488 |
+
logging.debug(f" After loading: mean={self.mean}, std={self.std}")
|
| 489 |
+
return result
|
| 490 |
+
|
| 491 |
+
class TokenTaskHead(nn.Module):
|
| 492 |
+
def __init__(
|
| 493 |
+
self, embed_dim: int, output_size: int = 1, dropout: float = 0.2
|
| 494 |
+
) -> None:
|
| 495 |
+
super().__init__()
|
| 496 |
+
self.layers = nn.Sequential(
|
| 497 |
+
nn.Linear(embed_dim, embed_dim),
|
| 498 |
+
nn.Dropout(dropout),
|
| 499 |
+
nn.GELU(),
|
| 500 |
+
nn.Linear(embed_dim, embed_dim),
|
| 501 |
+
nn.Dropout(dropout),
|
| 502 |
+
nn.GELU(),
|
| 503 |
+
nn.Linear(embed_dim, output_size),
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
def forward(self, emb):
|
| 507 |
+
return self.layers(emb)
|
| 508 |
+
|
| 509 |
+
class TokenPairwiseDistance(nn.Module):
|
| 510 |
+
def __init__(
|
| 511 |
+
self,
|
| 512 |
+
embed_dim: int,
|
| 513 |
+
dropout: float = 0.2,
|
| 514 |
+
num_attention_heads: int = 1,
|
| 515 |
+
num_layers: int = 1,
|
| 516 |
+
activation: str = "relu",
|
| 517 |
+
ff_ratio: int = 2,
|
| 518 |
+
) -> None:
|
| 519 |
+
super().__init__()
|
| 520 |
+
enc_layer = nn.TransformerEncoderLayer(
|
| 521 |
+
d_model=embed_dim,
|
| 522 |
+
nhead=num_attention_heads,
|
| 523 |
+
dim_feedforward=ff_ratio * embed_dim,
|
| 524 |
+
dropout=dropout,
|
| 525 |
+
batch_first=True,
|
| 526 |
+
norm_first=True,
|
| 527 |
+
)
|
| 528 |
+
self.interaction = nn.TransformerEncoder(enc_layer, num_layers)
|
| 529 |
+
self.pairwise_distance = PairwiseMLP(embed_dim, dropout)
|
| 530 |
+
self.distance1 = nn.Sequential(
|
| 531 |
+
nn.Linear(embed_dim, embed_dim), nn.Dropout(dropout), nn.GELU()
|
| 532 |
+
)
|
| 533 |
+
self.distance2 = nn.Linear(embed_dim, 1)
|
| 534 |
+
|
| 535 |
+
def forward(self, hs: torch.Tensor) -> torch.Tensor:
|
| 536 |
+
hs = self.interaction(hs)
|
| 537 |
+
|
| 538 |
+
with torch.autocast("cuda", dtype=torch.float32):
|
| 539 |
+
pw_dist = self.pairwise_distance(hs)
|
| 540 |
+
d = self.distance1(pw_dist) + pw_dist
|
| 541 |
+
d = self.distance2(d).squeeze(-1)
|
| 542 |
+
return F.relu(F.elu(d) + 1)
|
| 543 |
+
|
| 544 |
+
class BiPairwiseBlock(nn.Module):
|
| 545 |
+
def __init__(self, d_model: int, bias: bool = True, device=None, dtype=None):
|
| 546 |
+
super().__init__()
|
| 547 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 548 |
+
|
| 549 |
+
self.bi_weight = nn.Parameter(torch.empty((d_model, d_model), **factory_kwargs))
|
| 550 |
+
self.lin_weight = nn.Parameter(
|
| 551 |
+
torch.empty((d_model, d_model), **factory_kwargs)
|
| 552 |
+
)
|
| 553 |
+
if bias:
|
| 554 |
+
self.bias = nn.Parameter(torch.empty(d_model, **factory_kwargs))
|
| 555 |
+
else:
|
| 556 |
+
self.register_parameter("bias", None)
|
| 557 |
+
self.reset_parameters()
|
| 558 |
+
|
| 559 |
+
# Gradient hook to enforce symmetry
|
| 560 |
+
self.bi_weight.register_hook(lambda grad: 0.5 * (grad + grad.T))
|
| 561 |
+
|
| 562 |
+
def reset_parameters(self):
|
| 563 |
+
nn.init.xavier_normal_(self.lin_weight, gain=nn.init.calculate_gain("relu"))
|
| 564 |
+
nn.init.xavier_normal_(self.bi_weight, gain=nn.init.calculate_gain("relu"))
|
| 565 |
+
with torch.no_grad():
|
| 566 |
+
self.bi_weight.copy_(0.5 * (self.bi_weight + self.bi_weight.T))
|
| 567 |
+
|
| 568 |
+
if self.bias is not None:
|
| 569 |
+
bound = 1 / math.sqrt(self.bias.size(0))
|
| 570 |
+
nn.init.uniform_(self.bias, -bound, bound)
|
| 571 |
+
|
| 572 |
+
def forward(self, x: torch.Tensor):
|
| 573 |
+
y_bi = torch.einsum("...ld,df,...rf->...lrf", x, self.bi_weight, x)
|
| 574 |
+
y_bi = 0.5 * (y_bi + y_bi.transpose(-3, -2)) # Enforce symmetry
|
| 575 |
+
|
| 576 |
+
x_linear = x.unsqueeze(-2) + x.unsqueeze(-3)
|
| 577 |
+
return y_bi + F.linear(x_linear, self.lin_weight, self.bias)
|
| 578 |
+
|
| 579 |
+
class PairwiseMLP(nn.Module):
|
| 580 |
+
def __init__(
|
| 581 |
+
self,
|
| 582 |
+
d_model: int,
|
| 583 |
+
dropout: float = 0.2,
|
| 584 |
+
device=None,
|
| 585 |
+
dtype=None,
|
| 586 |
+
) -> None:
|
| 587 |
+
super().__init__()
|
| 588 |
+
self.mlp = nn.Sequential(
|
| 589 |
+
nn.Linear(2 * d_model, d_model),
|
| 590 |
+
nn.Dropout(dropout),
|
| 591 |
+
nn.GELU(),
|
| 592 |
+
nn.Linear(d_model, d_model),
|
| 593 |
+
nn.GELU(),
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
def forward(self, x: torch.Tensor):
|
| 597 |
+
_, N, _ = x.shape
|
| 598 |
+
x_l = x.unsqueeze(-2).expand(-1, N, N, -1)
|
| 599 |
+
x_r = x.unsqueeze(-3).expand(-1, N, N, -1)
|
| 600 |
+
x_pw = torch.cat([x_l, x_r], dim=-1)
|
| 601 |
+
y = self.mlp(x_pw)
|
| 602 |
+
return 0.5 * (y + y.transpose(1, 2))
|
| 603 |
+
|
| 604 |
+
class ChannelWiseTransform(AbstractNormalizer):
|
| 605 |
+
def __init__(self, transforms: list[AbstractNormalizer]):
|
| 606 |
+
super().__init__(len(transforms))
|
| 607 |
+
self.transforms = torch.nn.ModuleList(transforms)
|
| 608 |
+
|
| 609 |
+
def to_config(self) -> dict:
|
| 610 |
+
return {
|
| 611 |
+
"class": [t.__class__.__name__ for t in self.transforms],
|
| 612 |
+
"num_outputs": self.num_outputs,
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
def inverse(self, x: torch.Tensor) -> torch.Tensor:
|
| 616 |
+
return torch.cat(
|
| 617 |
+
[
|
| 618 |
+
transform.inverse(x[:, [idx]])
|
| 619 |
+
for idx, transform in enumerate(self.transforms)
|
| 620 |
+
],
|
| 621 |
+
dim=1,
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 625 |
+
return torch.cat(
|
| 626 |
+
[
|
| 627 |
+
transform.forward(x[:, [idx]])
|
| 628 |
+
for idx, transform in enumerate(self.transforms)
|
| 629 |
+
],
|
| 630 |
+
dim=1,
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
def _fit(self, x: MaskedTensor) -> dict:
|
| 634 |
+
for idx, transform in enumerate(self.transforms):
|
| 635 |
+
transform._fit(x[:, [idx]])
|
| 636 |
+
return self.state_dict()
|
| 637 |
+
|
| 638 |
+
class LogTransform(Standardize):
|
| 639 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 640 |
+
return torch.exp(super().forward(x))
|
| 641 |
+
|
| 642 |
+
def inverse(self, x: torch.Tensor) -> torch.Tensor:
|
| 643 |
+
return super().inverse(torch.log(x))
|
| 644 |
+
|
| 645 |
+
def _fit(self, target: MaskedTensor) -> dict:
|
| 646 |
+
return super()._fit(torch.log(target))
|
| 647 |
+
|
| 648 |
+
class PowerTransform(AbstractNormalizer):
|
| 649 |
+
"""
|
| 650 |
+
Apply a power transform (Yeo-Johnson) featurewise to make data more Gaussian-like.
|
| 651 |
+
Followed by applying a zero-mean, unit-variance normalization to the
|
| 652 |
+
transformed output to rescale targets to [-1, 1].
|
| 653 |
+
"""
|
| 654 |
+
|
| 655 |
+
def __init__(self, num_outputs, eps: float = 1e-8):
|
| 656 |
+
super().__init__(num_outputs)
|
| 657 |
+
self.num_outputs = num_outputs
|
| 658 |
+
self.register_buffer("lmbdas", torch.zeros(num_outputs))
|
| 659 |
+
self.register_buffer("mean", torch.zeros(num_outputs))
|
| 660 |
+
self.register_buffer("std", torch.zeros(num_outputs))
|
| 661 |
+
self.eps = float(eps)
|
| 662 |
+
assert 0 <= self.eps
|
| 663 |
+
|
| 664 |
+
def _yeo_johnson_transform(self, x, lmbda):
|
| 665 |
+
"""
|
| 666 |
+
Return transformed input x following Yeo-Johnson transform with
|
| 667 |
+
parameter lambda.
|
| 668 |
+
Adapted from
|
| 669 |
+
https://github.com/scikit-learn/scikit-learn/blob/fbb32eae5/sklearn/preprocessing/_data.py#L3354
|
| 670 |
+
"""
|
| 671 |
+
x_out = x.clone()
|
| 672 |
+
eps = torch.finfo(x.dtype).eps
|
| 673 |
+
pos = x >= 0 # binary mask
|
| 674 |
+
|
| 675 |
+
# when x >= 0
|
| 676 |
+
if abs(lmbda) < eps:
|
| 677 |
+
x_out[pos] = torch.log1p(x[pos])
|
| 678 |
+
else: # lmbda != 0
|
| 679 |
+
x_out[pos] = (torch.pow(x[pos] + 1, lmbda) - 1) / lmbda
|
| 680 |
+
|
| 681 |
+
# when x < 0
|
| 682 |
+
if abs(lmbda - 2) > eps:
|
| 683 |
+
x_out[~pos] = -(torch.pow(-x[~pos] + 1, 2 - lmbda) - 1) / (2 - lmbda)
|
| 684 |
+
else: # lmbda == 2
|
| 685 |
+
x_out[~pos] = -torch.log1p(-x[~pos])
|
| 686 |
+
|
| 687 |
+
return x_out
|
| 688 |
+
|
| 689 |
+
def _yeo_johnson_inverse_transform(self, x, lmbda):
|
| 690 |
+
"""
|
| 691 |
+
Return inverse-transformed input x following Yeo-Johnson inverse
|
| 692 |
+
transform with parameter lambda.
|
| 693 |
+
Adapted from
|
| 694 |
+
https://github.com/scikit-learn/scikit-learn/blob/fbb32eae5/sklearn/preprocessing/_data.py#L3383
|
| 695 |
+
"""
|
| 696 |
+
x_out = x.clone()
|
| 697 |
+
pos = x >= 0
|
| 698 |
+
eps = torch.finfo(x.dtype).eps
|
| 699 |
+
|
| 700 |
+
# when x >= 0
|
| 701 |
+
if abs(lmbda) < eps: # lmbda == 0
|
| 702 |
+
x_out[pos] = torch.exp(x[pos]) - 1
|
| 703 |
+
else: # lmbda != 0
|
| 704 |
+
x_out[pos] = torch.pow(x[pos] * lmbda + 1, 1 / lmbda) - 1
|
| 705 |
+
|
| 706 |
+
# when x < 0
|
| 707 |
+
if abs(lmbda - 2) > eps: # lmbda != 2
|
| 708 |
+
x_out[~pos] = 1 - torch.pow(-(2 - lmbda) * x[~pos] + 1, 1 / (2 - lmbda))
|
| 709 |
+
else: # lmbda == 2
|
| 710 |
+
x_out[~pos] = 1 - torch.exp(-x[~pos])
|
| 711 |
+
return x_out
|
| 712 |
+
|
| 713 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 714 |
+
# Undo standardization
|
| 715 |
+
x = (self.std * x) + self.mean
|
| 716 |
+
x_out = torch.zeros_like(x)
|
| 717 |
+
for i in range(self.num_outputs):
|
| 718 |
+
x_out[:, i] = self._yeo_johnson_inverse_transform(x[:, i], self.lmbdas[i])
|
| 719 |
+
return x_out
|
| 720 |
+
|
| 721 |
+
def inverse(self, x: torch.Tensor) -> torch.Tensor:
|
| 722 |
+
x_out = torch.zeros_like(x)
|
| 723 |
+
for i in range(self.num_outputs):
|
| 724 |
+
x_out[:, i] = self._yeo_johnson_transform(x[:, i], self.lmbdas[i])
|
| 725 |
+
# Standardization
|
| 726 |
+
x_out = (x_out - self.mean) / self.std
|
| 727 |
+
return x_out
|
| 728 |
+
|
| 729 |
+
def _fit(self, target: MaskedTensor) -> dict:
|
| 730 |
+
# Fit Yeo-Johnson lambdas
|
| 731 |
+
from sklearn.preprocessing import (
|
| 732 |
+
PowerTransformer as _PowerTransformer, # noqa: F811
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
transformer = _PowerTransformer(method="yeo-johnson", standardize=False)
|
| 736 |
+
target = torch.tensor(transformer.fit_transform(target.get_data().numpy()))
|
| 737 |
+
self.lmbdas = torch.tensor(transformer.lambdas_)
|
| 738 |
+
# Fit standardization scaling
|
| 739 |
+
self.mean = target.mean(0).to(self.mean)
|
| 740 |
+
self.std = target.std(0).to(self.std) + self.eps
|
| 741 |
+
return self.state_dict()
|
| 742 |
+
|
| 743 |
+
class MaxScaleTransform(AbstractNormalizer):
|
| 744 |
+
"""
|
| 745 |
+
Divide by maximum value in training dataset.
|
| 746 |
+
"""
|
| 747 |
+
|
| 748 |
+
def __init__(self, mx: int, eps: float = 1e-8):
|
| 749 |
+
super().__init__(1)
|
| 750 |
+
self.num_outputs = 1
|
| 751 |
+
self.max = mx
|
| 752 |
+
self.eps = float(eps)
|
| 753 |
+
assert 0 <= self.eps
|
| 754 |
+
|
| 755 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 756 |
+
# Undo standardization
|
| 757 |
+
x_out = self.max * x
|
| 758 |
+
return x_out
|
| 759 |
+
|
| 760 |
+
def inverse(self, x: torch.Tensor) -> torch.Tensor:
|
| 761 |
+
x_out = x / self.max
|
| 762 |
+
return x_out
|
| 763 |
+
|
| 764 |
+
def _fit(self, target: MaskedTensor) -> dict:
|
| 765 |
+
return self.state_dict()
|
| 766 |
+
|
| 767 |
+
class IdentityTransform(AbstractNormalizer):
|
| 768 |
+
def inverse(self, x: torch.Tensor) -> torch.Tensor:
|
| 769 |
+
return x
|
| 770 |
+
|
| 771 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 772 |
+
return x
|
| 773 |
+
|
| 774 |
+
def _fit(self, x: MaskedTensor) -> dict:
|
| 775 |
+
return self.state_dict()
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers==4.57.1
|
| 2 |
+
torch==2.9.0
|
| 3 |
+
scikit-learn==1.7.2
|
| 4 |
+
smirk==0.2.0.dev0
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "[BOS]",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"cls_token": {
|
| 10 |
+
"content": "[CLS]",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"eos_token": {
|
| 17 |
+
"content": "[EOS]",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"mask_token": {
|
| 24 |
+
"content": "[MASK]",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
},
|
| 30 |
+
"pad_token": {
|
| 31 |
+
"content": "[PAD]",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false
|
| 36 |
+
},
|
| 37 |
+
"sep_token": {
|
| 38 |
+
"content": "[SEP]",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false
|
| 43 |
+
},
|
| 44 |
+
"unk_token": {
|
| 45 |
+
"content": "[UNK]",
|
| 46 |
+
"lstrip": false,
|
| 47 |
+
"normalized": false,
|
| 48 |
+
"rstrip": false,
|
| 49 |
+
"single_word": false
|
| 50 |
+
}
|
| 51 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "1.0",
|
| 3 |
+
"truncation": null,
|
| 4 |
+
"padding": null,
|
| 5 |
+
"added_tokens": [
|
| 6 |
+
{
|
| 7 |
+
"id": 0,
|
| 8 |
+
"content": "[UNK]",
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"lstrip": false,
|
| 11 |
+
"rstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"special": true
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"id": 159,
|
| 17 |
+
"content": "[BOS]",
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"normalized": false,
|
| 22 |
+
"special": true
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"id": 160,
|
| 26 |
+
"content": "[EOS]",
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"lstrip": false,
|
| 29 |
+
"rstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"special": true
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"id": 161,
|
| 35 |
+
"content": "[SEP]",
|
| 36 |
+
"single_word": false,
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"rstrip": false,
|
| 39 |
+
"normalized": false,
|
| 40 |
+
"special": true
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"id": 162,
|
| 44 |
+
"content": "[PAD]",
|
| 45 |
+
"single_word": false,
|
| 46 |
+
"lstrip": false,
|
| 47 |
+
"rstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"special": true
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"id": 163,
|
| 53 |
+
"content": "[CLS]",
|
| 54 |
+
"single_word": false,
|
| 55 |
+
"lstrip": false,
|
| 56 |
+
"rstrip": false,
|
| 57 |
+
"normalized": false,
|
| 58 |
+
"special": true
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"id": 164,
|
| 62 |
+
"content": "[MASK]",
|
| 63 |
+
"single_word": false,
|
| 64 |
+
"lstrip": false,
|
| 65 |
+
"rstrip": false,
|
| 66 |
+
"normalized": false,
|
| 67 |
+
"special": true
|
| 68 |
+
}
|
| 69 |
+
],
|
| 70 |
+
"normalizer": {
|
| 71 |
+
"type": "Sequence",
|
| 72 |
+
"normalizers": [
|
| 73 |
+
{
|
| 74 |
+
"type": "Replace",
|
| 75 |
+
"pattern": {
|
| 76 |
+
"String": "++"
|
| 77 |
+
},
|
| 78 |
+
"content": "+2"
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"type": "Replace",
|
| 82 |
+
"pattern": {
|
| 83 |
+
"String": "--"
|
| 84 |
+
},
|
| 85 |
+
"content": "-2"
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"type": "Strip",
|
| 89 |
+
"strip_left": true,
|
| 90 |
+
"strip_right": true
|
| 91 |
+
}
|
| 92 |
+
]
|
| 93 |
+
},
|
| 94 |
+
"pre_tokenizer": {
|
| 95 |
+
"outer": "Br?|Cl?|F|I|N|O|P|S|b|c|n|o|p|s|\\*|[\\.\\-=\\#\\$:/\\\\]|\\d|%|\\(|\\)|\\[.*?]",
|
| 96 |
+
"inner": "(\\d+)?(A[c|g|l|m|r|s|t|u]|B[a|e|h|i|k|r]?|C[a|d|e|f|l|m|n|o|r|s|u]?|D[b|s|y]|E[r|s|u]|F[e|l|m|r]?|G[a|d|e]|H[e|f|g|o|s]?|I[n|r]?|Kr?|L[a|i|r|u|v]|M[c|d|g|n|o|t]|N[a|b|d|e|h|i|o|p]?|O[g|s]?|P[a|b|d|m|o|r|t|u]?|R[a|b|e|f|g|h|n|u]|S[b|c|e|g|i|m|n|r]?|T[a|b|c|e|h|i|l|m|s]|U|V|W|Xe|Yb?|Z[n|r]|as|b|c|n|o|p|se?|\\*)(?:(@(?:@|AL|OH|SP|T[B|H])?)(\\d{1,2})?)?(?:(H)(\\d)?)?(?:([+-]{1,2})(\\d{0,2}))?(?:(:)(\\d+))?"
|
| 97 |
+
},
|
| 98 |
+
"post_processor": null,
|
| 99 |
+
"decoder": {
|
| 100 |
+
"type": "Fuse"
|
| 101 |
+
},
|
| 102 |
+
"model": {
|
| 103 |
+
"type": "WordLevel",
|
| 104 |
+
"vocab": {
|
| 105 |
+
"[UNK]": 0,
|
| 106 |
+
"#": 1,
|
| 107 |
+
"$": 2,
|
| 108 |
+
"%": 3,
|
| 109 |
+
"(": 4,
|
| 110 |
+
")": 5,
|
| 111 |
+
"*": 6,
|
| 112 |
+
"+": 7,
|
| 113 |
+
"-": 8,
|
| 114 |
+
".": 9,
|
| 115 |
+
"/": 10,
|
| 116 |
+
"0": 11,
|
| 117 |
+
"1": 12,
|
| 118 |
+
"2": 13,
|
| 119 |
+
"3": 14,
|
| 120 |
+
"4": 15,
|
| 121 |
+
"5": 16,
|
| 122 |
+
"6": 17,
|
| 123 |
+
"7": 18,
|
| 124 |
+
"8": 19,
|
| 125 |
+
"9": 20,
|
| 126 |
+
":": 21,
|
| 127 |
+
"=": 22,
|
| 128 |
+
"@": 23,
|
| 129 |
+
"@@": 24,
|
| 130 |
+
"@AL": 25,
|
| 131 |
+
"@OH": 26,
|
| 132 |
+
"@SP": 27,
|
| 133 |
+
"@TB": 28,
|
| 134 |
+
"@TH": 29,
|
| 135 |
+
"Ac": 30,
|
| 136 |
+
"Ag": 31,
|
| 137 |
+
"Al": 32,
|
| 138 |
+
"Am": 33,
|
| 139 |
+
"Ar": 34,
|
| 140 |
+
"As": 35,
|
| 141 |
+
"At": 36,
|
| 142 |
+
"Au": 37,
|
| 143 |
+
"B": 38,
|
| 144 |
+
"Ba": 39,
|
| 145 |
+
"Be": 40,
|
| 146 |
+
"Bh": 41,
|
| 147 |
+
"Bi": 42,
|
| 148 |
+
"Bk": 43,
|
| 149 |
+
"Br": 44,
|
| 150 |
+
"C": 45,
|
| 151 |
+
"Ca": 46,
|
| 152 |
+
"Cd": 47,
|
| 153 |
+
"Ce": 48,
|
| 154 |
+
"Cf": 49,
|
| 155 |
+
"Cl": 50,
|
| 156 |
+
"Cm": 51,
|
| 157 |
+
"Cn": 52,
|
| 158 |
+
"Co": 53,
|
| 159 |
+
"Cr": 54,
|
| 160 |
+
"Cs": 55,
|
| 161 |
+
"Cu": 56,
|
| 162 |
+
"Db": 57,
|
| 163 |
+
"Ds": 58,
|
| 164 |
+
"Dy": 59,
|
| 165 |
+
"Er": 60,
|
| 166 |
+
"Es": 61,
|
| 167 |
+
"Eu": 62,
|
| 168 |
+
"F": 63,
|
| 169 |
+
"Fe": 64,
|
| 170 |
+
"Fl": 65,
|
| 171 |
+
"Fm": 66,
|
| 172 |
+
"Fr": 67,
|
| 173 |
+
"Ga": 68,
|
| 174 |
+
"Gd": 69,
|
| 175 |
+
"Ge": 70,
|
| 176 |
+
"H": 71,
|
| 177 |
+
"He": 72,
|
| 178 |
+
"Hf": 73,
|
| 179 |
+
"Hg": 74,
|
| 180 |
+
"Ho": 75,
|
| 181 |
+
"Hs": 76,
|
| 182 |
+
"I": 77,
|
| 183 |
+
"In": 78,
|
| 184 |
+
"Ir": 79,
|
| 185 |
+
"K": 80,
|
| 186 |
+
"Kr": 81,
|
| 187 |
+
"La": 82,
|
| 188 |
+
"Li": 83,
|
| 189 |
+
"Lr": 84,
|
| 190 |
+
"Lu": 85,
|
| 191 |
+
"Lv": 86,
|
| 192 |
+
"Mc": 87,
|
| 193 |
+
"Md": 88,
|
| 194 |
+
"Mg": 89,
|
| 195 |
+
"Mn": 90,
|
| 196 |
+
"Mo": 91,
|
| 197 |
+
"Mt": 92,
|
| 198 |
+
"N": 93,
|
| 199 |
+
"Na": 94,
|
| 200 |
+
"Nb": 95,
|
| 201 |
+
"Nd": 96,
|
| 202 |
+
"Ne": 97,
|
| 203 |
+
"Nh": 98,
|
| 204 |
+
"Ni": 99,
|
| 205 |
+
"No": 100,
|
| 206 |
+
"Np": 101,
|
| 207 |
+
"O": 102,
|
| 208 |
+
"Og": 103,
|
| 209 |
+
"Os": 104,
|
| 210 |
+
"P": 105,
|
| 211 |
+
"Pa": 106,
|
| 212 |
+
"Pb": 107,
|
| 213 |
+
"Pd": 108,
|
| 214 |
+
"Pm": 109,
|
| 215 |
+
"Po": 110,
|
| 216 |
+
"Pr": 111,
|
| 217 |
+
"Pt": 112,
|
| 218 |
+
"Pu": 113,
|
| 219 |
+
"Ra": 114,
|
| 220 |
+
"Rb": 115,
|
| 221 |
+
"Re": 116,
|
| 222 |
+
"Rf": 117,
|
| 223 |
+
"Rg": 118,
|
| 224 |
+
"Rh": 119,
|
| 225 |
+
"Rn": 120,
|
| 226 |
+
"Ru": 121,
|
| 227 |
+
"S": 122,
|
| 228 |
+
"Sb": 123,
|
| 229 |
+
"Sc": 124,
|
| 230 |
+
"Se": 125,
|
| 231 |
+
"Sg": 126,
|
| 232 |
+
"Si": 127,
|
| 233 |
+
"Sm": 128,
|
| 234 |
+
"Sn": 129,
|
| 235 |
+
"Sr": 130,
|
| 236 |
+
"Ta": 131,
|
| 237 |
+
"Tb": 132,
|
| 238 |
+
"Tc": 133,
|
| 239 |
+
"Te": 134,
|
| 240 |
+
"Th": 135,
|
| 241 |
+
"Ti": 136,
|
| 242 |
+
"Tl": 137,
|
| 243 |
+
"Tm": 138,
|
| 244 |
+
"Ts": 139,
|
| 245 |
+
"U": 140,
|
| 246 |
+
"V": 141,
|
| 247 |
+
"W": 142,
|
| 248 |
+
"Xe": 143,
|
| 249 |
+
"Y": 144,
|
| 250 |
+
"Yb": 145,
|
| 251 |
+
"Zn": 146,
|
| 252 |
+
"Zr": 147,
|
| 253 |
+
"[": 148,
|
| 254 |
+
"\\": 149,
|
| 255 |
+
"]": 150,
|
| 256 |
+
"as": 151,
|
| 257 |
+
"b": 152,
|
| 258 |
+
"c": 153,
|
| 259 |
+
"n": 154,
|
| 260 |
+
"o": 155,
|
| 261 |
+
"p": 156,
|
| 262 |
+
"s": 157,
|
| 263 |
+
"se": 158
|
| 264 |
+
},
|
| 265 |
+
"unk_token": "[UNK]"
|
| 266 |
+
}
|
| 267 |
+
}
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[UNK]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": true,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": false
|
| 10 |
+
},
|
| 11 |
+
"159": {
|
| 12 |
+
"content": "[BOS]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": true,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": false
|
| 18 |
+
},
|
| 19 |
+
"160": {
|
| 20 |
+
"content": "[EOS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": true,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": false
|
| 26 |
+
},
|
| 27 |
+
"161": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": true,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": false
|
| 34 |
+
},
|
| 35 |
+
"162": {
|
| 36 |
+
"content": "[PAD]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": true,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": false
|
| 42 |
+
},
|
| 43 |
+
"163": {
|
| 44 |
+
"content": "[CLS]",
|
| 45 |
+
"lstrip": false,
|
| 46 |
+
"normalized": true,
|
| 47 |
+
"rstrip": false,
|
| 48 |
+
"single_word": false,
|
| 49 |
+
"special": false
|
| 50 |
+
},
|
| 51 |
+
"164": {
|
| 52 |
+
"content": "[MASK]",
|
| 53 |
+
"lstrip": false,
|
| 54 |
+
"normalized": true,
|
| 55 |
+
"rstrip": false,
|
| 56 |
+
"single_word": false,
|
| 57 |
+
"special": false
|
| 58 |
+
}
|
| 59 |
+
},
|
| 60 |
+
"bos_token": "[BOS]",
|
| 61 |
+
"clean_up_tokenization_spaces": false,
|
| 62 |
+
"cls_token": "[CLS]",
|
| 63 |
+
"eos_token": "[EOS]",
|
| 64 |
+
"extra_special_tokens": {},
|
| 65 |
+
"mask_token": "[MASK]",
|
| 66 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 67 |
+
"pad_token": "[PAD]",
|
| 68 |
+
"sep_token": "[SEP]",
|
| 69 |
+
"tokenizer_class": "SmirkTokenizerFast",
|
| 70 |
+
"unk_token": "[UNK]",
|
| 71 |
+
"vocab_file": "/nfs/turbo/coe-venkvis/abhutani/electrolyte-fm/.venv/lib/python3.10/site-packages/smirk/vocab_smiles.json"
|
| 72 |
+
}
|