anoushka2000 commited on
Commit
9905e27
·
verified ·
1 Parent(s): f5fef0f

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ library_name: transformers
4
+ license: gpl-3.0
5
+ tags:
6
+ - mist
7
+ - chemistry
8
+ - molecular-property-prediction
9
+ ---
10
+
11
+ # MIST: Molecular Insight SMILES Transformers
12
+
13
+ MIST is a family of molecular foundation models for molecular property prediction.
14
+ The models were pre-trained on SMILES strings from the [Enamine REAL Space](https://enamine.net/compound-collections/real-compounds/real-space-navigator) dataset using the Masked Language Modeling (MLM) objective, then fine-tuned for downstream prediction tasks.
15
+
16
+ ## Model Details
17
+
18
+ - **Architecture**: Encoder-only transformer [``RoBERTa-PreLayerNorm``](https://huggingface.co/docs/transformers/en/model_doc/roberta-prelayernorm)
19
+ - **Pre-training**: Masked Language Modeling on molecular SMILES
20
+ - **Tokenization**: [``Smirk``](https://eeg.engin.umich.edu/smirk/) tokenizer
21
+
22
+
23
+ ### Quick Start
24
+
25
+ ```python
26
+ from transformers import AutoModel
27
+
28
+ # Load the model
29
+ model = AutoModel.from_pretrained(
30
+ "path/to/model",
31
+ trust_remote_code=True
32
+ )
33
+
34
+ # Make predictions
35
+ smiles_batch = [
36
+ "CCO", # Ethanol
37
+ "CC(=O)O", # Acetic acid
38
+ "c1ccccc1" # Benzene
39
+ ]
40
+ results = model.predict(smiles_batch)
41
+ ```
42
+
43
+ ### Setting Up Your Environment
44
+
45
+ Create a virtual environment and install dependencies:
46
+
47
+ ```bash
48
+ python -m venv .venv
49
+ source .venv/bin/activate # On Windows: .venv\Scripts\activate
50
+ pip install -r requirements.txt
51
+ ```
52
+
53
+ > **Note**: SMIRK tokenizers require Rust to be installed. See the [Rust installation guide](https://www.rust-lang.org/tools/install) for details.
54
+
55
+
56
+ ## Model Inputs and Outputs
57
+
58
+ ### Inputs
59
+ - **SMILES strings**: Standard SMILES notation for molecular structures
60
+ - **Batch size**: Variable, automatically padded during inference
61
+
62
+ ### Outputs
63
+ - **Predictions**: Task-specific numerical or categorical predictions
64
+ - **Format**: Dictionary with channel names and predicted values (if channels are configured), or raw tensor output
65
+
66
+
67
+ ## Provided Models
68
+
69
+ ### Pre-trained
70
+ - `mist-1.8B-dh61satt`: Flagship MIST model (MIST-1.8B)
71
+ - `mist-28M-ti624ev1`: Smaller MIST model (MIST-28M).
72
+
73
+ Below is a full list of finetuned variants hosted on HuggingFace:
74
+ ### MoleculeNet Benchmark Models
75
+
76
+ | Folder | Encoder | Dataset |
77
+ | ---------------------------- | :------: | ------------------------------------ |
78
+ | mist-1.8B-fbdn8e35-bbbp | MIST-1.8B| MoleculeNet BBBP |
79
+ | mist-1.8B-1a4puhg2-hiv | MIST-1.8B| MoleculeNet HIV |
80
+ | mist-1.8B-m50jgolp-bace | MIST-1.8B| MoleculeNet BACE |
81
+ | mist-1.8B-uop1z0dc-tox21 | MIST-1.8B| MoleculeNet Tox21 |
82
+ | mist-1.8B-lu1l5ieh-clintox | MIST-1.8B| MoleculeNet ClinTox |
83
+ | mist-1.8B-l1wfo7oa-sider | MIST-1.8B| MoleculeNet SIDER. |
84
+ | mist-1.8B-hxiygjsm-esol | MIST-1.8B| MoleculeNet ESOL |
85
+ | mist-1.8B-iwqj2cld-freesolv | MIST-1.8B| MoleculeNet FreeSolv |
86
+ | mist-1.8B-jvt4azpz-lipo | MIST-1.8B| MoleculeNet Lipophilicity |
87
+ | mist-1.8B-8nd1ot5j-qm8 | MIST-1.8B| MoleculeNet QM8 |
88
+ | mist-28M-3xpfhv48-bbbp | MIST-28M | MoleculeNet BBBP |
89
+ | mist-28M-8fh43gke-hiv | MIST-28M | MoleculeNet HIV |
90
+ | mist-28M-8loj3bab-bace | MIST-28M | MoleculeNet BACE |
91
+ | mist-28M-kw4ks27p-tox21 | MIST-28M | MoleculeNet Tox21 |
92
+ | mist-28M-97vfcykk-clintox | MIST-28M | MoleculeNet ClinTox |
93
+ | mist-28M-z8qo16uy-sider | MIST-28M | MoleculeNet SIDER |
94
+ | mist-28M-kcwb9le5-esol | MIST-28M | MoleculeNet ESOL |
95
+ | mist-28M-0uiq7o7m-freesolv | MIST-28M | MoleculeNet FreeSolv |
96
+ | mist-28M-xzr5ulva-lipo | MIST-28M | MoleculeNet Lipophilicity |
97
+ | mist-28M-gzwqzpcr-qm8 | MIST-28M | MoleculeNet QM8 |
98
+
99
+
100
+ #### QM9 Benchmark Models
101
+ The single target (MIST-1.8B encoder) models for properties in QM9 are available.
102
+
103
+ | Folder | Encoder | Target |
104
+ | ---------------------------- | :------: | ----------------------------------------------------------------- |
105
+ | mist-1.8B-ez05expv-mu | MIST-1.8B| μ - Dipole moment (unit: D) |
106
+ | mist-1.8B-rcwary93-alpha | MIST-1.8B| α - Isotropic polarizability (unit: Bohr^3) |
107
+ | mist-1.8B-jmjosq12-homo | MIST-1.8B| HOMO - Highest occupied molecular orbital energy (unit: Hartree) |
108
+ | mist-1.8B-n14wshc9-lumo | MIST-1.8B| LUMO - Lowest unoccupied molecular orbital energy (unit: Hartree) |
109
+ | mist-1.8B-kayun6v3-gap | MIST-1.8B| Gap - Gap between HOMO and LUMO (unit: Hartree) |
110
+ | mist-1.8B-xxe7t35e-r2 | MIST-1.8B| \<R2\> - Electronic spatial extent (unit: Bohr^2) |
111
+ | mist-1.8B-6nmcwyrp-zpve | MIST-1.8B| ZPVE - Zero point vibrational energy (unit: Hartree) |
112
+ | mist-1.8B-a7akimjj-u0 | MIST-1.8B| U0 - Internal energy at 0K (unit: Hartree) |
113
+ | mist-1.8B-85f24xkj-u298 | MIST-1.8B| U298 - Internal energy at 298.15K (unit: Hartree) |
114
+ | mist-1.8B-3fbbz4is-h298 | MIST-1.8B| H298 - Enthalpy at 298.15K (unit: Hartree) |
115
+ | mist-1.8B-09sntn03-g298 | MIST-1.8B| G298 - Free energy at 298.15K (unit: Hartree) |
116
+ | mist-1.8B-j356b3nf-cv | MIST-1.8B| Cv - Heat capacity at 298.15K (unit: cal/(mol*K)) |
117
+
118
+ - `mist-ti624ev1-moleculenet`: Contains MoleculeNet benchmark MIST-28M models trained as part of doi:10.5281/zenodo.13761263
119
+
120
+
121
+ ### Finetuned Single Task Models
122
+
123
+ These models consist of a MIST-encoder and task network finetuned on a single dataset used in the applications demonstrated in the manuscript.
124
+
125
+ | Folder | Encoder | Dataset |
126
+ | ------------------------- | :------: | ----------------------------------------------------------- |
127
+ | mist-26.9M-48kpooqf-odour | MIST-28M | Olfaction |
128
+ | mist-26.9M-6hk5coof-dn | MIST-28M | Donor Number |
129
+ | mist-26.9M-0vxdbm36-kt | MIST-28M | Kamlet-Taft Solvochromatic Parameters |
130
+ | mist-26.9M-b302p09x-bp | MIST-28M | Boiling Point (Part of Characteristic Temperatures Dataset) |
131
+ | mist-26.9M-cyuo2xb6-fp | MIST-28M | Flash Point (Part of Characteristic Temperatures Dataset) |
132
+ | mist-26.9M-y3ge5pf9-mp | MIST-28M | Melting Point (Part of Characteristic Temperatures Dataset) |
133
+
134
+ ### Finetuned Multi-Task Models
135
+ These are additional multi-target finetuned models consisting of a MIST encoder and task network.
136
+ | Folder | Encoder | Dataset |
137
+ | -------------------------- | :------: | ----------------------------------------------------------- |
138
+ | mist-26.9M-kkgx0omx-qm9 | MIST-28M | QM9 Dataset with SMILES randomization |
139
+ | mist-28M-ttqcvt6fs-toxcast | MIST-28M | ToxCast |
140
+ | mist-28M-yr1urd2c-muv | MIST-28M | Maximum Unbiased Validation (MUV) |
141
+
142
+ ### Finetuned Mixture Models
143
+
144
+ These models consist of a MIST-encoder and physics informed task network for mixture property prediction.
145
+ | Folder | Encoder | Dataset |
146
+ | -------------------------------- | :------: | ----------------------------------------------------------- |
147
+ | mist-conductivity-28M-2mpg8dcd | MIST-28M | Ionic Conductivity |
148
+ | mist-mixtures-zffffbex | MIST-28M | Excess Density, Molar Volume and Molar Enthalpy |
149
+
150
+ ## Citation
151
+
152
+ If you use this model in your research, please cite:
153
+
154
+ ```bibtex
155
+ @online{MIST,
156
+ title = {Foundation Models for Discovery and Exploration in Chemical Space},
157
+ author = {Wadell, Alexius and Bhutani, Anoushka and Azumah, Victor and Ellis-Mohr, Austin R. and Kelly, Celia and Zhao, Hancheng and Nayak, Anuj K. and Hegazy, Kareem and Brace, Alexander and Lin, Hongyi and Emani, Murali and Vishwanath, Venkatram and Gering, Kevin and Alkan, Melisa and Gibbs, Tom and Wells, Jack and Varshney, Lav R. and Ramsundar, Bharath and Duraisamy, Karthik and Mahoney, Michael W. and Ramanathan, Arvind and Viswanathan, Venkatasubramanian},
158
+ date = {2025-10-20},
159
+ eprint = {2510.18900},
160
+ eprinttype = {arXiv},
161
+ eprintclass = {physics},
162
+ doi = {10.48550/arXiv.2510.18900},
163
+ url = {http://arxiv.org/abs/2510.18900},
164
+ }
165
+ ```
166
+
167
+ ## License and Notice
168
+
169
+ Model weights are provided as-is for research purposes only, without guarantees of correctness, fitness for purpose, or warranties of any kind.
170
+
171
+ **Restrictions:**
172
+ - Research use only
173
+ - No redistribution without permission
174
+ - No commercial use without licensing agreement
175
+
176
+ For questions, issues, or licensing inquiries, please contact [venkvis@umich.edu](mailto:venkvis@umich.edu).
177
+
178
+ <hr>
config.json ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MISTFinetuned"
4
+ ],
5
+ "channels": null,
6
+ "dtype": "float32",
7
+ "encoder": {
8
+ "_name_or_path": "",
9
+ "add_cross_attention": false,
10
+ "architectures": null,
11
+ "attention_probs_dropout_prob": 0.1,
12
+ "attn_implementation": null,
13
+ "bad_words_ids": null,
14
+ "begin_suppress_tokens": null,
15
+ "bos_token_id": 0,
16
+ "chunk_size_feed_forward": 0,
17
+ "classifier_dropout": null,
18
+ "cross_attention_hidden_size": null,
19
+ "decoder_start_token_id": null,
20
+ "diversity_penalty": 0.0,
21
+ "do_sample": false,
22
+ "dtype": null,
23
+ "early_stopping": false,
24
+ "enable_token_counter": true,
25
+ "encoder_no_repeat_ngram_size": 0,
26
+ "eos_token_id": 2,
27
+ "exponential_decay_length_penalty": null,
28
+ "finetuning_task": null,
29
+ "forced_bos_token_id": null,
30
+ "forced_eos_token_id": null,
31
+ "hidden_act": "gelu",
32
+ "hidden_dropout_prob": 0.1,
33
+ "hidden_size": 512,
34
+ "id2label": {
35
+ "0": "LABEL_0",
36
+ "1": "LABEL_1"
37
+ },
38
+ "initializer_range": 0.02,
39
+ "intermediate_size": 2048,
40
+ "is_decoder": false,
41
+ "is_encoder_decoder": false,
42
+ "label2id": {
43
+ "LABEL_0": 0,
44
+ "LABEL_1": 1
45
+ },
46
+ "layer_norm_eps": 1e-12,
47
+ "length_penalty": 1.0,
48
+ "max_length": 20,
49
+ "max_position_embeddings": 2048,
50
+ "min_length": 0,
51
+ "model_type": "roberta-prelayernorm",
52
+ "no_repeat_ngram_size": 0,
53
+ "num_attention_heads": 8,
54
+ "num_beam_groups": 1,
55
+ "num_beams": 1,
56
+ "num_hidden_layers": 8,
57
+ "num_return_sequences": 1,
58
+ "output_attentions": false,
59
+ "output_hidden_states": false,
60
+ "output_scores": false,
61
+ "pad_token_id": 1,
62
+ "position_embedding_type": "absolute",
63
+ "prefix": null,
64
+ "problem_type": null,
65
+ "pruned_heads": {},
66
+ "remove_invalid_values": false,
67
+ "repetition_penalty": 1.0,
68
+ "return_dict": true,
69
+ "return_dict_in_generate": false,
70
+ "sep_token_id": null,
71
+ "suppress_tokens": null,
72
+ "task_specific_params": null,
73
+ "temperature": 1.0,
74
+ "tf_legacy_loss": false,
75
+ "tie_encoder_decoder": false,
76
+ "tie_word_embeddings": true,
77
+ "tokenizer_class": null,
78
+ "top_k": 50,
79
+ "top_p": 1.0,
80
+ "torchscript": false,
81
+ "transformers_version": "4.57.1",
82
+ "type_vocab_size": 2,
83
+ "typical_p": 1.0,
84
+ "use_bfloat16": false,
85
+ "use_cache": true,
86
+ "vocab_size": 165
87
+ },
88
+ "model_type": "mist_finetuned",
89
+ "task_network": {
90
+ "dropout": 0.2,
91
+ "embed_dim": 512,
92
+ "output_size": 1
93
+ },
94
+ "tokenizer_class": "SmirkTokenizerFast",
95
+ "transform": {
96
+ "class": "Standardize",
97
+ "num_outputs": 1
98
+ },
99
+ "transformers_version": "4.57.1",
100
+ "auto_map": {
101
+ "AutoConfig": "modeling_mist_finetuned.MISTFinetunedConfig",
102
+ "AutoModel": "modeling_mist_finetuned.MISTFinetuned"
103
+ }
104
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d82a38329ae3821140679785e3fd2083c26ecd167cd9a15298f076501131da1
3
+ size 108591540
modeling_mist_finetuned.py ADDED
@@ -0,0 +1,775 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import IterableDataset
2
+ from pathlib import Path
3
+ from smirk import SmirkTokenizerFast
4
+ from torch import nn
5
+ from torch.masked import MaskedTensor, masked_tensor
6
+ from transformers import (
7
+ AutoConfig,
8
+ AutoModel,
9
+ AutoTokenizer,
10
+ DataCollatorWithPadding,
11
+ PreTrainedModel,
12
+ PretrainedConfig,
13
+ )
14
+ from typing import Any, Callable, Optional, Union
15
+ from typing import Any, Dict, List, Optional
16
+ import json
17
+ import logging
18
+ import math
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ MODEL_TYPE_ALIASES = {}
24
+
25
+ def build_encoder(enc_dict: Dict[str, Any]):
26
+ mtype = enc_dict.get("model_type")
27
+ if mtype:
28
+ base = MODEL_TYPE_ALIASES.get(mtype, mtype)
29
+ cfg_cls = AutoConfig.for_model(base)
30
+ enc_cfg = cfg_cls.from_dict(enc_dict)
31
+ elif enc_dict.get("_name_or_path"):
32
+ enc_cfg = AutoConfig.from_pretrained(enc_dict["_name_or_path"])
33
+ else:
34
+ raise KeyError("encoder config missing 'model_type' or '_name_or_path'")
35
+ if hasattr(enc_cfg, "add_pooling_layer"):
36
+ enc_cfg.add_pooling_layer = False
37
+ return AutoModel.from_config(enc_cfg)
38
+
39
+
40
+ class MISTFinetunedConfig(PretrainedConfig):
41
+ """HF config for a single-task MIST wrapper."""
42
+
43
+ model_type = "mist_finetuned"
44
+
45
+ def __init__(
46
+ self,
47
+ encoder: Optional[Dict[str, Any]] = None,
48
+ task_network: Optional[Dict[str, Any]] = None,
49
+ transform: Optional[Dict[str, Any]] = None,
50
+ channels: Optional[List[Dict[str, Any]]] = None,
51
+ tokenizer_class: Optional[str] = "SmirkTokenizer",
52
+ **kwargs,
53
+ ):
54
+ super().__init__(**kwargs)
55
+ self.encoder = encoder or {}
56
+ self.task_network = task_network or {}
57
+ self.transform = transform or {}
58
+ self.channels = channels
59
+ self.tokenizer_class = tokenizer_class
60
+
61
+ class MISTFinetuned(PreTrainedModel):
62
+ config_class = MISTFinetunedConfig
63
+
64
+ def __init__(self, config: MISTFinetunedConfig):
65
+ super().__init__(config)
66
+ self.encoder = build_encoder_from_dict(config.encoder)
67
+
68
+ tn = config.task_network
69
+ self.task_network = PredictionTaskHead(
70
+ embed_dim=tn["embed_dim"],
71
+ output_size=tn["output_size"],
72
+ dropout=tn["dropout"],
73
+ )
74
+ self.transform = AbstractNormalizer.get(
75
+ config.transform["class"], config.transform["num_outputs"]
76
+ )
77
+ self.channels = config.channels
78
+ self.tokenizer = None
79
+ self.post_init()
80
+
81
+ @classmethod
82
+ def from_components(
83
+ cls,
84
+ encoder: PreTrainedModel,
85
+ task_network: nn.Module,
86
+ transform: Any,
87
+ tokenizer: Optional[Any] = None,
88
+ channels: Optional[List[Dict[str, Any]]] = None,
89
+ ) -> "MISTFinetuned":
90
+ cfg = MISTFinetunedConfig(
91
+ encoder=encoder.config.to_dict(),
92
+ task_network={
93
+ "embed_dim": encoder.config.hidden_size,
94
+ "output_size": task_network.final.out_features,
95
+ "dropout": task_network.dropout1.p,
96
+ },
97
+ transform=transform.to_config(),
98
+ channels=channels,
99
+ tokenizer_class=(
100
+ getattr(tokenizer, "__class__", type("T", (), {})).__name__
101
+ if tokenizer
102
+ else "SmirkTokenizer"
103
+ ),
104
+ )
105
+ model = cls(cfg)
106
+ # load component weights
107
+ model.encoder.load_state_dict(encoder.state_dict(), strict=False)
108
+ model.task_network.load_state_dict(task_network.state_dict())
109
+ model.transform.load_state_dict(transform.state_dict())
110
+ model.tokenizer = tokenizer
111
+ return model
112
+
113
+ def forward(self, input_ids, attention_mask=None):
114
+ hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state
115
+ y = self.task_network(hs)
116
+ return self.transform.forward(y)
117
+
118
+ def _resolve_tokenizer(self, tokenizer):
119
+ if tokenizer is not None:
120
+ return tokenizer
121
+ if getattr(self, "tokenizer", None) is not None:
122
+ return self.tokenizer
123
+ try:
124
+ return AutoTokenizer.from_pretrained(self.name_or_path, use_fast=True)
125
+ except Exception:
126
+ return AutoTokenizer.from_pretrained(
127
+ self.config._name_or_path, use_fast=True
128
+ )
129
+
130
+ def embed(self, smi: List[str], tokenizer=None):
131
+ tok = self._resolve_tokenizer(tokenizer)
132
+ batch = tok(smi)
133
+ batch = DataCollatorWithPadding(tok)(batch)
134
+ input_ids = batch["input_ids"].to(self.device)
135
+ attention_mask = batch["attention_mask"].to(self.device)
136
+ with torch.inference_mode():
137
+ hs = self.encoder(
138
+ input_ids, attention_mask=attention_mask
139
+ ).last_hidden_state[:, 0, :]
140
+ return hs.to("cpu")
141
+
142
+ def predict(self, smi: List[str], return_dict: bool = True, tokenizer=None):
143
+ tok = self._resolve_tokenizer(tokenizer)
144
+ batch = tok(smi)
145
+ batch = DataCollatorWithPadding(tok)(batch)
146
+ inputs = {k: v.to(self.device) for k, v in batch.items()}
147
+ with torch.inference_mode():
148
+ out = self(**inputs).cpu()
149
+ if self.channels is None or not return_dict:
150
+ return out
151
+ return annotate_prediction(out, self.channels)
152
+
153
+ def save_pretrained(self, save_directory, **kwargs):
154
+ super().save_pretrained(save_directory, **kwargs)
155
+ if getattr(self, "tokenizer", None) is not None:
156
+ self.tokenizer.save_pretrained(save_directory)
157
+
158
+ def maybe_get_annotated_channels(channels: List[Any]):
159
+ for chn in channels:
160
+ if isinstance(chn, str):
161
+ yield {"name": chn, "description": None, "unit": None}
162
+ else:
163
+ yield chn
164
+
165
+ def annotate_prediction(
166
+ y: torch.Tensor, channels: List[Dict[str, str]]
167
+ ) -> Dict[str, Dict[str, Any]]:
168
+ out: Dict[str, Dict[str, Any]] = {}
169
+ for idx, chn in enumerate(channels):
170
+ channel_info = {f: v for f, v in chn.items() if f != "name"}
171
+ out[chn["name"]] = {"value": y[:, idx], **channel_info}
172
+ return out
173
+
174
+ def build_encoder_from_dict(enc_dict):
175
+ if "model_type" in enc_dict:
176
+ cfg_cls = AutoConfig.for_model(enc_dict["model_type"])
177
+ enc_cfg = cfg_cls.from_dict(enc_dict, strict=False)
178
+ elif "_name_or_path" in enc_dict:
179
+ enc_cfg = AutoConfig.from_pretrained(enc_dict["_name_or_path"], strict=False)
180
+ else:
181
+ raise KeyError("Encoder config is missing 'model_type' and '_name_or_path.")
182
+
183
+ # Ensure pooling layer is disabled to match saved checkpoints
184
+ if hasattr(enc_cfg, "add_pooling_layer"):
185
+ enc_cfg.add_pooling_layer = False
186
+
187
+ return AutoModel.from_config(enc_cfg)
188
+
189
+ class MISTMultiTaskConfig(PretrainedConfig):
190
+ """HuggingFace config for a multi-task MIST wrapper."""
191
+
192
+ model_type = "mist_multitask"
193
+
194
+ def __init__(
195
+ self,
196
+ encoder: Optional[Dict[str, Any]] = None,
197
+ task_networks: Optional[List[Dict[str, Any]]] = None,
198
+ transforms: Optional[List[Dict[str, Any]]] = None,
199
+ channels: Optional[List[Dict[str, Any]]] = None,
200
+ tokenizer_class: Optional[str] = "SmirkTokenizer",
201
+ **kwargs,
202
+ ):
203
+ super().__init__(**kwargs)
204
+ self.encoder = encoder or {}
205
+ self.task_networks = task_networks or []
206
+ self.transforms = transforms or []
207
+ self.channels = channels
208
+ self.tokenizer_class = tokenizer_class
209
+
210
+ class MISTMultiTask(PreTrainedModel):
211
+ config_class = MISTMultiTaskConfig
212
+
213
+ def __init__(self, config: MISTMultiTaskConfig):
214
+ super().__init__(config)
215
+ self.encoder = build_encoder_from_dict(config.encoder)
216
+
217
+ self.task_networks = nn.ModuleList(
218
+ [
219
+ PredictionTaskHead(
220
+ embed_dim=tn["embed_dim"],
221
+ output_size=tn["output_size"],
222
+ dropout=tn["dropout"],
223
+ )
224
+ for tn in config.task_networks
225
+ ]
226
+ )
227
+ self.transforms = nn.ModuleList(
228
+ [
229
+ AbstractNormalizer.get(tf_cfg["class"], tf_cfg["num_outputs"])
230
+ for tf_cfg in config.transforms
231
+ ]
232
+ )
233
+
234
+ assert len(self.task_networks) == len(
235
+ self.transforms
236
+ ), "task_networks and transforms must align"
237
+ self.channels = config.channels
238
+ self.tokenizer = None
239
+ self.post_init()
240
+
241
+ @classmethod
242
+ def from_components(
243
+ cls,
244
+ encoder: PreTrainedModel,
245
+ task_networks: List[nn.Module],
246
+ transforms: List[Any],
247
+ tokenizer: Optional[Any] = None,
248
+ channels: Optional[List[Dict[str, Any]]] = None,
249
+ ) -> "MISTMultiTask":
250
+ cfg = MISTMultiTaskConfig(
251
+ encoder=encoder.config.to_dict(),
252
+ task_networks=[
253
+ {
254
+ "embed_dim": encoder.config.hidden_size,
255
+ "output_size": tn.final.out_features,
256
+ "dropout": tn.dropout1.p,
257
+ }
258
+ for tn in task_networks
259
+ ],
260
+ transforms=[tf.to_config() for tf in transforms],
261
+ channels=channels,
262
+ tokenizer_class=(
263
+ getattr(tokenizer, "__class__", type("T", (), {})).__name__
264
+ if tokenizer
265
+ else "SmirkTokenizer"
266
+ ),
267
+ )
268
+ model = cls(cfg)
269
+ model.encoder.load_state_dict(encoder.state_dict(), strict=False)
270
+ for dst, src in zip(model.task_networks, task_networks):
271
+ dst.load_state_dict(src.state_dict())
272
+ for dst, src in zip(model.transforms, transforms):
273
+ dst.load_state_dict(src.state_dict())
274
+ model.tokenizer = tokenizer
275
+ return model
276
+
277
+ def forward(self, input_ids, attention_mask=None):
278
+ hs = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state
279
+ outs = []
280
+ for tn, tf in zip(self.task_networks, self.transforms):
281
+ outs.append(tf.forward(tn(hs)))
282
+ return torch.cat(outs, dim=-1)
283
+
284
+ def _resolve_tokenizer(self, tokenizer):
285
+ if tokenizer is not None:
286
+ return tokenizer
287
+ if getattr(self, "tokenizer", None) is not None:
288
+ return self.tokenizer
289
+ try:
290
+ return AutoTokenizer.from_pretrained(self.name_or_path, use_fast=True)
291
+ except Exception:
292
+ return AutoTokenizer.from_pretrained(
293
+ self.config._name_or_path, use_fast=True
294
+ )
295
+
296
+ def predict(self, smi: List[str], tokenizer=None):
297
+ tok = self._resolve_tokenizer(tokenizer)
298
+ batch = tok(smi)
299
+ batch = DataCollatorWithPadding(tok)(batch)
300
+ inputs = {k: v.to(self.device) for k, v in batch.items()}
301
+ with torch.inference_mode():
302
+ out = self(**inputs).cpu()
303
+ if self.channels is None:
304
+ return out
305
+ return annotate_prediction(out, self.channels)
306
+
307
+ def embed(self, smi: List[str], tokenizer=None):
308
+ tok = self._resolve_tokenizer(tokenizer)
309
+ batch = tok(smi)
310
+ batch = DataCollatorWithPadding(tok)(batch)
311
+ input_ids = batch["input_ids"].to(self.device)
312
+ attention_mask = batch["attention_mask"].to(self.device)
313
+ with torch.inference_mode():
314
+ hs = self.encoder(
315
+ input_ids, attention_mask=attention_mask
316
+ ).last_hidden_state[:, 0, :]
317
+ return hs.to("cpu")
318
+
319
+ def save_pretrained(self, save_directory, **kwargs):
320
+ super().save_pretrained(save_directory, **kwargs)
321
+ if getattr(self, "tokenizer", None) is not None:
322
+ self.tokenizer.save_pretrained(save_directory)
323
+
324
+ class PredictionTaskHead(nn.Module):
325
+ def __init__(
326
+ self, embed_dim: int, output_size: int = 1, dropout: float = 0.2
327
+ ) -> None:
328
+ super().__init__()
329
+ self.desc_skip_connection = True
330
+
331
+ self.fc1 = nn.Linear(embed_dim, embed_dim)
332
+ self.dropout1 = nn.Dropout(dropout)
333
+ self.relu1 = nn.GELU()
334
+ self.fc2 = nn.Linear(embed_dim, embed_dim)
335
+ self.dropout2 = nn.Dropout(dropout)
336
+ self.relu2 = nn.GELU()
337
+ self.final = nn.Linear(embed_dim, output_size)
338
+
339
+ def forward(self, emb):
340
+ if emb.ndim > 2:
341
+ emb = emb[:, 0, :]
342
+ x_out = self.fc1(emb)
343
+ x_out = self.dropout1(x_out)
344
+ x_out = self.relu1(x_out)
345
+
346
+ if self.desc_skip_connection is True:
347
+ x_out = x_out + emb
348
+
349
+ z = self.fc2(x_out)
350
+ z = self.dropout2(z)
351
+ z = self.relu2(z)
352
+ if self.desc_skip_connection is True:
353
+ z = self.final(z + x_out)
354
+ else:
355
+ z = self.final(z)
356
+ return z
357
+
358
+ class AbstractNormalizer(torch.nn.Module):
359
+ def __init__(self, num_outputs: Optional[int] = None):
360
+ super().__init__()
361
+ self.num_outputs = num_outputs
362
+
363
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
364
+ """Remove normalization"""
365
+ raise NotImplementedError
366
+
367
+ def inverse(self, x: torch.Tensor) -> torch.Tensor:
368
+ """Apply normalization"""
369
+ raise NotImplementedError
370
+
371
+ def _fit(self, x: MaskedTensor) -> dict:
372
+ """Fit the normalization parameters"""
373
+ raise NotImplementedError
374
+
375
+ def to_config(self) -> dict:
376
+ return {"class": self.__class__.__name__, "num_outputs": self.num_outputs}
377
+
378
+ def leader_fit(self, ds, rank: int, broadcast: Callable):
379
+ state = None
380
+ if rank == 0:
381
+ state = self.fit(ds)
382
+ state = broadcast(state)
383
+ self.load_state_dict(state)
384
+
385
+ def fit(self, ds, name: str = "target") -> dict:
386
+ """Fit the normalization parameters on dataset"""
387
+ if isinstance(ds, IterableDataset):
388
+ target = []
389
+ mask = []
390
+ for x in ds:
391
+ target.append(x[name])
392
+ mask.append(x[f"{name}_mask"])
393
+
394
+ target = torch.stack(target)
395
+ mask = torch.stack(mask)
396
+
397
+ else:
398
+ target = torch.stack([torch.tensor(x) for x in ds[name]])
399
+ mask = torch.stack([torch.tensor(x) for x in ds[f"{name}_mask"]])
400
+
401
+ # Use masked tensor to compute normalization parameters
402
+ target = masked_tensor(target, mask)
403
+
404
+ state = self._fit(target)
405
+ return state
406
+
407
+ @classmethod
408
+ def get(
409
+ cls, transform: Optional[Union[list[str], str]], num_outputs: int
410
+ ) -> "AbstractNormalizer":
411
+ if isinstance(transform, list):
412
+ assert len(transform) == num_outputs
413
+ return ChannelWiseTransform([cls.get(t, 1) for t in transform])
414
+ elif transform in ["standardize", Standardize.__name__]:
415
+ return Standardize(num_outputs)
416
+ elif transform in ["power_transform", PowerTransform.__name__]:
417
+ return PowerTransform(num_outputs)
418
+ elif transform in ["log_transform", LogTransform.__name__]:
419
+ return LogTransform(num_outputs)
420
+ elif transform in ["max_scale", MaxScaleTransform.__name__]:
421
+ return MaxScaleTransform(num_outputs)
422
+ else:
423
+ return IdentityTransform()
424
+
425
+ class Standardize(AbstractNormalizer):
426
+ def __init__(self, num_outputs: int, eps: float = 1e-8):
427
+ super().__init__(num_outputs)
428
+ self.register_buffer("mean", torch.zeros(num_outputs))
429
+ self.register_buffer("std", torch.zeros(num_outputs))
430
+ self.eps = float(eps)
431
+ assert 0 <= self.eps
432
+
433
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
434
+ return (self.std * x) + self.mean
435
+
436
+ def inverse(self, x: torch.Tensor) -> torch.Tensor:
437
+ return (x - self.mean) / self.std
438
+
439
+ def fit(self, ds, name: str = "target") -> dict:
440
+ num_outputs = self.num_outputs
441
+ assert num_outputs is not None
442
+ mean = torch.zeros(num_outputs)
443
+ m2 = torch.zeros(num_outputs)
444
+ n = torch.zeros(num_outputs, dtype=torch.int)
445
+ for row in ds:
446
+ target = torch.tensor(row[name])
447
+ mask = torch.tensor(row[f"{name}_mask"])
448
+ x = masked_tensor(target, mask)
449
+ n += mask.view(-1, num_outputs).sum(0)
450
+ xs = x.view(-1, num_outputs).sum(0)
451
+ delta = xs - mean
452
+ # Only update masked values
453
+ mean += (delta / n).get_data().masked_fill(~delta.get_mask(), 0)
454
+ delta2 = xs - mean
455
+ m2 += (delta * delta2).get_data().masked_fill(~delta.get_mask(), 0)
456
+
457
+ self.mean = mean.to(self.mean)
458
+ self.std = (m2 / n).sqrt().to(self.std) + self.eps
459
+ self.mean[self.mean.isnan()] = 0
460
+ self.std[self.std.isnan()] = 1
461
+ logging.debug("Fitted %s", self.state_dict())
462
+ return self.state_dict()
463
+
464
+ def _fit(self, target: MaskedTensor) -> dict:
465
+ self.mean = target.mean(0).get_data().to(self.mean)
466
+ self.std = target.std(0).get_data().to(self.std) + self.eps
467
+ return self.state_dict()
468
+
469
+ def load_state_dict(
470
+ self, state_dict: dict[str, Any], strict: bool = True, assign: bool = False
471
+ ):
472
+ # Handle legacy case where keys have "transform." prefix
473
+ if "transform.mean" in state_dict:
474
+ state_dict = state_dict.copy() # Don't modify original
475
+ state_dict["mean"] = state_dict.pop("transform.mean")
476
+ state_dict["std"] = state_dict.pop("transform.std")
477
+
478
+ if assign:
479
+ # Manually assign buffers when assign=True
480
+ for key, value in state_dict.items():
481
+ if key in ["mean", "std"]:
482
+ # Use register_buffer to properly replace the buffer
483
+ self.register_buffer(key, value)
484
+ result = None # No incompatible keys when we do it manually
485
+ else:
486
+ result = super().load_state_dict(state_dict, strict=strict, assign=False)
487
+
488
+ logging.debug(f" After loading: mean={self.mean}, std={self.std}")
489
+ return result
490
+
491
+ class TokenTaskHead(nn.Module):
492
+ def __init__(
493
+ self, embed_dim: int, output_size: int = 1, dropout: float = 0.2
494
+ ) -> None:
495
+ super().__init__()
496
+ self.layers = nn.Sequential(
497
+ nn.Linear(embed_dim, embed_dim),
498
+ nn.Dropout(dropout),
499
+ nn.GELU(),
500
+ nn.Linear(embed_dim, embed_dim),
501
+ nn.Dropout(dropout),
502
+ nn.GELU(),
503
+ nn.Linear(embed_dim, output_size),
504
+ )
505
+
506
+ def forward(self, emb):
507
+ return self.layers(emb)
508
+
509
+ class TokenPairwiseDistance(nn.Module):
510
+ def __init__(
511
+ self,
512
+ embed_dim: int,
513
+ dropout: float = 0.2,
514
+ num_attention_heads: int = 1,
515
+ num_layers: int = 1,
516
+ activation: str = "relu",
517
+ ff_ratio: int = 2,
518
+ ) -> None:
519
+ super().__init__()
520
+ enc_layer = nn.TransformerEncoderLayer(
521
+ d_model=embed_dim,
522
+ nhead=num_attention_heads,
523
+ dim_feedforward=ff_ratio * embed_dim,
524
+ dropout=dropout,
525
+ batch_first=True,
526
+ norm_first=True,
527
+ )
528
+ self.interaction = nn.TransformerEncoder(enc_layer, num_layers)
529
+ self.pairwise_distance = PairwiseMLP(embed_dim, dropout)
530
+ self.distance1 = nn.Sequential(
531
+ nn.Linear(embed_dim, embed_dim), nn.Dropout(dropout), nn.GELU()
532
+ )
533
+ self.distance2 = nn.Linear(embed_dim, 1)
534
+
535
+ def forward(self, hs: torch.Tensor) -> torch.Tensor:
536
+ hs = self.interaction(hs)
537
+
538
+ with torch.autocast("cuda", dtype=torch.float32):
539
+ pw_dist = self.pairwise_distance(hs)
540
+ d = self.distance1(pw_dist) + pw_dist
541
+ d = self.distance2(d).squeeze(-1)
542
+ return F.relu(F.elu(d) + 1)
543
+
544
+ class BiPairwiseBlock(nn.Module):
545
+ def __init__(self, d_model: int, bias: bool = True, device=None, dtype=None):
546
+ super().__init__()
547
+ factory_kwargs = {"device": device, "dtype": dtype}
548
+
549
+ self.bi_weight = nn.Parameter(torch.empty((d_model, d_model), **factory_kwargs))
550
+ self.lin_weight = nn.Parameter(
551
+ torch.empty((d_model, d_model), **factory_kwargs)
552
+ )
553
+ if bias:
554
+ self.bias = nn.Parameter(torch.empty(d_model, **factory_kwargs))
555
+ else:
556
+ self.register_parameter("bias", None)
557
+ self.reset_parameters()
558
+
559
+ # Gradient hook to enforce symmetry
560
+ self.bi_weight.register_hook(lambda grad: 0.5 * (grad + grad.T))
561
+
562
+ def reset_parameters(self):
563
+ nn.init.xavier_normal_(self.lin_weight, gain=nn.init.calculate_gain("relu"))
564
+ nn.init.xavier_normal_(self.bi_weight, gain=nn.init.calculate_gain("relu"))
565
+ with torch.no_grad():
566
+ self.bi_weight.copy_(0.5 * (self.bi_weight + self.bi_weight.T))
567
+
568
+ if self.bias is not None:
569
+ bound = 1 / math.sqrt(self.bias.size(0))
570
+ nn.init.uniform_(self.bias, -bound, bound)
571
+
572
+ def forward(self, x: torch.Tensor):
573
+ y_bi = torch.einsum("...ld,df,...rf->...lrf", x, self.bi_weight, x)
574
+ y_bi = 0.5 * (y_bi + y_bi.transpose(-3, -2)) # Enforce symmetry
575
+
576
+ x_linear = x.unsqueeze(-2) + x.unsqueeze(-3)
577
+ return y_bi + F.linear(x_linear, self.lin_weight, self.bias)
578
+
579
+ class PairwiseMLP(nn.Module):
580
+ def __init__(
581
+ self,
582
+ d_model: int,
583
+ dropout: float = 0.2,
584
+ device=None,
585
+ dtype=None,
586
+ ) -> None:
587
+ super().__init__()
588
+ self.mlp = nn.Sequential(
589
+ nn.Linear(2 * d_model, d_model),
590
+ nn.Dropout(dropout),
591
+ nn.GELU(),
592
+ nn.Linear(d_model, d_model),
593
+ nn.GELU(),
594
+ )
595
+
596
+ def forward(self, x: torch.Tensor):
597
+ _, N, _ = x.shape
598
+ x_l = x.unsqueeze(-2).expand(-1, N, N, -1)
599
+ x_r = x.unsqueeze(-3).expand(-1, N, N, -1)
600
+ x_pw = torch.cat([x_l, x_r], dim=-1)
601
+ y = self.mlp(x_pw)
602
+ return 0.5 * (y + y.transpose(1, 2))
603
+
604
+ class ChannelWiseTransform(AbstractNormalizer):
605
+ def __init__(self, transforms: list[AbstractNormalizer]):
606
+ super().__init__(len(transforms))
607
+ self.transforms = torch.nn.ModuleList(transforms)
608
+
609
+ def to_config(self) -> dict:
610
+ return {
611
+ "class": [t.__class__.__name__ for t in self.transforms],
612
+ "num_outputs": self.num_outputs,
613
+ }
614
+
615
+ def inverse(self, x: torch.Tensor) -> torch.Tensor:
616
+ return torch.cat(
617
+ [
618
+ transform.inverse(x[:, [idx]])
619
+ for idx, transform in enumerate(self.transforms)
620
+ ],
621
+ dim=1,
622
+ )
623
+
624
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
625
+ return torch.cat(
626
+ [
627
+ transform.forward(x[:, [idx]])
628
+ for idx, transform in enumerate(self.transforms)
629
+ ],
630
+ dim=1,
631
+ )
632
+
633
+ def _fit(self, x: MaskedTensor) -> dict:
634
+ for idx, transform in enumerate(self.transforms):
635
+ transform._fit(x[:, [idx]])
636
+ return self.state_dict()
637
+
638
+ class LogTransform(Standardize):
639
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
640
+ return torch.exp(super().forward(x))
641
+
642
+ def inverse(self, x: torch.Tensor) -> torch.Tensor:
643
+ return super().inverse(torch.log(x))
644
+
645
+ def _fit(self, target: MaskedTensor) -> dict:
646
+ return super()._fit(torch.log(target))
647
+
648
+ class PowerTransform(AbstractNormalizer):
649
+ """
650
+ Apply a power transform (Yeo-Johnson) featurewise to make data more Gaussian-like.
651
+ Followed by applying a zero-mean, unit-variance normalization to the
652
+ transformed output to rescale targets to [-1, 1].
653
+ """
654
+
655
+ def __init__(self, num_outputs, eps: float = 1e-8):
656
+ super().__init__(num_outputs)
657
+ self.num_outputs = num_outputs
658
+ self.register_buffer("lmbdas", torch.zeros(num_outputs))
659
+ self.register_buffer("mean", torch.zeros(num_outputs))
660
+ self.register_buffer("std", torch.zeros(num_outputs))
661
+ self.eps = float(eps)
662
+ assert 0 <= self.eps
663
+
664
+ def _yeo_johnson_transform(self, x, lmbda):
665
+ """
666
+ Return transformed input x following Yeo-Johnson transform with
667
+ parameter lambda.
668
+ Adapted from
669
+ https://github.com/scikit-learn/scikit-learn/blob/fbb32eae5/sklearn/preprocessing/_data.py#L3354
670
+ """
671
+ x_out = x.clone()
672
+ eps = torch.finfo(x.dtype).eps
673
+ pos = x >= 0 # binary mask
674
+
675
+ # when x >= 0
676
+ if abs(lmbda) < eps:
677
+ x_out[pos] = torch.log1p(x[pos])
678
+ else: # lmbda != 0
679
+ x_out[pos] = (torch.pow(x[pos] + 1, lmbda) - 1) / lmbda
680
+
681
+ # when x < 0
682
+ if abs(lmbda - 2) > eps:
683
+ x_out[~pos] = -(torch.pow(-x[~pos] + 1, 2 - lmbda) - 1) / (2 - lmbda)
684
+ else: # lmbda == 2
685
+ x_out[~pos] = -torch.log1p(-x[~pos])
686
+
687
+ return x_out
688
+
689
+ def _yeo_johnson_inverse_transform(self, x, lmbda):
690
+ """
691
+ Return inverse-transformed input x following Yeo-Johnson inverse
692
+ transform with parameter lambda.
693
+ Adapted from
694
+ https://github.com/scikit-learn/scikit-learn/blob/fbb32eae5/sklearn/preprocessing/_data.py#L3383
695
+ """
696
+ x_out = x.clone()
697
+ pos = x >= 0
698
+ eps = torch.finfo(x.dtype).eps
699
+
700
+ # when x >= 0
701
+ if abs(lmbda) < eps: # lmbda == 0
702
+ x_out[pos] = torch.exp(x[pos]) - 1
703
+ else: # lmbda != 0
704
+ x_out[pos] = torch.pow(x[pos] * lmbda + 1, 1 / lmbda) - 1
705
+
706
+ # when x < 0
707
+ if abs(lmbda - 2) > eps: # lmbda != 2
708
+ x_out[~pos] = 1 - torch.pow(-(2 - lmbda) * x[~pos] + 1, 1 / (2 - lmbda))
709
+ else: # lmbda == 2
710
+ x_out[~pos] = 1 - torch.exp(-x[~pos])
711
+ return x_out
712
+
713
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
714
+ # Undo standardization
715
+ x = (self.std * x) + self.mean
716
+ x_out = torch.zeros_like(x)
717
+ for i in range(self.num_outputs):
718
+ x_out[:, i] = self._yeo_johnson_inverse_transform(x[:, i], self.lmbdas[i])
719
+ return x_out
720
+
721
+ def inverse(self, x: torch.Tensor) -> torch.Tensor:
722
+ x_out = torch.zeros_like(x)
723
+ for i in range(self.num_outputs):
724
+ x_out[:, i] = self._yeo_johnson_transform(x[:, i], self.lmbdas[i])
725
+ # Standardization
726
+ x_out = (x_out - self.mean) / self.std
727
+ return x_out
728
+
729
+ def _fit(self, target: MaskedTensor) -> dict:
730
+ # Fit Yeo-Johnson lambdas
731
+ from sklearn.preprocessing import (
732
+ PowerTransformer as _PowerTransformer, # noqa: F811
733
+ )
734
+
735
+ transformer = _PowerTransformer(method="yeo-johnson", standardize=False)
736
+ target = torch.tensor(transformer.fit_transform(target.get_data().numpy()))
737
+ self.lmbdas = torch.tensor(transformer.lambdas_)
738
+ # Fit standardization scaling
739
+ self.mean = target.mean(0).to(self.mean)
740
+ self.std = target.std(0).to(self.std) + self.eps
741
+ return self.state_dict()
742
+
743
+ class MaxScaleTransform(AbstractNormalizer):
744
+ """
745
+ Divide by maximum value in training dataset.
746
+ """
747
+
748
+ def __init__(self, mx: int, eps: float = 1e-8):
749
+ super().__init__(1)
750
+ self.num_outputs = 1
751
+ self.max = mx
752
+ self.eps = float(eps)
753
+ assert 0 <= self.eps
754
+
755
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
756
+ # Undo standardization
757
+ x_out = self.max * x
758
+ return x_out
759
+
760
+ def inverse(self, x: torch.Tensor) -> torch.Tensor:
761
+ x_out = x / self.max
762
+ return x_out
763
+
764
+ def _fit(self, target: MaskedTensor) -> dict:
765
+ return self.state_dict()
766
+
767
+ class IdentityTransform(AbstractNormalizer):
768
+ def inverse(self, x: torch.Tensor) -> torch.Tensor:
769
+ return x
770
+
771
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
772
+ return x
773
+
774
+ def _fit(self, x: MaskedTensor) -> dict:
775
+ return self.state_dict()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers==4.57.1
2
+ torch==2.9.0
3
+ scikit-learn==1.7.2
4
+ smirk==0.2.0.dev0
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "[BOS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "[CLS]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "[EOS]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "[MASK]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "[PAD]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "[SEP]",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "[UNK]",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer.json ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
5
+ "added_tokens": [
6
+ {
7
+ "id": 0,
8
+ "content": "[UNK]",
9
+ "single_word": false,
10
+ "lstrip": false,
11
+ "rstrip": false,
12
+ "normalized": false,
13
+ "special": true
14
+ },
15
+ {
16
+ "id": 159,
17
+ "content": "[BOS]",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": false,
22
+ "special": true
23
+ },
24
+ {
25
+ "id": 160,
26
+ "content": "[EOS]",
27
+ "single_word": false,
28
+ "lstrip": false,
29
+ "rstrip": false,
30
+ "normalized": false,
31
+ "special": true
32
+ },
33
+ {
34
+ "id": 161,
35
+ "content": "[SEP]",
36
+ "single_word": false,
37
+ "lstrip": false,
38
+ "rstrip": false,
39
+ "normalized": false,
40
+ "special": true
41
+ },
42
+ {
43
+ "id": 162,
44
+ "content": "[PAD]",
45
+ "single_word": false,
46
+ "lstrip": false,
47
+ "rstrip": false,
48
+ "normalized": false,
49
+ "special": true
50
+ },
51
+ {
52
+ "id": 163,
53
+ "content": "[CLS]",
54
+ "single_word": false,
55
+ "lstrip": false,
56
+ "rstrip": false,
57
+ "normalized": false,
58
+ "special": true
59
+ },
60
+ {
61
+ "id": 164,
62
+ "content": "[MASK]",
63
+ "single_word": false,
64
+ "lstrip": false,
65
+ "rstrip": false,
66
+ "normalized": false,
67
+ "special": true
68
+ }
69
+ ],
70
+ "normalizer": {
71
+ "type": "Sequence",
72
+ "normalizers": [
73
+ {
74
+ "type": "Replace",
75
+ "pattern": {
76
+ "String": "++"
77
+ },
78
+ "content": "+2"
79
+ },
80
+ {
81
+ "type": "Replace",
82
+ "pattern": {
83
+ "String": "--"
84
+ },
85
+ "content": "-2"
86
+ },
87
+ {
88
+ "type": "Strip",
89
+ "strip_left": true,
90
+ "strip_right": true
91
+ }
92
+ ]
93
+ },
94
+ "pre_tokenizer": {
95
+ "outer": "Br?|Cl?|F|I|N|O|P|S|b|c|n|o|p|s|\\*|[\\.\\-=\\#\\$:/\\\\]|\\d|%|\\(|\\)|\\[.*?]",
96
+ "inner": "(\\d+)?(A[c|g|l|m|r|s|t|u]|B[a|e|h|i|k|r]?|C[a|d|e|f|l|m|n|o|r|s|u]?|D[b|s|y]|E[r|s|u]|F[e|l|m|r]?|G[a|d|e]|H[e|f|g|o|s]?|I[n|r]?|Kr?|L[a|i|r|u|v]|M[c|d|g|n|o|t]|N[a|b|d|e|h|i|o|p]?|O[g|s]?|P[a|b|d|m|o|r|t|u]?|R[a|b|e|f|g|h|n|u]|S[b|c|e|g|i|m|n|r]?|T[a|b|c|e|h|i|l|m|s]|U|V|W|Xe|Yb?|Z[n|r]|as|b|c|n|o|p|se?|\\*)(?:(@(?:@|AL|OH|SP|T[B|H])?)(\\d{1,2})?)?(?:(H)(\\d)?)?(?:([+-]{1,2})(\\d{0,2}))?(?:(:)(\\d+))?"
97
+ },
98
+ "post_processor": null,
99
+ "decoder": {
100
+ "type": "Fuse"
101
+ },
102
+ "model": {
103
+ "type": "WordLevel",
104
+ "vocab": {
105
+ "[UNK]": 0,
106
+ "#": 1,
107
+ "$": 2,
108
+ "%": 3,
109
+ "(": 4,
110
+ ")": 5,
111
+ "*": 6,
112
+ "+": 7,
113
+ "-": 8,
114
+ ".": 9,
115
+ "/": 10,
116
+ "0": 11,
117
+ "1": 12,
118
+ "2": 13,
119
+ "3": 14,
120
+ "4": 15,
121
+ "5": 16,
122
+ "6": 17,
123
+ "7": 18,
124
+ "8": 19,
125
+ "9": 20,
126
+ ":": 21,
127
+ "=": 22,
128
+ "@": 23,
129
+ "@@": 24,
130
+ "@AL": 25,
131
+ "@OH": 26,
132
+ "@SP": 27,
133
+ "@TB": 28,
134
+ "@TH": 29,
135
+ "Ac": 30,
136
+ "Ag": 31,
137
+ "Al": 32,
138
+ "Am": 33,
139
+ "Ar": 34,
140
+ "As": 35,
141
+ "At": 36,
142
+ "Au": 37,
143
+ "B": 38,
144
+ "Ba": 39,
145
+ "Be": 40,
146
+ "Bh": 41,
147
+ "Bi": 42,
148
+ "Bk": 43,
149
+ "Br": 44,
150
+ "C": 45,
151
+ "Ca": 46,
152
+ "Cd": 47,
153
+ "Ce": 48,
154
+ "Cf": 49,
155
+ "Cl": 50,
156
+ "Cm": 51,
157
+ "Cn": 52,
158
+ "Co": 53,
159
+ "Cr": 54,
160
+ "Cs": 55,
161
+ "Cu": 56,
162
+ "Db": 57,
163
+ "Ds": 58,
164
+ "Dy": 59,
165
+ "Er": 60,
166
+ "Es": 61,
167
+ "Eu": 62,
168
+ "F": 63,
169
+ "Fe": 64,
170
+ "Fl": 65,
171
+ "Fm": 66,
172
+ "Fr": 67,
173
+ "Ga": 68,
174
+ "Gd": 69,
175
+ "Ge": 70,
176
+ "H": 71,
177
+ "He": 72,
178
+ "Hf": 73,
179
+ "Hg": 74,
180
+ "Ho": 75,
181
+ "Hs": 76,
182
+ "I": 77,
183
+ "In": 78,
184
+ "Ir": 79,
185
+ "K": 80,
186
+ "Kr": 81,
187
+ "La": 82,
188
+ "Li": 83,
189
+ "Lr": 84,
190
+ "Lu": 85,
191
+ "Lv": 86,
192
+ "Mc": 87,
193
+ "Md": 88,
194
+ "Mg": 89,
195
+ "Mn": 90,
196
+ "Mo": 91,
197
+ "Mt": 92,
198
+ "N": 93,
199
+ "Na": 94,
200
+ "Nb": 95,
201
+ "Nd": 96,
202
+ "Ne": 97,
203
+ "Nh": 98,
204
+ "Ni": 99,
205
+ "No": 100,
206
+ "Np": 101,
207
+ "O": 102,
208
+ "Og": 103,
209
+ "Os": 104,
210
+ "P": 105,
211
+ "Pa": 106,
212
+ "Pb": 107,
213
+ "Pd": 108,
214
+ "Pm": 109,
215
+ "Po": 110,
216
+ "Pr": 111,
217
+ "Pt": 112,
218
+ "Pu": 113,
219
+ "Ra": 114,
220
+ "Rb": 115,
221
+ "Re": 116,
222
+ "Rf": 117,
223
+ "Rg": 118,
224
+ "Rh": 119,
225
+ "Rn": 120,
226
+ "Ru": 121,
227
+ "S": 122,
228
+ "Sb": 123,
229
+ "Sc": 124,
230
+ "Se": 125,
231
+ "Sg": 126,
232
+ "Si": 127,
233
+ "Sm": 128,
234
+ "Sn": 129,
235
+ "Sr": 130,
236
+ "Ta": 131,
237
+ "Tb": 132,
238
+ "Tc": 133,
239
+ "Te": 134,
240
+ "Th": 135,
241
+ "Ti": 136,
242
+ "Tl": 137,
243
+ "Tm": 138,
244
+ "Ts": 139,
245
+ "U": 140,
246
+ "V": 141,
247
+ "W": 142,
248
+ "Xe": 143,
249
+ "Y": 144,
250
+ "Yb": 145,
251
+ "Zn": 146,
252
+ "Zr": 147,
253
+ "[": 148,
254
+ "\\": 149,
255
+ "]": 150,
256
+ "as": 151,
257
+ "b": 152,
258
+ "c": 153,
259
+ "n": 154,
260
+ "o": 155,
261
+ "p": 156,
262
+ "s": 157,
263
+ "se": 158
264
+ },
265
+ "unk_token": "[UNK]"
266
+ }
267
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[UNK]",
5
+ "lstrip": false,
6
+ "normalized": true,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": false
10
+ },
11
+ "159": {
12
+ "content": "[BOS]",
13
+ "lstrip": false,
14
+ "normalized": true,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": false
18
+ },
19
+ "160": {
20
+ "content": "[EOS]",
21
+ "lstrip": false,
22
+ "normalized": true,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": false
26
+ },
27
+ "161": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": false
34
+ },
35
+ "162": {
36
+ "content": "[PAD]",
37
+ "lstrip": false,
38
+ "normalized": true,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": false
42
+ },
43
+ "163": {
44
+ "content": "[CLS]",
45
+ "lstrip": false,
46
+ "normalized": true,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": false
50
+ },
51
+ "164": {
52
+ "content": "[MASK]",
53
+ "lstrip": false,
54
+ "normalized": true,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": false
58
+ }
59
+ },
60
+ "bos_token": "[BOS]",
61
+ "clean_up_tokenization_spaces": false,
62
+ "cls_token": "[CLS]",
63
+ "eos_token": "[EOS]",
64
+ "extra_special_tokens": {},
65
+ "mask_token": "[MASK]",
66
+ "model_max_length": 1000000000000000019884624838656,
67
+ "pad_token": "[PAD]",
68
+ "sep_token": "[SEP]",
69
+ "tokenizer_class": "SmirkTokenizerFast",
70
+ "unk_token": "[UNK]",
71
+ "vocab_file": "/nfs/turbo/coe-venkvis/abhutani/electrolyte-fm/.venv/lib/python3.10/site-packages/smirk/vocab_smiles.json"
72
+ }