victor-shirasuna commited on
Commit
607c6a7
Β·
1 Parent(s): 3d83373

Updated README.md

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. README.md +214 -3
  3. images/str-bamba.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,214 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Molecular String-based Bamba Encoder-Decoder (STR-Bamba)
2
+
3
+ This repository provides PyTorch source code associated with our publication, "STR-Bamba: Multimodal Molecular Textual Representation Encoder-Decoder Foundation Model".
4
+
5
+ **Paper:** [OpenReview Link](https://openreview.net/pdf?id=0uWNuJ1xtz)
6
+
7
+ **HuggingFace:** [HuggingFace Link](https://huggingface.co/ibm/materials.str-bamba)
8
+
9
+ For more information contact: vshirasuna@ibm.com or evital@br.ibm.com.
10
+
11
+ ![str_bamba](images/str-bamba.png)
12
+
13
+ ## Introduction
14
+
15
+ We present a large encoder-decoder chemical foundation model based on the IBM Bamba architecture, a hybrid of Transformers and Mamba-2 layers, designed to support multi-representational molecular string inputs. The model is pre-trained in a BERT-style on 588 million samples, resulting in a corpus of approximately 29 billion molecular tokens. These models serve as a foundation for language chemical research in supporting different complex tasks, including molecular properties prediction, classification, and molecular translation. **Additionally, the STR-Bamba architecture allows for the aggregation of multiple representations in a single text input, as it does not contain any token length limitation, except for hardware limitations.** Our experiments across multiple benchmark datasets demonstrate state-of-the-art performance for various tasks. Model weights are available at: [HuggingFace Link](https://huggingface.co/ibm/materials.str-bamba).
16
+
17
+ The STR-Bamba model supports the following **molecular representations**:
18
+ - SMILES
19
+ - SELFIES
20
+ - Molecular Formula
21
+ - InChI
22
+ - IUPAC Name
23
+ - Polymer SMILES in [SPG notation](https://openreview.net/pdf?id=L47GThI95d)
24
+ - Formulations
25
+
26
+ ## Table of Contents
27
+
28
+ 1. [Getting Started](#getting-started)
29
+ 1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
30
+ 2. [Replicating Conda Environment](#replicating-conda-environment)
31
+ 2. [Pretraining](#pretraining)
32
+ 3. [Finetuning](#finetuning)
33
+ 4. [Feature Extraction](#feature-extraction)
34
+ 5. [Citations](#citations)
35
+
36
+ ## Getting Started
37
+
38
+ **This code and environment have been tested on Nvidia V100s and Nvidia A100s**
39
+
40
+ ### Pretrained Models and Training Logs
41
+
42
+ We provide checkpoints of the STR-Bamba model pre-trained on a dataset of ~118M small molecules, ~2M polymer structures, and 258 formulations. The pre-trained model shows competitive performance on classification and regression benchmarks across small and polymer molecules, and electrolyte formulations. For model weights: [HuggingFace Link](https://huggingface.co/ibm/materials.str-bamba)
43
+
44
+ Add the STR-Bamba `pre-trained weights.pt` to the `inference/` or `finetune/` directory according to your needs. The directory structure should look like the following:
45
+
46
+ ```
47
+ inference/
48
+ └── str_bamba/
49
+ β”œβ”€β”€ config/
50
+ β”œβ”€β”€ checkpoints/
51
+ β”‚ └── STR-Bamba_8.pt
52
+ └── tokenizer/
53
+ ```
54
+ and/or:
55
+
56
+ ```
57
+ finetune/
58
+ └── str_bamba/
59
+ β”œβ”€β”€ config/
60
+ β”œβ”€β”€ checkpoints/
61
+ β”‚ └── STR-Bamba_8.pt
62
+ └── tokenizer/
63
+ ```
64
+
65
+ ### Replicating Conda Environment
66
+
67
+ Follow these steps to replicate our Conda environment and install the necessary libraries:
68
+
69
+ #### Create and Activate Conda Environment
70
+ ```shell
71
+ mamba create -n strbamba python=3.10.13
72
+ mamba activate strbamba
73
+ ```
74
+
75
+ #### PyTorch 2.4.0 and CUDA 12.4
76
+ ```shell
77
+ pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124
78
+ ```
79
+
80
+ #### Mamba2 dependencies:
81
+
82
+ Install the following packages in this order and with a **GPU**, because `mamba` depends on `causal-conv1d` to be installed.
83
+
84
+ ```shell
85
+ # causal-conv1d
86
+ git clone https://github.com/Dao-AILab/causal-conv1d.git
87
+ cd causal-conv1d && git checkout v1.5.0.post8 && pip install . && cd .. && rm -rf causal-conv1d
88
+ ```
89
+
90
+ ```shell
91
+ # mamba
92
+ git clone https://github.com/state-spaces/mamba.git
93
+ cd mamba && git checkout v2.2.4 && pip install --no-build-isolation . && cd .. && rm -rf mamba
94
+ ```
95
+
96
+ ```shell
97
+ # flash-attn
98
+ pip install flash-attn==2.6.1 --no-build-isolation
99
+ ```
100
+
101
+ #### Install Packages with Pip
102
+ ```shell
103
+ pip install -r requirements.txt
104
+ ```
105
+
106
+ #### Troubleshooting
107
+ ```shell
108
+ pip install mamba-ssm==2.2.4
109
+ MAX_JOBS=2 pip install flash-attn==2.6.1 --no-build-isolation --verbose
110
+ ```
111
+
112
+
113
+ ## Pretraining
114
+
115
+ For pretraining, we use two strategies: the masked language model method to train the encoder part and a next token prediction strategy to train the decoder in order to refine molecular representation reconstruction and generation conditioned from the encoder.
116
+
117
+ The pretraining code provides examples of data processing and model training on a smaller dataset, requiring a A100 GPU.
118
+
119
+ To pre-train the two stages of the STR-Bamba model, run:
120
+
121
+ ```
122
+ bash training/run_model_encoder_training.sh
123
+ ```
124
+ or
125
+ ```
126
+ bash training/run_model_decoder_training.sh
127
+ ```
128
+
129
+ ## Finetuning
130
+
131
+ The finetuning datasets and environment can be found in the [finetune](finetune/) directory. After setting up the environment, you can run a finetuning task with:
132
+
133
+ ```
134
+ bash finetune/runs/esol/run_finetune_esol.sh
135
+ ```
136
+
137
+ Finetuning training/checkpointing resources will be available in directories named `checkpoint_<measure_name>`.
138
+
139
+ ## Feature Extraction
140
+
141
+ To load STR-Bamba, you can simply use:
142
+
143
+ ```python
144
+ model = load_strbamba('STR-Bamba_8.pt')
145
+ ```
146
+
147
+ To encode SMILES, SELFIES, InChI or other supported molecular representations into embeddings, you can use:
148
+
149
+ ```python
150
+ with torch.no_grad():
151
+ encoded_embeddings = model.encode(df['SMILES'], return_torch=True)
152
+ ```
153
+ For decoder, you can use the following code, so you can generate new molecular representations conditioned from the encoder:
154
+
155
+ ```python
156
+ with torch.no_grad():
157
+ # encoder and decoder inputs
158
+ encoder_input = '<smiles>CCO'
159
+ decoder_input = '<smiles>'
160
+ decoder_target = '<smiles>CCO'
161
+
162
+ # tokenization
163
+ encoder_input_ids = model.tokenizer(encoder_input,
164
+ padding=True,
165
+ truncation=True,
166
+ return_tensors='pt')['input_ids'].to(device)
167
+ decoder_input_ids = model.tokenizer(decoder_input,
168
+ padding=True,
169
+ truncation=True,
170
+ return_tensors='pt')['input_ids'][:, :-1].to(device)
171
+ decoder_target_ids = model.tokenizer(decoder_target,
172
+ padding=True,
173
+ truncation=True,
174
+ return_tensors='pt')['input_ids'].to(device)
175
+
176
+ # visualize input texts
177
+ print('Encoder input:', model.tokenizer.batch_decode(encoder_input_ids))
178
+ print('Decoder input:', model.tokenizer.batch_decode(decoder_input_ids))
179
+ print('Decoder target:', model.tokenizer.batch_decode(decoder_target_ids))
180
+ print('Target:', decoder_target_ids)
181
+
182
+ # encoder forward
183
+ encoder_hidden_states = model.encoder(encoder_input_ids).hidden_states
184
+
185
+ # model generation
186
+ output = model.decoder.generate(
187
+ input_ids=decoder_input_ids,
188
+ encoder_hidden_states=encoder_hidden_states,
189
+ max_length=decoder_target_ids.shape[1],
190
+ cg=True,
191
+ return_dict_in_generate=True,
192
+ output_scores=True,
193
+ enable_timing=False,
194
+ temperature=1,
195
+ top_k=1,
196
+ top_p=1.0,
197
+ min_p=0.,
198
+ repetition_penalty=1,
199
+ )
200
+
201
+ # visualize model output
202
+ generated_text = ''.join(
203
+ ''.join(
204
+ model.tokenizer.batch_decode(
205
+ output.sequences,
206
+ clean_up_tokenization_spaces=True,
207
+ skip_special_tokens=False
208
+ )
209
+ ).split(' ')
210
+ )
211
+ print(generated_text)
212
+ ```
213
+
214
+ ## Citations
images/str-bamba.png ADDED

Git LFS Details

  • SHA256: 18565666078310fed55d2af8cb127a2a72d264be9aaf254044c7987f25411f96
  • Pointer size: 131 Bytes
  • Size of remote file: 695 kB