File size: 8,063 Bytes
bac76b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
---
license: apache-2.0
tags:
- chemistry
- foundation models
- AI4Science
- materials
- molecules
- smiles
- selfies
- molecular formula
- iupac name
- inchi
- polymer smiles
- formulation
- pytorch
- bamba
- transformers
- mamba2
---

# Molecular String-based Bamba Encoder-Decoder (STR-Bamba)

This repository provides PyTorch source code associated with our publication, "STR-Bamba: Multimodal Molecular Textual Representation Encoder-Decoder Foundation Model".

**Paper:** [OpenReview Link](https://openreview.net/pdf?id=0uWNuJ1xtz)

**GitHub:** [GitHub Link](https://github.com/IBM/materials/tree/main/models/str_bamba)

For more information contact: vshirasuna@ibm.com or evital@br.ibm.com.

![str_bamba](images/str-bamba.png)

## Introduction

We present a large encoder-decoder chemical foundation model based on the IBM Bamba architecture, a hybrid of Transformers and Mamba-2 layers, designed to support multi-representational molecular string inputs. The model is pre-trained in a BERT-style on 588 million samples, resulting in a corpus of approximately 29 billion molecular tokens. These models serve as a foundation for language chemical research in supporting different complex tasks, including molecular properties prediction, classification, and molecular translation. **Additionally, the STR-Bamba architecture allows for the aggregation of multiple representations in a single text input, as it does not contain any token length limitation, except for hardware limitations.** Our experiments across multiple benchmark datasets demonstrate state-of-the-art performance for various tasks. Code details are available at: [GitHub Link](https://github.com/IBM/materials/tree/main/models/str_bamba).

The STR-Bamba model supports the following **molecular representations**:
- SMILES
- SELFIES
- Molecular Formula
- InChI
- IUPAC Name
- Polymer SMILES in [SPG notation](https://openreview.net/pdf?id=L47GThI95d)
- Formulations

## Table of Contents

1. [Getting Started](#getting-started)
    1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
    2. [Replicating Conda Environment](#replicating-conda-environment)
2. [Pretraining](#pretraining)
3. [Finetuning](#finetuning)
4. [Feature Extraction](#feature-extraction)
5. [Citations](#citations)

## Getting Started

**This code and environment have been tested on Nvidia V100s and Nvidia A100s**

### Pretrained Models and Training Logs

We provide checkpoints of the STR-Bamba model pre-trained on a dataset of ~118M small molecules, ~2M polymer structures, and 258 formulations. The pre-trained model shows competitive performance on classification and regression benchmarks across small and polymer molecules, and electrolyte formulations. For code details: [GitHub Link](https://github.com/IBM/materials/tree/main/models/str_bamba)

Add the STR-Bamba `pre-trained weights.pt` to the `inference/` or `finetune/` directory according to your needs. The directory structure should look like the following:

```
inference/
└── str_bamba/
    β”œβ”€β”€ config/
    β”œβ”€β”€ checkpoints/
    β”‚   └── STR-Bamba_8.pt
    └── tokenizer/
```
and/or:

```
finetune/
└── str_bamba/
    β”œβ”€β”€ config/
    β”œβ”€β”€ checkpoints/
    β”‚   └── STR-Bamba_8.pt
    └── tokenizer/
```

### Replicating Conda Environment

Follow these steps to replicate our Conda environment and install the necessary libraries:

#### Create and Activate Conda Environment
```shell
mamba create -n strbamba python=3.10.13
mamba activate strbamba
```

#### PyTorch 2.4.0 and CUDA 12.4
```shell
pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124
```

#### Mamba2 dependencies:

Install the following packages in this order and with a **GPU**, because `mamba` depends on `causal-conv1d` to be installed.

```shell
# causal-conv1d
git clone https://github.com/Dao-AILab/causal-conv1d.git
cd causal-conv1d && git checkout v1.5.0.post8 && pip install . && cd .. && rm -rf causal-conv1d
```

```shell
# mamba
git clone https://github.com/state-spaces/mamba.git
cd mamba && git checkout v2.2.4 && pip install --no-build-isolation . && cd .. && rm -rf mamba
```

```shell
# flash-attn
pip install flash-attn==2.6.1 --no-build-isolation
```

#### Install Packages with Pip
```shell
pip install -r requirements.txt
```

#### Troubleshooting
```shell
pip install mamba-ssm==2.2.4
MAX_JOBS=2 pip install flash-attn==2.6.1 --no-build-isolation --verbose
```


## Pretraining

For pretraining, we use two strategies: the masked language model method to train the encoder part and a next token prediction strategy to train the decoder in order to refine molecular representation reconstruction and generation conditioned from the encoder.

The pretraining code provides examples of data processing and model training on a smaller dataset, requiring a A100 GPU.

To pre-train the two stages of the STR-Bamba model, run:

```
bash training/run_model_encoder_training.sh
```
or
```
bash training/run_model_decoder_training.sh
```

## Finetuning

The finetuning datasets and environment can be found in the [finetune](https://github.com/IBM/materials/tree/main/models/str_bamba/finetune) directory. After setting up the environment, you can run a finetuning task with:

```
bash finetune/runs/esol/run_finetune_esol.sh
```

Finetuning training/checkpointing resources will be available in directories named `checkpoint_<measure_name>`.

## Feature Extraction

To load STR-Bamba, you can simply use:

```python
model = load_strbamba('STR-Bamba_8.pt')
```

To encode SMILES, SELFIES, InChI or other supported molecular representations into embeddings, you can use:

```python
with torch.no_grad():
    encoded_embeddings = model.encode(df['SMILES'], return_torch=True)
```
For decoder, you can use the following code, so you can generate new molecular representations conditioned from the encoder:

```python
with torch.no_grad():
    # encoder and decoder inputs
    encoder_input = '<smiles>CCO'
    decoder_input = '<smiles>'
    decoder_target = '<smiles>CCO'

    # tokenization
    encoder_input_ids = model.tokenizer(encoder_input, 
                                        padding=True, 
                                        truncation=True, 
                                        return_tensors='pt')['input_ids'].to(device)
    decoder_input_ids = model.tokenizer(decoder_input, 
                                        padding=True, 
                                        truncation=True, 
                                        return_tensors='pt')['input_ids'][:, :-1].to(device)
    decoder_target_ids = model.tokenizer(decoder_target, 
                                         padding=True, 
                                         truncation=True, 
                                         return_tensors='pt')['input_ids'].to(device)

    # visualize input texts
    print('Encoder input:', model.tokenizer.batch_decode(encoder_input_ids))
    print('Decoder input:', model.tokenizer.batch_decode(decoder_input_ids))
    print('Decoder target:', model.tokenizer.batch_decode(decoder_target_ids))
    print('Target:', decoder_target_ids)

    # encoder forward
    encoder_hidden_states = model.encoder(encoder_input_ids).hidden_states

    # model generation
    output = model.decoder.generate(
        input_ids=decoder_input_ids,
        encoder_hidden_states=encoder_hidden_states,
        max_length=decoder_target_ids.shape[1],
        cg=True,
        return_dict_in_generate=True,
        output_scores=True,
        enable_timing=False,
        temperature=1,
        top_k=1,
        top_p=1.0,
        min_p=0.,
        repetition_penalty=1,
    )

    # visualize model output
    generated_text = ''.join(
        ''.join(
            model.tokenizer.batch_decode(
                output.sequences, 
                clean_up_tokenization_spaces=True, 
                skip_special_tokens=False
            )
        ).split(' ')
    )
    print(generated_text)
```

## Citations