victor-shirasuna commited on
Commit
3d83373
·
1 Parent(s): 94a0645

Upload files

Browse files
STR-Bamba_8.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db6d7a2561bfbaf9bd8a5f910321b2ff21671b6bc47cad955a323898203a9967
3
+ size 1372194320
STR-Bamba_8.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db6d7a2561bfbaf9bd8a5f910321b2ff21671b6bc47cad955a323898203a9967
3
+ size 1372194320
config.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "encoder_config": {
3
+ "d_model": 1024,
4
+ "d_intermediate": 0,
5
+ "n_layer": 24,
6
+ "vocab_size": 5000,
7
+ "max_position_embeddings": 4096,
8
+ "ssm_cfg": {
9
+ "layer": "Mamba2"
10
+ },
11
+ "attn_layer_idx": [
12
+ 6,
13
+ 18
14
+ ],
15
+ "attn_cfg": {
16
+ "causal": false,
17
+ "d_conv": 0,
18
+ "head_dim": 64,
19
+ "num_heads": 16,
20
+ "num_heads_kv": 8,
21
+ "out_proj_bias": false,
22
+ "qkv_proj_bias": false,
23
+ "rotary_emb_dim": 64
24
+ },
25
+ "rms_norm": true,
26
+ "residual_in_fp32": true,
27
+ "fused_add_norm": true,
28
+ "pad_vocab_size_multiple": 8,
29
+ "tie_embeddings": false
30
+ },
31
+ "decoder_config": {
32
+ "d_model": 1024,
33
+ "d_intermediate": 0,
34
+ "n_layer": 24,
35
+ "vocab_size": 5000,
36
+ "max_position_embeddings": 4096,
37
+ "ssm_cfg": {
38
+ "layer": "Mamba2"
39
+ },
40
+ "attn_layer_idx": [
41
+ 6,
42
+ 18
43
+ ],
44
+ "attn_cfg": {
45
+ "causal": true,
46
+ "d_conv": 0,
47
+ "head_dim": 64,
48
+ "num_heads": 16,
49
+ "num_heads_kv": 8,
50
+ "out_proj_bias": false,
51
+ "qkv_proj_bias": false,
52
+ "rotary_emb_dim": 64
53
+ },
54
+ "rms_norm": true,
55
+ "residual_in_fp32": true,
56
+ "fused_add_norm": true,
57
+ "pad_vocab_size_multiple": 8,
58
+ "tie_embeddings": false
59
+ },
60
+ "tie_word_embeddings": true,
61
+ "seed": 0
62
+ }
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ pandas==2.2.3
3
+ scikit-learn>=1.6.1
4
+ datasets==3.5.0
5
+ transformers==4.52.1
6
+ tokenizers==0.21.1
7
+ deepspeed==0.16.7
8
+ einops==0.8.1
9
+ tqdm==4.67.1
10
+ torch-optimizer==0.3.0
11
+ rdkit>=2024.3.5
12
+ selfies>=2.2.0
str_bamba/bamba.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .generation import GenerationMixin
2
+ from mamba_ssm.modules.mamba2 import Mamba2
3
+ from mamba_ssm.modules.mha import MHA
4
+ from mamba_ssm.modules.mlp import GatedMLP
5
+ from mamba_ssm.modules.block import Block
6
+ from mamba_ssm.models.mixer_seq_simple import _init_weights
7
+ from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
8
+ from .bamba_modules import BertEmbeddings, BertPooler, BertPreTrainingHeads, BlockCrossAttention
9
+ from .bamba_config import BambaConfig, BambaEncoderDecoderConfig
10
+
11
+ try:
12
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
13
+ except ImportError:
14
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+ from collections import namedtuple
18
+ import torch.backends.cudnn as cudnn
19
+ import math
20
+ import random
21
+ from functools import partial
22
+ import json
23
+ import os
24
+ import copy
25
+ import torch
26
+ import torch.nn as nn
27
+ import pandas as pd
28
+ import numpy as np
29
+ import gc
30
+ from tqdm import tqdm
31
+
32
+
33
+ def create_block(
34
+ d_model,
35
+ d_intermediate,
36
+ block_class,
37
+ ssm_cfg=None,
38
+ attn_layer_idx=None,
39
+ attn_cfg=None,
40
+ norm_epsilon=1e-5,
41
+ rms_norm=False,
42
+ residual_in_fp32=False,
43
+ fused_add_norm=False,
44
+ layer_idx=None,
45
+ device=None,
46
+ dtype=None,
47
+ ):
48
+ if ssm_cfg is None:
49
+ ssm_cfg = {}
50
+ if attn_layer_idx is None:
51
+ attn_layer_idx = []
52
+ if attn_cfg is None:
53
+ attn_cfg = {}
54
+ factory_kwargs = {"device": device, "dtype": dtype}
55
+ if layer_idx not in attn_layer_idx:
56
+ # Create a copy of the config to modify
57
+ ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
58
+ ssm_layer = ssm_cfg.pop("layer", "Mamba1")
59
+ if ssm_layer not in ["Mamba1", "Mamba2"]:
60
+ raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
61
+ mixer_cls = partial(
62
+ Mamba2 if ssm_layer == "Mamba2" else Mamba,
63
+ layer_idx=layer_idx,
64
+ **ssm_cfg,
65
+ **factory_kwargs
66
+ )
67
+ else:
68
+ mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
69
+ norm_cls = partial(
70
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
71
+ )
72
+ if d_intermediate == 0:
73
+ mlp_cls = nn.Identity
74
+ else:
75
+ mlp_cls = partial(
76
+ GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
77
+ )
78
+ block = block_class(
79
+ d_model,
80
+ mixer_cls,
81
+ mlp_cls,
82
+ norm_cls=norm_cls,
83
+ fused_add_norm=fused_add_norm,
84
+ residual_in_fp32=residual_in_fp32,
85
+ )
86
+ if isinstance(block, BlockCrossAttention) and factory_kwargs["dtype"] is not None:
87
+ block.encoder_attn.type(factory_kwargs["dtype"]).to(factory_kwargs["device"])
88
+ block.layer_idx = layer_idx
89
+ return block
90
+
91
+
92
+ class BambaMixerModel(nn.Module):
93
+ def __init__(
94
+ self,
95
+ d_model: int,
96
+ n_layer: int,
97
+ d_intermediate: int,
98
+ vocab_size: int,
99
+ max_position_embeddings: int,
100
+ is_decoder: bool = False,
101
+ ssm_cfg=None,
102
+ attn_layer_idx=None,
103
+ attn_cfg=None,
104
+ norm_epsilon: float = 1e-5,
105
+ rms_norm: bool = False,
106
+ initializer_cfg=None,
107
+ fused_add_norm=False,
108
+ residual_in_fp32=False,
109
+ device=None,
110
+ dtype=None,
111
+ ) -> None:
112
+ factory_kwargs = {"device": device, "dtype": dtype}
113
+ super().__init__()
114
+ self.residual_in_fp32 = residual_in_fp32
115
+
116
+ self.is_decoder = is_decoder
117
+
118
+ if is_decoder:
119
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
120
+ else:
121
+ self.embedding = BertEmbeddings(vocab_size, d_model, max_position_embeddings, **factory_kwargs)
122
+
123
+ # We change the order of residual and layer norm:
124
+ # Instead of LN -> Attn / MLP -> Add, we do:
125
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
126
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
127
+ # This is for performance reason: we can fuse add + layer_norm.
128
+ self.fused_add_norm = fused_add_norm
129
+ if self.fused_add_norm:
130
+ if layer_norm_fn is None or rms_norm_fn is None:
131
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
132
+
133
+ if is_decoder:
134
+ block_class = BlockCrossAttention
135
+ else:
136
+ block_class = Block
137
+
138
+ self.layers = nn.ModuleList(
139
+ [
140
+ create_block(
141
+ d_model,
142
+ d_intermediate=d_intermediate,
143
+ block_class=block_class,
144
+ ssm_cfg=ssm_cfg,
145
+ attn_layer_idx=attn_layer_idx,
146
+ attn_cfg=attn_cfg,
147
+ norm_epsilon=norm_epsilon,
148
+ rms_norm=rms_norm,
149
+ residual_in_fp32=residual_in_fp32,
150
+ fused_add_norm=fused_add_norm,
151
+ layer_idx=i,
152
+ **factory_kwargs,
153
+ )
154
+ for i in range(n_layer)
155
+ ]
156
+ )
157
+
158
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
159
+ d_model, eps=norm_epsilon, **factory_kwargs
160
+ )
161
+
162
+ if not is_decoder:
163
+ self.pooler = BertPooler(d_model, **factory_kwargs)
164
+
165
+ self.apply(
166
+ partial(
167
+ _init_weights,
168
+ n_layer=n_layer,
169
+ **(initializer_cfg if initializer_cfg is not None else {}),
170
+ n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP
171
+ )
172
+ )
173
+
174
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
175
+ return {
176
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
177
+ for i, layer in enumerate(self.layers)
178
+ }
179
+
180
+ def forward(self, input_ids, token_type_ids=None, inference_params=None, encoder_hidden_states=None, attention_mask=None, **mixer_kwargs):
181
+ if self.is_decoder:
182
+ hidden_states = self.embedding(input_ids)
183
+ else:
184
+ hidden_states = self.embedding(input_ids, token_type_ids)
185
+ residual = None
186
+ for layer in self.layers:
187
+ if self.is_decoder:
188
+ hidden_states, residual = layer(
189
+ hidden_states, residual, inference_params=inference_params, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **mixer_kwargs
190
+ )
191
+ else:
192
+ hidden_states, residual = layer(
193
+ hidden_states, residual, inference_params=inference_params, **mixer_kwargs
194
+ )
195
+ if not self.fused_add_norm:
196
+ residual = (hidden_states + residual) if residual is not None else hidden_states
197
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
198
+ else:
199
+ # Set prenorm=False here since we don't need the residual
200
+ hidden_states = layer_norm_fn(
201
+ hidden_states,
202
+ self.norm_f.weight,
203
+ self.norm_f.bias,
204
+ eps=self.norm_f.eps,
205
+ residual=residual,
206
+ prenorm=False,
207
+ residual_in_fp32=self.residual_in_fp32,
208
+ is_rms_norm=isinstance(self.norm_f, RMSNorm)
209
+ )
210
+ if not self.is_decoder:
211
+ pooled_output = self.pooler(hidden_states)
212
+ return hidden_states, pooled_output
213
+ return hidden_states
214
+
215
+
216
+ class BambaEncoder(nn.Module):
217
+
218
+ def __init__(
219
+ self,
220
+ config: BambaConfig,
221
+ initializer_cfg=None,
222
+ device=None,
223
+ dtype=None,
224
+ ) -> None:
225
+ self.config = config
226
+ d_model = config.d_model
227
+ n_layer = config.n_layer
228
+ d_intermediate = config.d_intermediate
229
+ vocab_size = config.vocab_size
230
+ max_position_embeddings = config.max_position_embeddings
231
+ ssm_cfg = config.ssm_cfg
232
+ attn_layer_idx = config.attn_layer_idx
233
+ attn_cfg = config.attn_cfg
234
+ rms_norm = config.rms_norm
235
+ residual_in_fp32 = config.residual_in_fp32
236
+ fused_add_norm = config.fused_add_norm
237
+ pad_vocab_size_multiple = config.pad_vocab_size_multiple
238
+ factory_kwargs = {"device": device, "dtype": dtype}
239
+
240
+ super().__init__()
241
+ if vocab_size % pad_vocab_size_multiple != 0:
242
+ vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
243
+ self.backbone = BambaMixerModel(
244
+ d_model=d_model,
245
+ n_layer=n_layer,
246
+ d_intermediate=d_intermediate,
247
+ vocab_size=vocab_size,
248
+ max_position_embeddings=max_position_embeddings,
249
+ is_decoder=False,
250
+ ssm_cfg=ssm_cfg,
251
+ attn_layer_idx=attn_layer_idx,
252
+ attn_cfg=attn_cfg,
253
+ rms_norm=rms_norm,
254
+ initializer_cfg=initializer_cfg,
255
+ fused_add_norm=fused_add_norm,
256
+ residual_in_fp32=residual_in_fp32,
257
+ **factory_kwargs,
258
+ )
259
+ self.cls = BertPreTrainingHeads(vocab_size, d_model, **factory_kwargs)
260
+
261
+ # Initialize weights and apply final processing
262
+ self.apply(
263
+ partial(
264
+ _init_weights,
265
+ n_layer=n_layer,
266
+ **(initializer_cfg if initializer_cfg is not None else {}),
267
+ )
268
+ )
269
+ self.tie_weights()
270
+
271
+ def tie_weights(self):
272
+ if self.config.tie_embeddings:
273
+ self.lm_head.weight = self.backbone.embedding.weight
274
+
275
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
276
+ return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
277
+
278
+ def forward(self, input_ids, token_type_ids=None, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
279
+ """
280
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
281
+ num_last_tokens: if > 0, only return the logits for the last n tokens
282
+ """
283
+ hidden_states, pooled_output = self.backbone(input_ids, token_type_ids, inference_params=inference_params, **mixer_kwargs)
284
+ if num_last_tokens > 0:
285
+ hidden_states = hidden_states[:, -num_last_tokens:]
286
+ lm_logits, seq_relationship_score = self.cls(hidden_states, pooled_output)
287
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "seq_relationship_logits", "hidden_states"])
288
+ return CausalLMOutput(logits=lm_logits, seq_relationship_logits=seq_relationship_score, hidden_states=hidden_states)
289
+
290
+ @classmethod
291
+ def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
292
+ config_data = load_config_hf(pretrained_model_name)
293
+ config = MambaConfig(**config_data)
294
+ model = cls(config, device=device, dtype=dtype, **kwargs)
295
+ model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
296
+ return model
297
+
298
+ def save_pretrained(self, save_directory):
299
+ """
300
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
301
+ Save the model and its configuration file to a directory.
302
+ """
303
+ # Ensure save_directory exists
304
+ os.makedirs(save_directory, exist_ok=True)
305
+
306
+ # Save the model's state_dict
307
+ model_path = os.path.join(save_directory, 'pytorch_model.bin')
308
+ torch.save(self.state_dict(), model_path)
309
+
310
+ # Save the configuration of the model
311
+ config_path = os.path.join(save_directory, 'config.json')
312
+ with open(config_path, 'w') as f:
313
+ json.dump(self.config.__dict__, f, indent=4)
314
+
315
+
316
+ class BambaDecoder(nn.Module, GenerationMixin):
317
+
318
+ def __init__(
319
+ self,
320
+ config: BambaConfig,
321
+ initializer_cfg=None,
322
+ device=None,
323
+ dtype=None,
324
+ ) -> None:
325
+ self.config = config
326
+ d_model = config.d_model
327
+ n_layer = config.n_layer
328
+ d_intermediate = config.d_intermediate
329
+ vocab_size = config.vocab_size
330
+ max_position_embeddings = config.max_position_embeddings
331
+ ssm_cfg = config.ssm_cfg
332
+ attn_layer_idx = config.attn_layer_idx
333
+ attn_cfg = config.attn_cfg
334
+ rms_norm = config.rms_norm
335
+ residual_in_fp32 = config.residual_in_fp32
336
+ fused_add_norm = config.fused_add_norm
337
+ pad_vocab_size_multiple = config.pad_vocab_size_multiple
338
+ factory_kwargs = {"device": device, "dtype": dtype}
339
+
340
+ super().__init__()
341
+ if vocab_size % pad_vocab_size_multiple != 0:
342
+ vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
343
+ self.backbone = BambaMixerModel(
344
+ d_model=d_model,
345
+ n_layer=n_layer,
346
+ d_intermediate=d_intermediate,
347
+ vocab_size=vocab_size,
348
+ max_position_embeddings=max_position_embeddings,
349
+ is_decoder=True,
350
+ ssm_cfg=ssm_cfg,
351
+ attn_layer_idx=attn_layer_idx,
352
+ attn_cfg=attn_cfg,
353
+ rms_norm=rms_norm,
354
+ initializer_cfg=initializer_cfg,
355
+ fused_add_norm=fused_add_norm,
356
+ residual_in_fp32=residual_in_fp32,
357
+ **factory_kwargs,
358
+ )
359
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
360
+
361
+ # Initialize weights and apply final processing
362
+ self.apply(
363
+ partial(
364
+ _init_weights,
365
+ n_layer=n_layer,
366
+ **(initializer_cfg if initializer_cfg is not None else {}),
367
+ )
368
+ )
369
+ self.tie_weights()
370
+
371
+ def tie_weights(self):
372
+ if self.config.tie_embeddings:
373
+ self.lm_head.weight = self.backbone.embedding.weight
374
+
375
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
376
+ return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
377
+
378
+ def forward(self, input_ids, token_type_ids=None, position_ids=None, inference_params=None, num_last_tokens=0, encoder_hidden_states=None, attention_mask=None, **mixer_kwargs):
379
+ """
380
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
381
+ num_last_tokens: if > 0, only return the logits for the last n tokens
382
+ """
383
+ hidden_states = self.backbone(
384
+ input_ids, token_type_ids, inference_params=inference_params, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **mixer_kwargs
385
+ )
386
+ if num_last_tokens > 0:
387
+ hidden_states = hidden_states[:, -num_last_tokens:]
388
+ lm_logits = self.lm_head(hidden_states)
389
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
390
+ return CausalLMOutput(logits=lm_logits)
391
+
392
+ @classmethod
393
+ def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
394
+ config_data = load_config_hf(pretrained_model_name)
395
+ config = MambaConfig(**config_data)
396
+ model = cls(config, device=device, dtype=dtype, **kwargs)
397
+ model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
398
+ return model
399
+
400
+ def save_pretrained(self, save_directory):
401
+ """
402
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
403
+ Save the model and its configuration file to a directory.
404
+ """
405
+ # Ensure save_directory exists
406
+ os.makedirs(save_directory, exist_ok=True)
407
+
408
+ # Save the model's state_dict
409
+ model_path = os.path.join(save_directory, 'pytorch_model.bin')
410
+ torch.save(self.state_dict(), model_path)
411
+
412
+ # Save the configuration of the model
413
+ config_path = os.path.join(save_directory, 'config.json')
414
+ with open(config_path, 'w') as f:
415
+ json.dump(self.config.__dict__, f, indent=4)
416
+
417
+
418
+ class BambaEncoderDecoder(nn.Module, GenerationMixin):
419
+
420
+ def __init__(
421
+ self,
422
+ config: BambaEncoderDecoderConfig,
423
+ tokenizer=None,
424
+ initializer_cfg=None,
425
+ device=None,
426
+ dtype=None,
427
+ ) -> None:
428
+ self.config = config
429
+ self.encoder_config = config.encoder_config
430
+ self.decoder_config = config.decoder_config
431
+ factory_kwargs = {"device": device, "dtype": dtype}
432
+ self.tokenizer = tokenizer
433
+
434
+ super().__init__()
435
+ self.encoder = BambaEncoder(self.encoder_config, **factory_kwargs)
436
+ self.decoder = BambaDecoder(self.decoder_config, **factory_kwargs)
437
+
438
+ self.device = device
439
+
440
+ self.tie_weights()
441
+ self._set_seed(config.seed)
442
+
443
+ def tie_weights(self):
444
+ if self.config.tie_word_embeddings:
445
+ self.decoder.backbone.embedding.weight = self.encoder.backbone.embedding.word_embeddings.weight
446
+
447
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
448
+ return self.decoder.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
449
+
450
+ def forward(self, encoder_input_ids, decoder_input_ids, token_type_ids=None, attention_mask=None, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
451
+ """
452
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
453
+ num_last_tokens: if > 0, only return the logits for the last n tokens
454
+ """
455
+ encoder_hidden_states = self.encoder(encoder_input_ids, inference_params=inference_params, **mixer_kwargs).hidden_states
456
+ lm_logits = self.decoder(decoder_input_ids, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, inference_params=inference_params, **mixer_kwargs).logits
457
+ if num_last_tokens > 0:
458
+ hidden_states = hidden_states[:, -num_last_tokens:]
459
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
460
+ return CausalLMOutput(logits=lm_logits)
461
+
462
+ @classmethod
463
+ def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
464
+ config_data = load_config_hf(pretrained_model_name)
465
+ config = MambaConfig(**config_data)
466
+ model = cls(config, device=device, dtype=dtype, **kwargs)
467
+ model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
468
+ return model
469
+
470
+ def save_pretrained(self, save_directory):
471
+ """
472
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
473
+ Save the model and its configuration file to a directory.
474
+ """
475
+ # Ensure save_directory exists
476
+ os.makedirs(save_directory, exist_ok=True)
477
+
478
+ # Save the model's state_dict
479
+ model_path = os.path.join(save_directory, 'pytorch_model.bin')
480
+ torch.save(self.state_dict(), model_path)
481
+
482
+ # Save the configuration of the model
483
+ config_path = os.path.join(save_directory, 'config.json')
484
+ with open(config_path, 'w') as f:
485
+ json.dump(self.config.__dict__, f, indent=4)
486
+
487
+ def _set_seed(self, value):
488
+ print('Random Seed:', value)
489
+ random.seed(value)
490
+ torch.manual_seed(value)
491
+ torch.cuda.manual_seed(value)
492
+ torch.cuda.manual_seed_all(value)
493
+ np.random.seed(value)
494
+ cudnn.deterministic = True
495
+ cudnn.benchmark = False
496
+
497
+ def extract_embeddings(self, smiles):
498
+ tokens = self.tokenizer(smiles, padding=True, truncation=True, return_tensors='pt')
499
+
500
+ idx = tokens['input_ids'].to(self.device)
501
+ mask = tokens['attention_mask'].to(self.device)
502
+ outputs = self.encoder(input_ids=idx)
503
+ hidden_states = outputs.hidden_states
504
+
505
+ token_embeddings = hidden_states
506
+ input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
507
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
508
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
509
+ embeddings = sum_embeddings / sum_mask
510
+
511
+ return embeddings
512
+
513
+ def encode(self, smiles, useCuda=False, batch_size=100, return_torch=False):
514
+ """Extract efficiently SMILES embeddings per batches."""
515
+ # TODO: remove useCuda argument
516
+
517
+ # handle single str or a list of str
518
+ smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
519
+
520
+ # process in batches
521
+ n_split = smiles.shape[0] // batch_size if smiles.shape[0] >= batch_size else smiles.shape[0]
522
+ embeddings = [
523
+ self.extract_embeddings(list(batch)).cpu().detach().numpy()
524
+ for batch in tqdm(np.array_split(smiles, n_split))
525
+ ]
526
+ flat_list = [item for sublist in embeddings for item in sublist]
527
+
528
+ # clear GPU memory
529
+ torch.cuda.empty_cache()
530
+ gc.collect()
531
+
532
+ if return_torch:
533
+ return torch.tensor(flat_list)
534
+ return pd.DataFrame(flat_list)
str_bamba/bamba_config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class BambaConfig:
6
+
7
+ d_model: int = 2560
8
+ d_intermediate: int = 0
9
+ n_layer: int = 64
10
+ vocab_size: int = 50277
11
+ max_position_embeddings: int = 262144
12
+ ssm_cfg: dict = field(default_factory=dict)
13
+ attn_layer_idx: list = field(default_factory=list)
14
+ attn_cfg: dict = field(default_factory=dict)
15
+ rms_norm: bool = True
16
+ residual_in_fp32: bool = True
17
+ fused_add_norm: bool = True
18
+ pad_vocab_size_multiple: int = 8
19
+ tie_embeddings: bool = True
20
+
21
+
22
+ @dataclass
23
+ class BambaEncoderDecoderConfig:
24
+
25
+ encoder_config: BambaConfig = None
26
+ decoder_config: BambaConfig = None
27
+ tie_word_embeddings: bool = True
28
+ seed: int = 0
str_bamba/bamba_modules.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn, Tensor
5
+
6
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
7
+ from transformers.models.bart.modeling_bart import BartSdpaAttention
8
+ from transformers.activations import ACT2FN
9
+
10
+
11
+ class BertEmbeddings(nn.Module):
12
+ """Construct the embeddings from word, position and token_type embeddings."""
13
+
14
+ def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size=2, pad_token_id=2, layer_norm_eps=1e-12, hidden_dropout_prob=0.1, device=None, dtype=None):
15
+ factory_kwargs = {"device": device, "dtype": dtype}
16
+ super().__init__()
17
+ self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id, **factory_kwargs)
18
+ self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size, **factory_kwargs)
19
+
20
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
21
+ # any TensorFlow checkpoint file
22
+ self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, **factory_kwargs)
23
+ self.dropout = nn.Dropout(hidden_dropout_prob)
24
+ # self.position_embedding_type = "rotary"
25
+ # self.register_buffer(
26
+ # "position_ids", torch.arange(max_position_embeddings).expand((1, -1)), persistent=False
27
+ # )
28
+ # self.register_buffer(
29
+ # "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
30
+ # )
31
+
32
+ def forward(
33
+ self,
34
+ input_ids: Optional[torch.LongTensor] = None,
35
+ token_type_ids: Optional[torch.LongTensor] = None,
36
+ position_ids: Optional[torch.LongTensor] = None,
37
+ inputs_embeds: Optional[torch.FloatTensor] = None,
38
+ past_key_values_length: int = 0,
39
+ ) -> torch.Tensor:
40
+ if input_ids is not None:
41
+ input_shape = input_ids.size()
42
+ else:
43
+ input_shape = inputs_embeds.size()[:-1]
44
+
45
+ seq_length = input_shape[1]
46
+
47
+ # if position_ids is None:
48
+ # position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
49
+
50
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
51
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
52
+ # issue #5664
53
+ if token_type_ids is None:
54
+ # if hasattr(self, "token_type_ids"):
55
+ # import ipdb; ipdb.set_trace()
56
+ # buffered_token_type_ids = self.token_type_ids[:, :seq_length]
57
+ # buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
58
+ # token_type_ids = buffered_token_type_ids_expanded
59
+ # else:
60
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
61
+
62
+ if inputs_embeds is None:
63
+ inputs_embeds = self.word_embeddings(input_ids)
64
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
65
+
66
+ embeddings = inputs_embeds + token_type_embeddings
67
+ embeddings = self.LayerNorm(embeddings)
68
+ embeddings = self.dropout(embeddings)
69
+ return embeddings
70
+
71
+
72
+ class BertPooler(nn.Module):
73
+ def __init__(self, hidden_size, device=None, dtype=None):
74
+ factory_kwargs = {"device": device, "dtype": dtype}
75
+ super().__init__()
76
+ self.dense = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
77
+ self.activation = nn.Tanh()
78
+
79
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
80
+ # We "pool" the model by simply taking the hidden state corresponding
81
+ # to the first token.
82
+ first_token_tensor = hidden_states[:, 0]
83
+ pooled_output = self.dense(first_token_tensor)
84
+ pooled_output = self.activation(pooled_output)
85
+ return pooled_output
86
+
87
+
88
+ class BertPredictionHeadTransform(nn.Module):
89
+ def __init__(self, hidden_size, hidden_act="gelu", layer_norm_eps=1e-12, device=None, dtype=None):
90
+ factory_kwargs = {"device": device, "dtype": dtype}
91
+ super().__init__()
92
+ self.dense = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
93
+ if isinstance(hidden_act, str):
94
+ self.transform_act_fn = ACT2FN[hidden_act]
95
+ else:
96
+ self.transform_act_fn = hidden_act
97
+ self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, **factory_kwargs)
98
+
99
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
100
+ hidden_states = self.dense(hidden_states)
101
+ hidden_states = self.transform_act_fn(hidden_states)
102
+ hidden_states = self.LayerNorm(hidden_states)
103
+ return hidden_states
104
+
105
+
106
+ class BertLMPredictionHead(nn.Module):
107
+ def __init__(self, vocab_size, hidden_size, hidden_act="gelu", layer_norm_eps=1e-12, device=None, dtype=None):
108
+ factory_kwargs = {"device": device, "dtype": dtype}
109
+ super().__init__()
110
+ self.transform = BertPredictionHeadTransform(hidden_size, hidden_act, layer_norm_eps, **factory_kwargs)
111
+
112
+ # The output weights are the same as the input embeddings, but there is
113
+ # an output-only bias for each token.
114
+ self.decoder = nn.Linear(hidden_size, vocab_size, bias=False, **factory_kwargs)
115
+
116
+ self.bias = nn.Parameter(torch.zeros(vocab_size, **factory_kwargs))
117
+
118
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
119
+ self.decoder.bias = self.bias
120
+
121
+ def _tie_weights(self):
122
+ self.decoder.bias = self.bias
123
+
124
+ def forward(self, hidden_states):
125
+ hidden_states = self.transform(hidden_states)
126
+ hidden_states = self.decoder(hidden_states)
127
+ return hidden_states
128
+
129
+
130
+ class BertPreTrainingHeads(nn.Module):
131
+ def __init__(self, vocab_size, hidden_size, hidden_act="gelu", layer_norm_eps=1e-12, device=None, dtype=None):
132
+ factory_kwargs = {"device": device, "dtype": dtype}
133
+ super().__init__()
134
+ self.predictions = BertLMPredictionHead(vocab_size, hidden_size, hidden_act, layer_norm_eps, **factory_kwargs)
135
+ self.seq_relationship = nn.Linear(hidden_size, 2, **factory_kwargs)
136
+
137
+ def forward(self, sequence_output, pooled_output):
138
+ prediction_scores = self.predictions(sequence_output)
139
+ seq_relationship_score = self.seq_relationship(pooled_output)
140
+ return prediction_scores, seq_relationship_score
141
+
142
+
143
+
144
+ class BlockCrossAttention(nn.Module):
145
+ def __init__(
146
+ self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
147
+ ):
148
+ """
149
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
150
+
151
+ This Block has a slightly different structure compared to a regular
152
+ prenorm Transformer block.
153
+ The standard block is: LN -> MHA/MLP -> Add.
154
+ [Ref: https://arxiv.org/abs/2002.04745]
155
+ Here we have: Add -> LN -> Mixer, returning both
156
+ the hidden_states (output of the mixer) and the residual.
157
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
158
+ The residual needs to be provided (except for the very first block).
159
+ """
160
+ super().__init__()
161
+ self.residual_in_fp32 = residual_in_fp32
162
+ self.fused_add_norm = fused_add_norm
163
+ self.norm = norm_cls(dim)
164
+ self.mixer = mixer_cls(dim)
165
+ self.encoder_attn = BartSdpaAttention(embed_dim=dim, num_heads=1)
166
+ if mlp_cls is not nn.Identity:
167
+ self.norm2 = norm_cls(dim)
168
+ self.mlp = mlp_cls(dim)
169
+ else:
170
+ self.mlp = None
171
+ if self.fused_add_norm:
172
+ assert RMSNorm is not None, "RMSNorm import fails"
173
+ assert isinstance(
174
+ self.norm, (nn.LayerNorm, RMSNorm)
175
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
176
+
177
+ def forward(
178
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, encoder_hidden_states=None, attention_mask=None, **mixer_kwargs
179
+ ):
180
+ r"""Pass the input through the encoder layer.
181
+
182
+ Args:
183
+ hidden_states: the sequence to the encoder layer (required).
184
+ residual: hidden_states = Mixer(LN(residual))
185
+ """
186
+ if not self.fused_add_norm:
187
+ residual = (hidden_states + residual) if residual is not None else hidden_states
188
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
189
+ if self.residual_in_fp32:
190
+ residual = residual.to(torch.float32)
191
+ else:
192
+ hidden_states, residual = layer_norm_fn(
193
+ hidden_states,
194
+ self.norm.weight,
195
+ self.norm.bias,
196
+ residual=residual,
197
+ prenorm=True,
198
+ residual_in_fp32=self.residual_in_fp32,
199
+ eps=self.norm.eps,
200
+ is_rms_norm=isinstance(self.norm, RMSNorm)
201
+ )
202
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
203
+
204
+ # cross-attention
205
+ hidden_states, _, _ = self.encoder_attn(hidden_states, encoder_hidden_states, attention_mask=attention_mask)
206
+
207
+ if self.mlp is not None:
208
+ if not self.fused_add_norm:
209
+ residual = hidden_states + residual
210
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
211
+ if self.residual_in_fp32:
212
+ residual = residual.to(torch.float32)
213
+ else:
214
+ hidden_states, residual = layer_norm_fn(
215
+ hidden_states,
216
+ self.norm2.weight,
217
+ self.norm2.bias,
218
+ residual=residual,
219
+ prenorm=True,
220
+ residual_in_fp32=self.residual_in_fp32,
221
+ eps=self.norm2.eps,
222
+ is_rms_norm=isinstance(self.norm2, RMSNorm)
223
+ )
224
+ hidden_states = self.mlp(hidden_states)
225
+
226
+ return hidden_states, residual
227
+
228
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
229
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
str_bamba/config/config_encoder-decoder_436M.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "encoder_config": {
3
+ "d_model": 1024,
4
+ "d_intermediate": 0,
5
+ "n_layer": 24,
6
+ "vocab_size": 5000,
7
+ "max_position_embeddings": 4096,
8
+ "ssm_cfg": {
9
+ "layer": "Mamba2"
10
+ },
11
+ "attn_layer_idx": [
12
+ 6,
13
+ 18
14
+ ],
15
+ "attn_cfg": {
16
+ "causal": false,
17
+ "d_conv": 0,
18
+ "head_dim": 64,
19
+ "num_heads": 16,
20
+ "num_heads_kv": 8,
21
+ "out_proj_bias": false,
22
+ "qkv_proj_bias": false,
23
+ "rotary_emb_dim": 64
24
+ },
25
+ "rms_norm": true,
26
+ "residual_in_fp32": true,
27
+ "fused_add_norm": true,
28
+ "pad_vocab_size_multiple": 8,
29
+ "tie_embeddings": false
30
+ },
31
+ "decoder_config": {
32
+ "d_model": 1024,
33
+ "d_intermediate": 0,
34
+ "n_layer": 24,
35
+ "vocab_size": 5000,
36
+ "max_position_embeddings": 4096,
37
+ "ssm_cfg": {
38
+ "layer": "Mamba2"
39
+ },
40
+ "attn_layer_idx": [
41
+ 6,
42
+ 18
43
+ ],
44
+ "attn_cfg": {
45
+ "causal": true,
46
+ "d_conv": 0,
47
+ "head_dim": 64,
48
+ "num_heads": 16,
49
+ "num_heads_kv": 8,
50
+ "out_proj_bias": false,
51
+ "qkv_proj_bias": false,
52
+ "rotary_emb_dim": 64
53
+ },
54
+ "rms_norm": true,
55
+ "residual_in_fp32": true,
56
+ "fused_add_norm": true,
57
+ "pad_vocab_size_multiple": 8,
58
+ "tie_embeddings": false
59
+ },
60
+ "tie_word_embeddings": true,
61
+ "seed": 0
62
+ }
str_bamba/generation.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+ import gc
3
+ import time
4
+ from collections import namedtuple
5
+ from dataclasses import dataclass, field
6
+ from functools import partial
7
+ from typing import Callable, Optional, Sequence, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from einops import rearrange, repeat
12
+ from torch import Tensor
13
+ from torch.profiler import ProfilerActivity, profile, record_function
14
+ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
15
+
16
+
17
+ @dataclass
18
+ class InferenceParams:
19
+ """Inference parameters that are passed to the main model in order
20
+ to efficienly calculate and store the context during inference."""
21
+
22
+ max_seqlen: int
23
+ max_batch_size: int
24
+ seqlen_offset: int = 0
25
+ batch_size_offset: int = 0
26
+ key_value_memory_dict: dict = field(default_factory=dict)
27
+ lengths_per_sample: Optional[Tensor] = None
28
+
29
+ def reset(self, max_seqlen, max_batch_size):
30
+ self.max_seqlen = max_seqlen
31
+ self.max_batch_size = max_batch_size
32
+ self.seqlen_offset = 0
33
+ if self.lengths_per_sample is not None:
34
+ self.lengths_per_sample.zero_()
35
+
36
+
37
+ def modify_logits_for_min_p_filtering(logits, min_p):
38
+ """Set the logits for none min_p values to -inf. Done in-place."""
39
+ if min_p <= 0.0 or min_p >= 1.0:
40
+ return
41
+ indices_to_remove = logits < min_p
42
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
43
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
44
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
45
+ def modify_logits_for_top_k_filtering(logits, top_k):
46
+ """Set the logits for none top-k values to -inf. Done in-place."""
47
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
48
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
49
+
50
+
51
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
52
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
53
+ def modify_logits_for_top_p_filtering(logits, top_p):
54
+ """Set the logits for none top-p values to -inf. Done in-place."""
55
+ if top_p <= 0.0 or top_p >= 1.0:
56
+ return
57
+ # First sort and calculate cumulative sum of probabilities.
58
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
59
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
60
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
61
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
62
+ # scatter sorted tensors to original indexing
63
+ indices_to_remove = sorted_indices_to_remove.scatter(
64
+ 1, sorted_indices, sorted_indices_to_remove
65
+ )
66
+ logits.masked_fill_(indices_to_remove, float("-inf"))
67
+
68
+
69
+ def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
70
+ """Apply repetition penalty. See https://arxiv.org/abs/1909.05858
71
+ logits: (batch_size, vocab_size)
72
+ prev_output_tokens: (batch_size, seq_len)
73
+ """
74
+ if repetition_penalty == 1.0:
75
+ return logits
76
+ score = torch.gather(logits, 1, prev_output_tokens)
77
+ # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
78
+ score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
79
+ logits.scatter_(1, prev_output_tokens, score)
80
+ return logits
81
+
82
+
83
+ def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
84
+ """Sample from top-k logits.
85
+ Arguments:
86
+ logits: Tensor of shape (batch_size, vocab_size)
87
+ """
88
+ if top_k == 1: # Short-circuit for greedy decoding
89
+ return logits.argmax(dim=-1)
90
+ else:
91
+ if top_p > 0.0:
92
+ assert top_p <= 1.0, "top-p should be in (0, 1]."
93
+ if top_k > 0:
94
+ top_k = min(top_k, logits.size(-1)) # Safety check
95
+ logits_top, indices = torch.topk(logits, top_k, dim=-1)
96
+ if temperature != 1.0:
97
+ logits_top /= temperature
98
+ modify_logits_for_top_p_filtering(logits_top, top_p)
99
+ return indices[
100
+ torch.arange(indices.shape[0], device=indices.device),
101
+ torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
102
+ ]
103
+ else:
104
+ if min_p > 0.0:
105
+ logits_top = logits.clone()
106
+ max_prob = logits_top[..., 0].item()
107
+ min_prob = max_prob * min_p
108
+ modify_logits_for_min_p_filtering(logits_top, min_prob)
109
+ if temperature != 1.0:
110
+ logits_top /= temperature
111
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
112
+ # Clone so that when we modify for top_p we don't change the original logits
113
+ logits_top = logits / temperature if temperature != 1.0 else logits.clone()
114
+ modify_logits_for_top_p_filtering(logits_top, top_p)
115
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
116
+ dim=-1
117
+ )
118
+
119
+
120
+ @torch.inference_mode()
121
+ def decode(
122
+ input_ids,
123
+ encoder_hidden_states,
124
+ model,
125
+ max_length,
126
+ top_k=1,
127
+ top_p=0.0,
128
+ min_p=0.0,
129
+ temperature=1.0,
130
+ repetition_penalty=1.0,
131
+ eos_token_id=None,
132
+ teacher_outputs=None,
133
+ vocab_size=None,
134
+ cg=False,
135
+ enable_timing=False,
136
+ output_scores=False,
137
+ streamer: Optional[TextStreamer] = None
138
+ ):
139
+ """Decoding, either greedy or with top-k or top-p sampling.
140
+ If top-k = 0, don't limit the number of candidates (pure sampling).
141
+ Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
142
+ then top-p.
143
+ We assume that all sequences in the same batch have the same length.
144
+
145
+ Arguments:
146
+ input_ids: (batch, seq_len)
147
+ max_length: int
148
+ teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
149
+ logits, the next token is taken from the teacher_outputs. Useful for testing.
150
+ Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
151
+ sequences: (batch, max_length)
152
+ scores: tuples of (batch, vocab_size)
153
+ """
154
+ if streamer is not None:
155
+ streamer.put(input_ids.cpu())
156
+
157
+ batch_size, seqlen_og = input_ids.shape
158
+ teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
159
+ if cg:
160
+ if not hasattr(model, "_decoding_cache"):
161
+ model._decoding_cache = None
162
+ model._decoding_cache = update_graph_cache(
163
+ model,
164
+ encoder_hidden_states,
165
+ model._decoding_cache,
166
+ batch_size,
167
+ seqlen_og,
168
+ max_length,
169
+ )
170
+ inference_params = model._decoding_cache.inference_params
171
+ inference_params.reset(max_length, batch_size)
172
+ else:
173
+ inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
174
+
175
+ def get_logits(input_ids, inference_params):
176
+ decoding = inference_params.seqlen_offset > 0
177
+ if decoding:
178
+ position_ids = torch.full(
179
+ (batch_size, 1),
180
+ inference_params.seqlen_offset,
181
+ dtype=torch.long,
182
+ device=input_ids.device,
183
+ )
184
+ else:
185
+ position_ids = None
186
+ if not cg or not decoding:
187
+ logits = model(
188
+ input_ids,
189
+ encoder_hidden_states=encoder_hidden_states,
190
+ position_ids=position_ids,
191
+ inference_params=inference_params,
192
+ num_last_tokens=1,
193
+ ).logits.squeeze(dim=1)
194
+ else:
195
+ logits = model._decoding_cache.run(
196
+ input_ids, position_ids, inference_params.seqlen_offset
197
+ ).squeeze(dim=1)
198
+ return logits[..., :vocab_size] if vocab_size is not None else logits
199
+
200
+ def sample_tokens(logits, inference_params):
201
+ if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
202
+ token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
203
+ else:
204
+ token = teacher_outputs[:, inference_params.seqlen_offset]
205
+ # return rearrange(token, "b -> b 1")
206
+ return token.unsqueeze(1)
207
+
208
+ def should_stop(current_token, inference_params):
209
+ if inference_params.seqlen_offset == 0:
210
+ return False
211
+ if eos_token_id is not None and (current_token == eos_token_id).all():
212
+ return True
213
+ if inference_params.seqlen_offset >= max_length - 1:
214
+ return True
215
+ return False
216
+
217
+ start = torch.cuda.Event(enable_timing=enable_timing)
218
+ end = torch.cuda.Event(enable_timing=enable_timing)
219
+
220
+ if enable_timing:
221
+ start.record()
222
+ scores, sequences = [], [input_ids]
223
+ sequences_cat = input_ids
224
+ while not should_stop(sequences[-1], inference_params):
225
+ logits = get_logits(sequences[-1], inference_params)
226
+ if output_scores:
227
+ scores.append(logits.clone())
228
+ inference_params.seqlen_offset += sequences[-1].shape[1]
229
+ if repetition_penalty == 1.0:
230
+ sampled_tokens = sample_tokens(logits, inference_params)
231
+ else:
232
+ logits = modify_logit_for_repetition_penalty(
233
+ logits, sequences_cat, repetition_penalty
234
+ )
235
+ sampled_tokens = sample_tokens(logits, inference_params)
236
+ sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
237
+ sequences.append(sampled_tokens)
238
+ if streamer is not None:
239
+ streamer.put(sampled_tokens.cpu())
240
+ if streamer is not None:
241
+ streamer.end()
242
+ if enable_timing:
243
+ end.record()
244
+ torch.cuda.synchronize()
245
+ print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
246
+ output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
247
+ return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
248
+
249
+
250
+ class GenerationMixin:
251
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
252
+ raise NotImplementedError
253
+
254
+ def generate(
255
+ self,
256
+ input_ids,
257
+ encoder_hidden_states,
258
+ max_length,
259
+ top_k=1,
260
+ top_p=0.0,
261
+ min_p=0.0,
262
+ temperature=1.0,
263
+ return_dict_in_generate=False,
264
+ output_scores=False,
265
+ **kwargs,
266
+ ):
267
+ output = decode(
268
+ input_ids, encoder_hidden_states, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, output_scores=output_scores, **kwargs
269
+ )
270
+ if not output_scores:
271
+ output.scores = None
272
+ return output if return_dict_in_generate else output.sequences
273
+
274
+
275
+ @dataclass
276
+ class DecodingCGCache:
277
+ max_batch_size: int = 0
278
+ max_seqlen: int = 0
279
+ device = None
280
+ dtype = None
281
+ callables: dict = field(default_factory=dict)
282
+ mempool = None
283
+ inference_params: Optional[InferenceParams] = None
284
+ run: Optional[Callable] = None
285
+
286
+
287
+ @torch.inference_mode()
288
+ def update_graph_cache(
289
+ model,
290
+ encoder_hidden_states,
291
+ cache,
292
+ batch_size,
293
+ seqlen_og,
294
+ max_seqlen,
295
+ decoding_seqlens=(1,),
296
+ dtype=None,
297
+ n_warmups=2,
298
+ ):
299
+ if cache is None:
300
+ cache = DecodingCGCache()
301
+ param_example = next(iter(model.parameters()))
302
+ device = param_example.device
303
+ if dtype is None:
304
+ dtype = param_example.dtype
305
+ if (
306
+ (device, dtype) != (cache.device, cache.dtype)
307
+ or batch_size > cache.max_batch_size
308
+ or max_seqlen > cache.max_seqlen
309
+ ): # Invalidate the cache
310
+ cache.callables = {}
311
+ cache.mempool = None
312
+ cache.inference_params = None
313
+ gc.collect()
314
+ cache.device, cache.dtype = device, dtype
315
+ cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
316
+ assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
317
+ inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
318
+ lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
319
+ cache.inference_params = InferenceParams(
320
+ max_seqlen=max_seqlen,
321
+ max_batch_size=batch_size,
322
+ seqlen_offset=seqlen_og,
323
+ key_value_memory_dict=inf_cache,
324
+ lengths_per_sample=lengths_per_sample,
325
+ )
326
+ cache.mempool = torch.cuda.graphs.graph_pool_handle()
327
+ for decoding_seqlen in decoding_seqlens:
328
+ if (batch_size, decoding_seqlen) not in cache.callables:
329
+ cache.callables[batch_size, decoding_seqlen] = capture_graph(
330
+ model,
331
+ encoder_hidden_states,
332
+ cache.inference_params,
333
+ batch_size,
334
+ max_seqlen,
335
+ decoding_seqlen=decoding_seqlen,
336
+ mempool=cache.mempool,
337
+ n_warmups=n_warmups,
338
+ )
339
+
340
+ def dispatch(input_ids, position_ids, seqlen):
341
+ batch_size, decoding_seqlen = input_ids.shape[:2]
342
+ return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
343
+
344
+ cache.run = dispatch
345
+ cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
346
+ return cache
347
+
348
+
349
+ def capture_graph(
350
+ model, encoder_hidden_states, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
351
+ ):
352
+ device = next(iter(model.parameters())).device
353
+ input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
354
+ position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
355
+ seqlen_offset_og = inference_params.seqlen_offset
356
+ inference_params.seqlen_offset = max_seqlen - decoding_seqlen
357
+ inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
358
+
359
+ # Warmup before capture
360
+ s = torch.cuda.Stream()
361
+ s.wait_stream(torch.cuda.current_stream())
362
+ with torch.cuda.stream(s):
363
+ for _ in range(n_warmups):
364
+ logits = model(
365
+ input_ids,
366
+ encoder_hidden_states=encoder_hidden_states,
367
+ position_ids=position_ids,
368
+ inference_params=inference_params,
369
+ num_last_tokens=decoding_seqlen,
370
+ ).logits
371
+ s.synchronize()
372
+ # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
373
+ # which requires that graph launch and non-captured launch to not overlap (I think,
374
+ # that's how I interpret the documentation). I'm not sure if this is required.
375
+ if torch.distributed.is_initialized():
376
+ torch.distributed.barrier()
377
+ torch.cuda.current_stream().wait_stream(s)
378
+ # Captures the graph
379
+ # To allow capture, automatically sets a side stream as the current stream in the context
380
+ graph = torch.cuda.CUDAGraph()
381
+ with torch.cuda.graph(graph, pool=mempool):
382
+ logits = model(
383
+ input_ids,
384
+ encoder_hidden_states=encoder_hidden_states,
385
+ position_ids=position_ids,
386
+ inference_params=inference_params,
387
+ num_last_tokens=decoding_seqlen,
388
+ ).logits
389
+
390
+ def run(new_input_ids, new_position_ids, seqlen):
391
+ inference_params.lengths_per_sample[:] = seqlen
392
+ input_ids.copy_(new_input_ids)
393
+ position_ids.copy_(new_position_ids)
394
+ graph.replay()
395
+ return logits.clone()
396
+
397
+ inference_params.seqlen_offset = seqlen_offset_og
398
+ return run
str_bamba/load.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .bamba_config import BambaEncoderDecoderConfig
2
+ from .bamba import BambaConfig, BambaEncoderDecoder
3
+ from .tokenizer.str_tokenizer import load_tokenizer
4
+ import torch
5
+ import numpy as np
6
+ import random
7
+ import json
8
+ import os
9
+
10
+
11
+ def load_strbamba(ckpt_filename,
12
+ base_folder='./str_bamba',
13
+ config_filename='config_encoder-decoder_436M.json',
14
+ tokenizer_filename='str_bamba_tokenizer.json',
15
+ eval_model=True,
16
+ device='cuda:0',
17
+ dtype=torch.float32
18
+ ):
19
+ # load config
20
+ with open(os.path.join(base_folder, f'config/{config_filename}')) as json_data:
21
+ config_json = json.load(json_data)
22
+ bamba_config = BambaEncoderDecoderConfig(
23
+ encoder_config=BambaConfig(**config_json['encoder_config']),
24
+ decoder_config=BambaConfig(**config_json['decoder_config']),
25
+ tie_word_embeddings=config_json['tie_word_embeddings'],
26
+ seed=config_json['seed']
27
+ )
28
+
29
+ # load tokenizer
30
+ tokenizer = load_tokenizer(os.path.join(base_folder, f'tokenizer/{tokenizer_filename}'))
31
+
32
+ # load model
33
+ model = BambaEncoderDecoder(bamba_config, tokenizer, device=device, dtype=dtype)
34
+
35
+ # load weights
36
+ ckpt_dict = torch.load(
37
+ os.path.join(base_folder, f'checkpoints/{ckpt_filename}'),
38
+ map_location=device,
39
+ weights_only=False
40
+ )
41
+ model.load_state_dict(ckpt_dict['module'])
42
+
43
+ # load RNG states each time the model and states are loaded from checkpoint
44
+ if 'rng' in ckpt_dict:
45
+ rng = ckpt_dict['rng']
46
+ for key, value in rng.items():
47
+ if key =='torch_state':
48
+ torch.set_rng_state(value.cpu())
49
+ elif key =='cuda_state':
50
+ torch.cuda.set_rng_state(value.cpu())
51
+ elif key =='numpy_state':
52
+ np.random.set_state(value)
53
+ elif key =='python_state':
54
+ random.setstate(value)
55
+ else:
56
+ print('unrecognized state')
57
+
58
+ if eval_model:
59
+ return model.eval()
60
+ return model
str_bamba/tokenizer/special_tokens.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STR_SPECIAL_TOKENS = {
2
+ ### basic tokens ###
3
+ "BOS_TOKEN": "<bos>",
4
+ "EOS_TOKEN": "<sep>",
5
+ "PAD_TOKEN": "<pad>",
6
+ "MASK_TOKEN": "<mask>",
7
+ "UNK_TOKEN": "<unk>",
8
+
9
+ ### molecular representations ###
10
+ # molecular formula
11
+ "MOLECULAR_FORMULA_TOKEN": "<formula>",
12
+
13
+ # canonical SMILES
14
+ "SMILES_TOKEN": "<smiles>",
15
+
16
+ # IUPAC name
17
+ "IUPAC_TOKEN": "<iupac>",
18
+
19
+ # InChI
20
+ "INCHI_TOKEN": "<inchi>",
21
+ "INCHI_INITIAL_TOKEN": "InChI=", # force `InChI=` to be a unique token
22
+ "INCHI_COMMA_TOKEN": ",", # force `,` to be a unique token
23
+ "INCHI_DASH_TOKEN": "-", # force `-` to be a unique token
24
+ "INCHI_FORWARDSLASH_TOKEN": "/", # force `/` to be a unique token
25
+ "INCHI_QUESTIONMARK_TOKEN": "?", # force `?` to be a unique token
26
+ "INCHI_PARENTHESIS_OPEN_TOKEN": "(", # force `(` to be a unique token
27
+ "INCHI_PARENTHESIS_CLOSE_TOKEN": ")", # force `)` to be a unique token
28
+
29
+ # SELFIES
30
+ "SELFIES_TOKEN": "<selfies>",
31
+
32
+ # polymer SPG
33
+ "POLYMER_SPG_TOKEN": "<polymer_spg>",
34
+ "POLYMER_ARROW_TOKEN": "->", # force `->` to be a unique token
35
+
36
+ # formulation
37
+ "FORMULATION_START_TOKEN": "<formulation_start>",
38
+ "FORMULATION_END_TOKEN": "<formulation_end>",
39
+ }
str_bamba/tokenizer/str_bamba_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
str_bamba/tokenizer/str_tokenizer.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from tokenizers import NormalizedString, PreTokenizedString
4
+ from tokenizers.pre_tokenizers import PreTokenizer
5
+ from transformers import PreTrainedTokenizerFast
6
+
7
+ import re
8
+
9
+
10
+ ATOM_REGEX_PATTERN = r"""(<(.*?)>|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""
11
+ FORMULATION_REGEX_PATTERN = r"""(<(.*?)>|[-+]?\d*\.\d+|[-+]?\d+\.?\d*[eE][-+]?\d+|[-+]?\d+|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""
12
+ NUMBER_REGEX_PATTERN = r"""(\d{2}|\d[a-zA-Z]\d|\d[a-zA-Z]|[a-zA-Z]\d+|\(|\))"""
13
+ # NUMBER_REGEX_PATTERN = r"""((?<!\d)\d{2}(?!\d)|\d[a-zA-Z]\d|\d[a-zA-Z]|[a-zA-Z]\d)"""
14
+ # NUMBER_REGEX_PATTERN = r"""(\d[a-zA-Z]\d|\d[a-zA-Z]|[a-zA-Z]\d|\b\d{2}\b)"""
15
+ SPECIAL_REGEX_PATTERN = r"""<(.*?)>"""
16
+
17
+
18
+ class MoleculePreTokenizer:
19
+
20
+ def molecule_based_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
21
+ splits = []
22
+ if str(normalized_string).startswith(('<smiles>', '<selfies>', '<polymer_spg>')):
23
+ for m in re.finditer(ATOM_REGEX_PATTERN, str(normalized_string)):
24
+ start = m.start(0)
25
+ stop = m.end(0)
26
+ if start == 0: # remove special tokens
27
+ continue
28
+ splits.append(normalized_string[start:stop])
29
+ elif str(normalized_string).startswith('<formulation_start>'):
30
+ for m in re.finditer(FORMULATION_REGEX_PATTERN, str(normalized_string)):
31
+ start = m.start(0)
32
+ stop = m.end(0)
33
+ if start == 0 or stop == len(str(normalized_string)): # remove special tokens
34
+ continue
35
+ splits.append(normalized_string[start:stop])
36
+ elif str(normalized_string).startswith(('<formula>', '<inchi>')):
37
+ for m in re.finditer(NUMBER_REGEX_PATTERN, str(normalized_string)):
38
+ start = m.start(0)
39
+ stop = m.end(0)
40
+ if start == 0: # remove special tokens
41
+ continue
42
+ splits.append(normalized_string[start:stop])
43
+ else:
44
+ last = 0
45
+ for m in re.finditer(SPECIAL_REGEX_PATTERN, str(normalized_string)): # remove special tokens
46
+ start = m.start(0)
47
+ stop = m.end(0)
48
+ # splits.append(normalized_string[start:stop])
49
+ last = stop
50
+ splits.append(normalized_string[last:])
51
+
52
+ return splits
53
+
54
+ def pre_tokenize(self, pretok: PreTokenizedString):
55
+ pretok.split(self.molecule_based_split)
56
+
57
+
58
+ class MultiMolTranBertTokenizer(PreTrainedTokenizerFast):
59
+ def __init__(self, vocab_file: str = '',
60
+ do_lower_case=False,
61
+ cls_token='<bos>',
62
+ eos_token='<sep>',
63
+ pad_token='<pad>',
64
+ unk_token='<unk>',
65
+ mask_token='<mask>',
66
+ **kwargs):
67
+
68
+ super().__init__(
69
+ tokenizer_file=vocab_file,
70
+ bos_token=cls_token,
71
+ eos_token=eos_token,
72
+ pad_token=pad_token,
73
+ unk_token=unk_token,
74
+ mask_token=mask_token
75
+ )
76
+
77
+ def get_padding_idx(self):
78
+ return 2
79
+
80
+ def convert_idx_to_tokens(self, idx_tensor):
81
+ tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()]
82
+ return tokens
83
+
84
+ def convert_tokens_to_string(self, tokens):
85
+ stopwords = ['<bos>', '<eos>']
86
+ clean_tokens = [word for word in tokens if word not in stopwords]
87
+ out_string = ''.join(clean_tokens)
88
+ return out_string
89
+
90
+ def idx_to_smiles(self, torch_model, idx):
91
+ '''Convert tokens idx back to SMILES text'''
92
+ rev_tokens = torch_model.tokenizer.convert_idx_to_tokens(idx)
93
+ flat_list_tokens = [item for sublist in rev_tokens for item in sublist]
94
+ decoded_smiles = torch_model.tokenizer.convert_tokens_to_string(flat_list_tokens)
95
+ return decoded_smiles
96
+
97
+
98
+ def load_tokenizer(vocab_file, **kwargs):
99
+ tokenizer = MultiMolTranBertTokenizer(vocab_file, **kwargs)
100
+ tokenizer.backend_tokenizer.pre_tokenizer = PreTokenizer.custom(MoleculePreTokenizer())
101
+ return tokenizer