victor-shirasuna commited on
Commit ·
3d83373
1
Parent(s): 94a0645
Upload files
Browse files- STR-Bamba_8.bin +3 -0
- STR-Bamba_8.pt +3 -0
- config.json +62 -0
- requirements.txt +12 -0
- str_bamba/bamba.py +534 -0
- str_bamba/bamba_config.py +28 -0
- str_bamba/bamba_modules.py +229 -0
- str_bamba/config/config_encoder-decoder_436M.json +62 -0
- str_bamba/generation.py +398 -0
- str_bamba/load.py +60 -0
- str_bamba/tokenizer/special_tokens.py +39 -0
- str_bamba/tokenizer/str_bamba_tokenizer.json +0 -0
- str_bamba/tokenizer/str_tokenizer.py +101 -0
STR-Bamba_8.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:db6d7a2561bfbaf9bd8a5f910321b2ff21671b6bc47cad955a323898203a9967
|
| 3 |
+
size 1372194320
|
STR-Bamba_8.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:db6d7a2561bfbaf9bd8a5f910321b2ff21671b6bc47cad955a323898203a9967
|
| 3 |
+
size 1372194320
|
config.json
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"encoder_config": {
|
| 3 |
+
"d_model": 1024,
|
| 4 |
+
"d_intermediate": 0,
|
| 5 |
+
"n_layer": 24,
|
| 6 |
+
"vocab_size": 5000,
|
| 7 |
+
"max_position_embeddings": 4096,
|
| 8 |
+
"ssm_cfg": {
|
| 9 |
+
"layer": "Mamba2"
|
| 10 |
+
},
|
| 11 |
+
"attn_layer_idx": [
|
| 12 |
+
6,
|
| 13 |
+
18
|
| 14 |
+
],
|
| 15 |
+
"attn_cfg": {
|
| 16 |
+
"causal": false,
|
| 17 |
+
"d_conv": 0,
|
| 18 |
+
"head_dim": 64,
|
| 19 |
+
"num_heads": 16,
|
| 20 |
+
"num_heads_kv": 8,
|
| 21 |
+
"out_proj_bias": false,
|
| 22 |
+
"qkv_proj_bias": false,
|
| 23 |
+
"rotary_emb_dim": 64
|
| 24 |
+
},
|
| 25 |
+
"rms_norm": true,
|
| 26 |
+
"residual_in_fp32": true,
|
| 27 |
+
"fused_add_norm": true,
|
| 28 |
+
"pad_vocab_size_multiple": 8,
|
| 29 |
+
"tie_embeddings": false
|
| 30 |
+
},
|
| 31 |
+
"decoder_config": {
|
| 32 |
+
"d_model": 1024,
|
| 33 |
+
"d_intermediate": 0,
|
| 34 |
+
"n_layer": 24,
|
| 35 |
+
"vocab_size": 5000,
|
| 36 |
+
"max_position_embeddings": 4096,
|
| 37 |
+
"ssm_cfg": {
|
| 38 |
+
"layer": "Mamba2"
|
| 39 |
+
},
|
| 40 |
+
"attn_layer_idx": [
|
| 41 |
+
6,
|
| 42 |
+
18
|
| 43 |
+
],
|
| 44 |
+
"attn_cfg": {
|
| 45 |
+
"causal": true,
|
| 46 |
+
"d_conv": 0,
|
| 47 |
+
"head_dim": 64,
|
| 48 |
+
"num_heads": 16,
|
| 49 |
+
"num_heads_kv": 8,
|
| 50 |
+
"out_proj_bias": false,
|
| 51 |
+
"qkv_proj_bias": false,
|
| 52 |
+
"rotary_emb_dim": 64
|
| 53 |
+
},
|
| 54 |
+
"rms_norm": true,
|
| 55 |
+
"residual_in_fp32": true,
|
| 56 |
+
"fused_add_norm": true,
|
| 57 |
+
"pad_vocab_size_multiple": 8,
|
| 58 |
+
"tie_embeddings": false
|
| 59 |
+
},
|
| 60 |
+
"tie_word_embeddings": true,
|
| 61 |
+
"seed": 0
|
| 62 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.26.4
|
| 2 |
+
pandas==2.2.3
|
| 3 |
+
scikit-learn>=1.6.1
|
| 4 |
+
datasets==3.5.0
|
| 5 |
+
transformers==4.52.1
|
| 6 |
+
tokenizers==0.21.1
|
| 7 |
+
deepspeed==0.16.7
|
| 8 |
+
einops==0.8.1
|
| 9 |
+
tqdm==4.67.1
|
| 10 |
+
torch-optimizer==0.3.0
|
| 11 |
+
rdkit>=2024.3.5
|
| 12 |
+
selfies>=2.2.0
|
str_bamba/bamba.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .generation import GenerationMixin
|
| 2 |
+
from mamba_ssm.modules.mamba2 import Mamba2
|
| 3 |
+
from mamba_ssm.modules.mha import MHA
|
| 4 |
+
from mamba_ssm.modules.mlp import GatedMLP
|
| 5 |
+
from mamba_ssm.modules.block import Block
|
| 6 |
+
from mamba_ssm.models.mixer_seq_simple import _init_weights
|
| 7 |
+
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
| 8 |
+
from .bamba_modules import BertEmbeddings, BertPooler, BertPreTrainingHeads, BlockCrossAttention
|
| 9 |
+
from .bamba_config import BambaConfig, BambaEncoderDecoderConfig
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 13 |
+
except ImportError:
|
| 14 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 15 |
+
|
| 16 |
+
from typing import List, Optional, Tuple, Union
|
| 17 |
+
from collections import namedtuple
|
| 18 |
+
import torch.backends.cudnn as cudnn
|
| 19 |
+
import math
|
| 20 |
+
import random
|
| 21 |
+
from functools import partial
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import copy
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
import pandas as pd
|
| 28 |
+
import numpy as np
|
| 29 |
+
import gc
|
| 30 |
+
from tqdm import tqdm
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def create_block(
|
| 34 |
+
d_model,
|
| 35 |
+
d_intermediate,
|
| 36 |
+
block_class,
|
| 37 |
+
ssm_cfg=None,
|
| 38 |
+
attn_layer_idx=None,
|
| 39 |
+
attn_cfg=None,
|
| 40 |
+
norm_epsilon=1e-5,
|
| 41 |
+
rms_norm=False,
|
| 42 |
+
residual_in_fp32=False,
|
| 43 |
+
fused_add_norm=False,
|
| 44 |
+
layer_idx=None,
|
| 45 |
+
device=None,
|
| 46 |
+
dtype=None,
|
| 47 |
+
):
|
| 48 |
+
if ssm_cfg is None:
|
| 49 |
+
ssm_cfg = {}
|
| 50 |
+
if attn_layer_idx is None:
|
| 51 |
+
attn_layer_idx = []
|
| 52 |
+
if attn_cfg is None:
|
| 53 |
+
attn_cfg = {}
|
| 54 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 55 |
+
if layer_idx not in attn_layer_idx:
|
| 56 |
+
# Create a copy of the config to modify
|
| 57 |
+
ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
|
| 58 |
+
ssm_layer = ssm_cfg.pop("layer", "Mamba1")
|
| 59 |
+
if ssm_layer not in ["Mamba1", "Mamba2"]:
|
| 60 |
+
raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
|
| 61 |
+
mixer_cls = partial(
|
| 62 |
+
Mamba2 if ssm_layer == "Mamba2" else Mamba,
|
| 63 |
+
layer_idx=layer_idx,
|
| 64 |
+
**ssm_cfg,
|
| 65 |
+
**factory_kwargs
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
|
| 69 |
+
norm_cls = partial(
|
| 70 |
+
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
| 71 |
+
)
|
| 72 |
+
if d_intermediate == 0:
|
| 73 |
+
mlp_cls = nn.Identity
|
| 74 |
+
else:
|
| 75 |
+
mlp_cls = partial(
|
| 76 |
+
GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
|
| 77 |
+
)
|
| 78 |
+
block = block_class(
|
| 79 |
+
d_model,
|
| 80 |
+
mixer_cls,
|
| 81 |
+
mlp_cls,
|
| 82 |
+
norm_cls=norm_cls,
|
| 83 |
+
fused_add_norm=fused_add_norm,
|
| 84 |
+
residual_in_fp32=residual_in_fp32,
|
| 85 |
+
)
|
| 86 |
+
if isinstance(block, BlockCrossAttention) and factory_kwargs["dtype"] is not None:
|
| 87 |
+
block.encoder_attn.type(factory_kwargs["dtype"]).to(factory_kwargs["device"])
|
| 88 |
+
block.layer_idx = layer_idx
|
| 89 |
+
return block
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class BambaMixerModel(nn.Module):
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
d_model: int,
|
| 96 |
+
n_layer: int,
|
| 97 |
+
d_intermediate: int,
|
| 98 |
+
vocab_size: int,
|
| 99 |
+
max_position_embeddings: int,
|
| 100 |
+
is_decoder: bool = False,
|
| 101 |
+
ssm_cfg=None,
|
| 102 |
+
attn_layer_idx=None,
|
| 103 |
+
attn_cfg=None,
|
| 104 |
+
norm_epsilon: float = 1e-5,
|
| 105 |
+
rms_norm: bool = False,
|
| 106 |
+
initializer_cfg=None,
|
| 107 |
+
fused_add_norm=False,
|
| 108 |
+
residual_in_fp32=False,
|
| 109 |
+
device=None,
|
| 110 |
+
dtype=None,
|
| 111 |
+
) -> None:
|
| 112 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 115 |
+
|
| 116 |
+
self.is_decoder = is_decoder
|
| 117 |
+
|
| 118 |
+
if is_decoder:
|
| 119 |
+
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
| 120 |
+
else:
|
| 121 |
+
self.embedding = BertEmbeddings(vocab_size, d_model, max_position_embeddings, **factory_kwargs)
|
| 122 |
+
|
| 123 |
+
# We change the order of residual and layer norm:
|
| 124 |
+
# Instead of LN -> Attn / MLP -> Add, we do:
|
| 125 |
+
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
| 126 |
+
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
| 127 |
+
# This is for performance reason: we can fuse add + layer_norm.
|
| 128 |
+
self.fused_add_norm = fused_add_norm
|
| 129 |
+
if self.fused_add_norm:
|
| 130 |
+
if layer_norm_fn is None or rms_norm_fn is None:
|
| 131 |
+
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
| 132 |
+
|
| 133 |
+
if is_decoder:
|
| 134 |
+
block_class = BlockCrossAttention
|
| 135 |
+
else:
|
| 136 |
+
block_class = Block
|
| 137 |
+
|
| 138 |
+
self.layers = nn.ModuleList(
|
| 139 |
+
[
|
| 140 |
+
create_block(
|
| 141 |
+
d_model,
|
| 142 |
+
d_intermediate=d_intermediate,
|
| 143 |
+
block_class=block_class,
|
| 144 |
+
ssm_cfg=ssm_cfg,
|
| 145 |
+
attn_layer_idx=attn_layer_idx,
|
| 146 |
+
attn_cfg=attn_cfg,
|
| 147 |
+
norm_epsilon=norm_epsilon,
|
| 148 |
+
rms_norm=rms_norm,
|
| 149 |
+
residual_in_fp32=residual_in_fp32,
|
| 150 |
+
fused_add_norm=fused_add_norm,
|
| 151 |
+
layer_idx=i,
|
| 152 |
+
**factory_kwargs,
|
| 153 |
+
)
|
| 154 |
+
for i in range(n_layer)
|
| 155 |
+
]
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
| 159 |
+
d_model, eps=norm_epsilon, **factory_kwargs
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
if not is_decoder:
|
| 163 |
+
self.pooler = BertPooler(d_model, **factory_kwargs)
|
| 164 |
+
|
| 165 |
+
self.apply(
|
| 166 |
+
partial(
|
| 167 |
+
_init_weights,
|
| 168 |
+
n_layer=n_layer,
|
| 169 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
| 170 |
+
n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP
|
| 171 |
+
)
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 175 |
+
return {
|
| 176 |
+
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
| 177 |
+
for i, layer in enumerate(self.layers)
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
def forward(self, input_ids, token_type_ids=None, inference_params=None, encoder_hidden_states=None, attention_mask=None, **mixer_kwargs):
|
| 181 |
+
if self.is_decoder:
|
| 182 |
+
hidden_states = self.embedding(input_ids)
|
| 183 |
+
else:
|
| 184 |
+
hidden_states = self.embedding(input_ids, token_type_ids)
|
| 185 |
+
residual = None
|
| 186 |
+
for layer in self.layers:
|
| 187 |
+
if self.is_decoder:
|
| 188 |
+
hidden_states, residual = layer(
|
| 189 |
+
hidden_states, residual, inference_params=inference_params, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **mixer_kwargs
|
| 190 |
+
)
|
| 191 |
+
else:
|
| 192 |
+
hidden_states, residual = layer(
|
| 193 |
+
hidden_states, residual, inference_params=inference_params, **mixer_kwargs
|
| 194 |
+
)
|
| 195 |
+
if not self.fused_add_norm:
|
| 196 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
| 197 |
+
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
| 198 |
+
else:
|
| 199 |
+
# Set prenorm=False here since we don't need the residual
|
| 200 |
+
hidden_states = layer_norm_fn(
|
| 201 |
+
hidden_states,
|
| 202 |
+
self.norm_f.weight,
|
| 203 |
+
self.norm_f.bias,
|
| 204 |
+
eps=self.norm_f.eps,
|
| 205 |
+
residual=residual,
|
| 206 |
+
prenorm=False,
|
| 207 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 208 |
+
is_rms_norm=isinstance(self.norm_f, RMSNorm)
|
| 209 |
+
)
|
| 210 |
+
if not self.is_decoder:
|
| 211 |
+
pooled_output = self.pooler(hidden_states)
|
| 212 |
+
return hidden_states, pooled_output
|
| 213 |
+
return hidden_states
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class BambaEncoder(nn.Module):
|
| 217 |
+
|
| 218 |
+
def __init__(
|
| 219 |
+
self,
|
| 220 |
+
config: BambaConfig,
|
| 221 |
+
initializer_cfg=None,
|
| 222 |
+
device=None,
|
| 223 |
+
dtype=None,
|
| 224 |
+
) -> None:
|
| 225 |
+
self.config = config
|
| 226 |
+
d_model = config.d_model
|
| 227 |
+
n_layer = config.n_layer
|
| 228 |
+
d_intermediate = config.d_intermediate
|
| 229 |
+
vocab_size = config.vocab_size
|
| 230 |
+
max_position_embeddings = config.max_position_embeddings
|
| 231 |
+
ssm_cfg = config.ssm_cfg
|
| 232 |
+
attn_layer_idx = config.attn_layer_idx
|
| 233 |
+
attn_cfg = config.attn_cfg
|
| 234 |
+
rms_norm = config.rms_norm
|
| 235 |
+
residual_in_fp32 = config.residual_in_fp32
|
| 236 |
+
fused_add_norm = config.fused_add_norm
|
| 237 |
+
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
| 238 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 239 |
+
|
| 240 |
+
super().__init__()
|
| 241 |
+
if vocab_size % pad_vocab_size_multiple != 0:
|
| 242 |
+
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
| 243 |
+
self.backbone = BambaMixerModel(
|
| 244 |
+
d_model=d_model,
|
| 245 |
+
n_layer=n_layer,
|
| 246 |
+
d_intermediate=d_intermediate,
|
| 247 |
+
vocab_size=vocab_size,
|
| 248 |
+
max_position_embeddings=max_position_embeddings,
|
| 249 |
+
is_decoder=False,
|
| 250 |
+
ssm_cfg=ssm_cfg,
|
| 251 |
+
attn_layer_idx=attn_layer_idx,
|
| 252 |
+
attn_cfg=attn_cfg,
|
| 253 |
+
rms_norm=rms_norm,
|
| 254 |
+
initializer_cfg=initializer_cfg,
|
| 255 |
+
fused_add_norm=fused_add_norm,
|
| 256 |
+
residual_in_fp32=residual_in_fp32,
|
| 257 |
+
**factory_kwargs,
|
| 258 |
+
)
|
| 259 |
+
self.cls = BertPreTrainingHeads(vocab_size, d_model, **factory_kwargs)
|
| 260 |
+
|
| 261 |
+
# Initialize weights and apply final processing
|
| 262 |
+
self.apply(
|
| 263 |
+
partial(
|
| 264 |
+
_init_weights,
|
| 265 |
+
n_layer=n_layer,
|
| 266 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
| 267 |
+
)
|
| 268 |
+
)
|
| 269 |
+
self.tie_weights()
|
| 270 |
+
|
| 271 |
+
def tie_weights(self):
|
| 272 |
+
if self.config.tie_embeddings:
|
| 273 |
+
self.lm_head.weight = self.backbone.embedding.weight
|
| 274 |
+
|
| 275 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 276 |
+
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
| 277 |
+
|
| 278 |
+
def forward(self, input_ids, token_type_ids=None, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
|
| 279 |
+
"""
|
| 280 |
+
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
| 281 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
| 282 |
+
"""
|
| 283 |
+
hidden_states, pooled_output = self.backbone(input_ids, token_type_ids, inference_params=inference_params, **mixer_kwargs)
|
| 284 |
+
if num_last_tokens > 0:
|
| 285 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
| 286 |
+
lm_logits, seq_relationship_score = self.cls(hidden_states, pooled_output)
|
| 287 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "seq_relationship_logits", "hidden_states"])
|
| 288 |
+
return CausalLMOutput(logits=lm_logits, seq_relationship_logits=seq_relationship_score, hidden_states=hidden_states)
|
| 289 |
+
|
| 290 |
+
@classmethod
|
| 291 |
+
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
| 292 |
+
config_data = load_config_hf(pretrained_model_name)
|
| 293 |
+
config = MambaConfig(**config_data)
|
| 294 |
+
model = cls(config, device=device, dtype=dtype, **kwargs)
|
| 295 |
+
model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
|
| 296 |
+
return model
|
| 297 |
+
|
| 298 |
+
def save_pretrained(self, save_directory):
|
| 299 |
+
"""
|
| 300 |
+
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
| 301 |
+
Save the model and its configuration file to a directory.
|
| 302 |
+
"""
|
| 303 |
+
# Ensure save_directory exists
|
| 304 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 305 |
+
|
| 306 |
+
# Save the model's state_dict
|
| 307 |
+
model_path = os.path.join(save_directory, 'pytorch_model.bin')
|
| 308 |
+
torch.save(self.state_dict(), model_path)
|
| 309 |
+
|
| 310 |
+
# Save the configuration of the model
|
| 311 |
+
config_path = os.path.join(save_directory, 'config.json')
|
| 312 |
+
with open(config_path, 'w') as f:
|
| 313 |
+
json.dump(self.config.__dict__, f, indent=4)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class BambaDecoder(nn.Module, GenerationMixin):
|
| 317 |
+
|
| 318 |
+
def __init__(
|
| 319 |
+
self,
|
| 320 |
+
config: BambaConfig,
|
| 321 |
+
initializer_cfg=None,
|
| 322 |
+
device=None,
|
| 323 |
+
dtype=None,
|
| 324 |
+
) -> None:
|
| 325 |
+
self.config = config
|
| 326 |
+
d_model = config.d_model
|
| 327 |
+
n_layer = config.n_layer
|
| 328 |
+
d_intermediate = config.d_intermediate
|
| 329 |
+
vocab_size = config.vocab_size
|
| 330 |
+
max_position_embeddings = config.max_position_embeddings
|
| 331 |
+
ssm_cfg = config.ssm_cfg
|
| 332 |
+
attn_layer_idx = config.attn_layer_idx
|
| 333 |
+
attn_cfg = config.attn_cfg
|
| 334 |
+
rms_norm = config.rms_norm
|
| 335 |
+
residual_in_fp32 = config.residual_in_fp32
|
| 336 |
+
fused_add_norm = config.fused_add_norm
|
| 337 |
+
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
| 338 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 339 |
+
|
| 340 |
+
super().__init__()
|
| 341 |
+
if vocab_size % pad_vocab_size_multiple != 0:
|
| 342 |
+
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
| 343 |
+
self.backbone = BambaMixerModel(
|
| 344 |
+
d_model=d_model,
|
| 345 |
+
n_layer=n_layer,
|
| 346 |
+
d_intermediate=d_intermediate,
|
| 347 |
+
vocab_size=vocab_size,
|
| 348 |
+
max_position_embeddings=max_position_embeddings,
|
| 349 |
+
is_decoder=True,
|
| 350 |
+
ssm_cfg=ssm_cfg,
|
| 351 |
+
attn_layer_idx=attn_layer_idx,
|
| 352 |
+
attn_cfg=attn_cfg,
|
| 353 |
+
rms_norm=rms_norm,
|
| 354 |
+
initializer_cfg=initializer_cfg,
|
| 355 |
+
fused_add_norm=fused_add_norm,
|
| 356 |
+
residual_in_fp32=residual_in_fp32,
|
| 357 |
+
**factory_kwargs,
|
| 358 |
+
)
|
| 359 |
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
| 360 |
+
|
| 361 |
+
# Initialize weights and apply final processing
|
| 362 |
+
self.apply(
|
| 363 |
+
partial(
|
| 364 |
+
_init_weights,
|
| 365 |
+
n_layer=n_layer,
|
| 366 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
| 367 |
+
)
|
| 368 |
+
)
|
| 369 |
+
self.tie_weights()
|
| 370 |
+
|
| 371 |
+
def tie_weights(self):
|
| 372 |
+
if self.config.tie_embeddings:
|
| 373 |
+
self.lm_head.weight = self.backbone.embedding.weight
|
| 374 |
+
|
| 375 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 376 |
+
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
| 377 |
+
|
| 378 |
+
def forward(self, input_ids, token_type_ids=None, position_ids=None, inference_params=None, num_last_tokens=0, encoder_hidden_states=None, attention_mask=None, **mixer_kwargs):
|
| 379 |
+
"""
|
| 380 |
+
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
| 381 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
| 382 |
+
"""
|
| 383 |
+
hidden_states = self.backbone(
|
| 384 |
+
input_ids, token_type_ids, inference_params=inference_params, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **mixer_kwargs
|
| 385 |
+
)
|
| 386 |
+
if num_last_tokens > 0:
|
| 387 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
| 388 |
+
lm_logits = self.lm_head(hidden_states)
|
| 389 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
| 390 |
+
return CausalLMOutput(logits=lm_logits)
|
| 391 |
+
|
| 392 |
+
@classmethod
|
| 393 |
+
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
| 394 |
+
config_data = load_config_hf(pretrained_model_name)
|
| 395 |
+
config = MambaConfig(**config_data)
|
| 396 |
+
model = cls(config, device=device, dtype=dtype, **kwargs)
|
| 397 |
+
model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
|
| 398 |
+
return model
|
| 399 |
+
|
| 400 |
+
def save_pretrained(self, save_directory):
|
| 401 |
+
"""
|
| 402 |
+
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
| 403 |
+
Save the model and its configuration file to a directory.
|
| 404 |
+
"""
|
| 405 |
+
# Ensure save_directory exists
|
| 406 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 407 |
+
|
| 408 |
+
# Save the model's state_dict
|
| 409 |
+
model_path = os.path.join(save_directory, 'pytorch_model.bin')
|
| 410 |
+
torch.save(self.state_dict(), model_path)
|
| 411 |
+
|
| 412 |
+
# Save the configuration of the model
|
| 413 |
+
config_path = os.path.join(save_directory, 'config.json')
|
| 414 |
+
with open(config_path, 'w') as f:
|
| 415 |
+
json.dump(self.config.__dict__, f, indent=4)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
class BambaEncoderDecoder(nn.Module, GenerationMixin):
|
| 419 |
+
|
| 420 |
+
def __init__(
|
| 421 |
+
self,
|
| 422 |
+
config: BambaEncoderDecoderConfig,
|
| 423 |
+
tokenizer=None,
|
| 424 |
+
initializer_cfg=None,
|
| 425 |
+
device=None,
|
| 426 |
+
dtype=None,
|
| 427 |
+
) -> None:
|
| 428 |
+
self.config = config
|
| 429 |
+
self.encoder_config = config.encoder_config
|
| 430 |
+
self.decoder_config = config.decoder_config
|
| 431 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 432 |
+
self.tokenizer = tokenizer
|
| 433 |
+
|
| 434 |
+
super().__init__()
|
| 435 |
+
self.encoder = BambaEncoder(self.encoder_config, **factory_kwargs)
|
| 436 |
+
self.decoder = BambaDecoder(self.decoder_config, **factory_kwargs)
|
| 437 |
+
|
| 438 |
+
self.device = device
|
| 439 |
+
|
| 440 |
+
self.tie_weights()
|
| 441 |
+
self._set_seed(config.seed)
|
| 442 |
+
|
| 443 |
+
def tie_weights(self):
|
| 444 |
+
if self.config.tie_word_embeddings:
|
| 445 |
+
self.decoder.backbone.embedding.weight = self.encoder.backbone.embedding.word_embeddings.weight
|
| 446 |
+
|
| 447 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 448 |
+
return self.decoder.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
| 449 |
+
|
| 450 |
+
def forward(self, encoder_input_ids, decoder_input_ids, token_type_ids=None, attention_mask=None, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
|
| 451 |
+
"""
|
| 452 |
+
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
| 453 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
| 454 |
+
"""
|
| 455 |
+
encoder_hidden_states = self.encoder(encoder_input_ids, inference_params=inference_params, **mixer_kwargs).hidden_states
|
| 456 |
+
lm_logits = self.decoder(decoder_input_ids, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, inference_params=inference_params, **mixer_kwargs).logits
|
| 457 |
+
if num_last_tokens > 0:
|
| 458 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
| 459 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
| 460 |
+
return CausalLMOutput(logits=lm_logits)
|
| 461 |
+
|
| 462 |
+
@classmethod
|
| 463 |
+
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
| 464 |
+
config_data = load_config_hf(pretrained_model_name)
|
| 465 |
+
config = MambaConfig(**config_data)
|
| 466 |
+
model = cls(config, device=device, dtype=dtype, **kwargs)
|
| 467 |
+
model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
|
| 468 |
+
return model
|
| 469 |
+
|
| 470 |
+
def save_pretrained(self, save_directory):
|
| 471 |
+
"""
|
| 472 |
+
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
| 473 |
+
Save the model and its configuration file to a directory.
|
| 474 |
+
"""
|
| 475 |
+
# Ensure save_directory exists
|
| 476 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 477 |
+
|
| 478 |
+
# Save the model's state_dict
|
| 479 |
+
model_path = os.path.join(save_directory, 'pytorch_model.bin')
|
| 480 |
+
torch.save(self.state_dict(), model_path)
|
| 481 |
+
|
| 482 |
+
# Save the configuration of the model
|
| 483 |
+
config_path = os.path.join(save_directory, 'config.json')
|
| 484 |
+
with open(config_path, 'w') as f:
|
| 485 |
+
json.dump(self.config.__dict__, f, indent=4)
|
| 486 |
+
|
| 487 |
+
def _set_seed(self, value):
|
| 488 |
+
print('Random Seed:', value)
|
| 489 |
+
random.seed(value)
|
| 490 |
+
torch.manual_seed(value)
|
| 491 |
+
torch.cuda.manual_seed(value)
|
| 492 |
+
torch.cuda.manual_seed_all(value)
|
| 493 |
+
np.random.seed(value)
|
| 494 |
+
cudnn.deterministic = True
|
| 495 |
+
cudnn.benchmark = False
|
| 496 |
+
|
| 497 |
+
def extract_embeddings(self, smiles):
|
| 498 |
+
tokens = self.tokenizer(smiles, padding=True, truncation=True, return_tensors='pt')
|
| 499 |
+
|
| 500 |
+
idx = tokens['input_ids'].to(self.device)
|
| 501 |
+
mask = tokens['attention_mask'].to(self.device)
|
| 502 |
+
outputs = self.encoder(input_ids=idx)
|
| 503 |
+
hidden_states = outputs.hidden_states
|
| 504 |
+
|
| 505 |
+
token_embeddings = hidden_states
|
| 506 |
+
input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 507 |
+
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
|
| 508 |
+
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 509 |
+
embeddings = sum_embeddings / sum_mask
|
| 510 |
+
|
| 511 |
+
return embeddings
|
| 512 |
+
|
| 513 |
+
def encode(self, smiles, useCuda=False, batch_size=100, return_torch=False):
|
| 514 |
+
"""Extract efficiently SMILES embeddings per batches."""
|
| 515 |
+
# TODO: remove useCuda argument
|
| 516 |
+
|
| 517 |
+
# handle single str or a list of str
|
| 518 |
+
smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles))
|
| 519 |
+
|
| 520 |
+
# process in batches
|
| 521 |
+
n_split = smiles.shape[0] // batch_size if smiles.shape[0] >= batch_size else smiles.shape[0]
|
| 522 |
+
embeddings = [
|
| 523 |
+
self.extract_embeddings(list(batch)).cpu().detach().numpy()
|
| 524 |
+
for batch in tqdm(np.array_split(smiles, n_split))
|
| 525 |
+
]
|
| 526 |
+
flat_list = [item for sublist in embeddings for item in sublist]
|
| 527 |
+
|
| 528 |
+
# clear GPU memory
|
| 529 |
+
torch.cuda.empty_cache()
|
| 530 |
+
gc.collect()
|
| 531 |
+
|
| 532 |
+
if return_torch:
|
| 533 |
+
return torch.tensor(flat_list)
|
| 534 |
+
return pd.DataFrame(flat_list)
|
str_bamba/bamba_config.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class BambaConfig:
|
| 6 |
+
|
| 7 |
+
d_model: int = 2560
|
| 8 |
+
d_intermediate: int = 0
|
| 9 |
+
n_layer: int = 64
|
| 10 |
+
vocab_size: int = 50277
|
| 11 |
+
max_position_embeddings: int = 262144
|
| 12 |
+
ssm_cfg: dict = field(default_factory=dict)
|
| 13 |
+
attn_layer_idx: list = field(default_factory=list)
|
| 14 |
+
attn_cfg: dict = field(default_factory=dict)
|
| 15 |
+
rms_norm: bool = True
|
| 16 |
+
residual_in_fp32: bool = True
|
| 17 |
+
fused_add_norm: bool = True
|
| 18 |
+
pad_vocab_size_multiple: int = 8
|
| 19 |
+
tie_embeddings: bool = True
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class BambaEncoderDecoderConfig:
|
| 24 |
+
|
| 25 |
+
encoder_config: BambaConfig = None
|
| 26 |
+
decoder_config: BambaConfig = None
|
| 27 |
+
tie_word_embeddings: bool = True
|
| 28 |
+
seed: int = 0
|
str_bamba/bamba_modules.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn, Tensor
|
| 5 |
+
|
| 6 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
|
| 7 |
+
from transformers.models.bart.modeling_bart import BartSdpaAttention
|
| 8 |
+
from transformers.activations import ACT2FN
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BertEmbeddings(nn.Module):
|
| 12 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 13 |
+
|
| 14 |
+
def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size=2, pad_token_id=2, layer_norm_eps=1e-12, hidden_dropout_prob=0.1, device=None, dtype=None):
|
| 15 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id, **factory_kwargs)
|
| 18 |
+
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size, **factory_kwargs)
|
| 19 |
+
|
| 20 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 21 |
+
# any TensorFlow checkpoint file
|
| 22 |
+
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, **factory_kwargs)
|
| 23 |
+
self.dropout = nn.Dropout(hidden_dropout_prob)
|
| 24 |
+
# self.position_embedding_type = "rotary"
|
| 25 |
+
# self.register_buffer(
|
| 26 |
+
# "position_ids", torch.arange(max_position_embeddings).expand((1, -1)), persistent=False
|
| 27 |
+
# )
|
| 28 |
+
# self.register_buffer(
|
| 29 |
+
# "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
| 30 |
+
# )
|
| 31 |
+
|
| 32 |
+
def forward(
|
| 33 |
+
self,
|
| 34 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 35 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 36 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 37 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 38 |
+
past_key_values_length: int = 0,
|
| 39 |
+
) -> torch.Tensor:
|
| 40 |
+
if input_ids is not None:
|
| 41 |
+
input_shape = input_ids.size()
|
| 42 |
+
else:
|
| 43 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 44 |
+
|
| 45 |
+
seq_length = input_shape[1]
|
| 46 |
+
|
| 47 |
+
# if position_ids is None:
|
| 48 |
+
# position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
| 49 |
+
|
| 50 |
+
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
| 51 |
+
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
| 52 |
+
# issue #5664
|
| 53 |
+
if token_type_ids is None:
|
| 54 |
+
# if hasattr(self, "token_type_ids"):
|
| 55 |
+
# import ipdb; ipdb.set_trace()
|
| 56 |
+
# buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
| 57 |
+
# buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
| 58 |
+
# token_type_ids = buffered_token_type_ids_expanded
|
| 59 |
+
# else:
|
| 60 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
|
| 61 |
+
|
| 62 |
+
if inputs_embeds is None:
|
| 63 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 64 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 65 |
+
|
| 66 |
+
embeddings = inputs_embeds + token_type_embeddings
|
| 67 |
+
embeddings = self.LayerNorm(embeddings)
|
| 68 |
+
embeddings = self.dropout(embeddings)
|
| 69 |
+
return embeddings
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class BertPooler(nn.Module):
|
| 73 |
+
def __init__(self, hidden_size, device=None, dtype=None):
|
| 74 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.dense = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
| 77 |
+
self.activation = nn.Tanh()
|
| 78 |
+
|
| 79 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 80 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 81 |
+
# to the first token.
|
| 82 |
+
first_token_tensor = hidden_states[:, 0]
|
| 83 |
+
pooled_output = self.dense(first_token_tensor)
|
| 84 |
+
pooled_output = self.activation(pooled_output)
|
| 85 |
+
return pooled_output
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class BertPredictionHeadTransform(nn.Module):
|
| 89 |
+
def __init__(self, hidden_size, hidden_act="gelu", layer_norm_eps=1e-12, device=None, dtype=None):
|
| 90 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.dense = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
| 93 |
+
if isinstance(hidden_act, str):
|
| 94 |
+
self.transform_act_fn = ACT2FN[hidden_act]
|
| 95 |
+
else:
|
| 96 |
+
self.transform_act_fn = hidden_act
|
| 97 |
+
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, **factory_kwargs)
|
| 98 |
+
|
| 99 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
hidden_states = self.dense(hidden_states)
|
| 101 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 102 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 103 |
+
return hidden_states
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class BertLMPredictionHead(nn.Module):
|
| 107 |
+
def __init__(self, vocab_size, hidden_size, hidden_act="gelu", layer_norm_eps=1e-12, device=None, dtype=None):
|
| 108 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.transform = BertPredictionHeadTransform(hidden_size, hidden_act, layer_norm_eps, **factory_kwargs)
|
| 111 |
+
|
| 112 |
+
# The output weights are the same as the input embeddings, but there is
|
| 113 |
+
# an output-only bias for each token.
|
| 114 |
+
self.decoder = nn.Linear(hidden_size, vocab_size, bias=False, **factory_kwargs)
|
| 115 |
+
|
| 116 |
+
self.bias = nn.Parameter(torch.zeros(vocab_size, **factory_kwargs))
|
| 117 |
+
|
| 118 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 119 |
+
self.decoder.bias = self.bias
|
| 120 |
+
|
| 121 |
+
def _tie_weights(self):
|
| 122 |
+
self.decoder.bias = self.bias
|
| 123 |
+
|
| 124 |
+
def forward(self, hidden_states):
|
| 125 |
+
hidden_states = self.transform(hidden_states)
|
| 126 |
+
hidden_states = self.decoder(hidden_states)
|
| 127 |
+
return hidden_states
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class BertPreTrainingHeads(nn.Module):
|
| 131 |
+
def __init__(self, vocab_size, hidden_size, hidden_act="gelu", layer_norm_eps=1e-12, device=None, dtype=None):
|
| 132 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.predictions = BertLMPredictionHead(vocab_size, hidden_size, hidden_act, layer_norm_eps, **factory_kwargs)
|
| 135 |
+
self.seq_relationship = nn.Linear(hidden_size, 2, **factory_kwargs)
|
| 136 |
+
|
| 137 |
+
def forward(self, sequence_output, pooled_output):
|
| 138 |
+
prediction_scores = self.predictions(sequence_output)
|
| 139 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 140 |
+
return prediction_scores, seq_relationship_score
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class BlockCrossAttention(nn.Module):
|
| 145 |
+
def __init__(
|
| 146 |
+
self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
| 147 |
+
):
|
| 148 |
+
"""
|
| 149 |
+
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
| 150 |
+
|
| 151 |
+
This Block has a slightly different structure compared to a regular
|
| 152 |
+
prenorm Transformer block.
|
| 153 |
+
The standard block is: LN -> MHA/MLP -> Add.
|
| 154 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
| 155 |
+
Here we have: Add -> LN -> Mixer, returning both
|
| 156 |
+
the hidden_states (output of the mixer) and the residual.
|
| 157 |
+
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
| 158 |
+
The residual needs to be provided (except for the very first block).
|
| 159 |
+
"""
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 162 |
+
self.fused_add_norm = fused_add_norm
|
| 163 |
+
self.norm = norm_cls(dim)
|
| 164 |
+
self.mixer = mixer_cls(dim)
|
| 165 |
+
self.encoder_attn = BartSdpaAttention(embed_dim=dim, num_heads=1)
|
| 166 |
+
if mlp_cls is not nn.Identity:
|
| 167 |
+
self.norm2 = norm_cls(dim)
|
| 168 |
+
self.mlp = mlp_cls(dim)
|
| 169 |
+
else:
|
| 170 |
+
self.mlp = None
|
| 171 |
+
if self.fused_add_norm:
|
| 172 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
| 173 |
+
assert isinstance(
|
| 174 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
| 175 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
| 176 |
+
|
| 177 |
+
def forward(
|
| 178 |
+
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, encoder_hidden_states=None, attention_mask=None, **mixer_kwargs
|
| 179 |
+
):
|
| 180 |
+
r"""Pass the input through the encoder layer.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
hidden_states: the sequence to the encoder layer (required).
|
| 184 |
+
residual: hidden_states = Mixer(LN(residual))
|
| 185 |
+
"""
|
| 186 |
+
if not self.fused_add_norm:
|
| 187 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
| 188 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
| 189 |
+
if self.residual_in_fp32:
|
| 190 |
+
residual = residual.to(torch.float32)
|
| 191 |
+
else:
|
| 192 |
+
hidden_states, residual = layer_norm_fn(
|
| 193 |
+
hidden_states,
|
| 194 |
+
self.norm.weight,
|
| 195 |
+
self.norm.bias,
|
| 196 |
+
residual=residual,
|
| 197 |
+
prenorm=True,
|
| 198 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 199 |
+
eps=self.norm.eps,
|
| 200 |
+
is_rms_norm=isinstance(self.norm, RMSNorm)
|
| 201 |
+
)
|
| 202 |
+
hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
|
| 203 |
+
|
| 204 |
+
# cross-attention
|
| 205 |
+
hidden_states, _, _ = self.encoder_attn(hidden_states, encoder_hidden_states, attention_mask=attention_mask)
|
| 206 |
+
|
| 207 |
+
if self.mlp is not None:
|
| 208 |
+
if not self.fused_add_norm:
|
| 209 |
+
residual = hidden_states + residual
|
| 210 |
+
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
| 211 |
+
if self.residual_in_fp32:
|
| 212 |
+
residual = residual.to(torch.float32)
|
| 213 |
+
else:
|
| 214 |
+
hidden_states, residual = layer_norm_fn(
|
| 215 |
+
hidden_states,
|
| 216 |
+
self.norm2.weight,
|
| 217 |
+
self.norm2.bias,
|
| 218 |
+
residual=residual,
|
| 219 |
+
prenorm=True,
|
| 220 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 221 |
+
eps=self.norm2.eps,
|
| 222 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm)
|
| 223 |
+
)
|
| 224 |
+
hidden_states = self.mlp(hidden_states)
|
| 225 |
+
|
| 226 |
+
return hidden_states, residual
|
| 227 |
+
|
| 228 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 229 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
str_bamba/config/config_encoder-decoder_436M.json
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"encoder_config": {
|
| 3 |
+
"d_model": 1024,
|
| 4 |
+
"d_intermediate": 0,
|
| 5 |
+
"n_layer": 24,
|
| 6 |
+
"vocab_size": 5000,
|
| 7 |
+
"max_position_embeddings": 4096,
|
| 8 |
+
"ssm_cfg": {
|
| 9 |
+
"layer": "Mamba2"
|
| 10 |
+
},
|
| 11 |
+
"attn_layer_idx": [
|
| 12 |
+
6,
|
| 13 |
+
18
|
| 14 |
+
],
|
| 15 |
+
"attn_cfg": {
|
| 16 |
+
"causal": false,
|
| 17 |
+
"d_conv": 0,
|
| 18 |
+
"head_dim": 64,
|
| 19 |
+
"num_heads": 16,
|
| 20 |
+
"num_heads_kv": 8,
|
| 21 |
+
"out_proj_bias": false,
|
| 22 |
+
"qkv_proj_bias": false,
|
| 23 |
+
"rotary_emb_dim": 64
|
| 24 |
+
},
|
| 25 |
+
"rms_norm": true,
|
| 26 |
+
"residual_in_fp32": true,
|
| 27 |
+
"fused_add_norm": true,
|
| 28 |
+
"pad_vocab_size_multiple": 8,
|
| 29 |
+
"tie_embeddings": false
|
| 30 |
+
},
|
| 31 |
+
"decoder_config": {
|
| 32 |
+
"d_model": 1024,
|
| 33 |
+
"d_intermediate": 0,
|
| 34 |
+
"n_layer": 24,
|
| 35 |
+
"vocab_size": 5000,
|
| 36 |
+
"max_position_embeddings": 4096,
|
| 37 |
+
"ssm_cfg": {
|
| 38 |
+
"layer": "Mamba2"
|
| 39 |
+
},
|
| 40 |
+
"attn_layer_idx": [
|
| 41 |
+
6,
|
| 42 |
+
18
|
| 43 |
+
],
|
| 44 |
+
"attn_cfg": {
|
| 45 |
+
"causal": true,
|
| 46 |
+
"d_conv": 0,
|
| 47 |
+
"head_dim": 64,
|
| 48 |
+
"num_heads": 16,
|
| 49 |
+
"num_heads_kv": 8,
|
| 50 |
+
"out_proj_bias": false,
|
| 51 |
+
"qkv_proj_bias": false,
|
| 52 |
+
"rotary_emb_dim": 64
|
| 53 |
+
},
|
| 54 |
+
"rms_norm": true,
|
| 55 |
+
"residual_in_fp32": true,
|
| 56 |
+
"fused_add_norm": true,
|
| 57 |
+
"pad_vocab_size_multiple": 8,
|
| 58 |
+
"tie_embeddings": false
|
| 59 |
+
},
|
| 60 |
+
"tie_word_embeddings": true,
|
| 61 |
+
"seed": 0
|
| 62 |
+
}
|
str_bamba/generation.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
| 2 |
+
import gc
|
| 3 |
+
import time
|
| 4 |
+
from collections import namedtuple
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from functools import partial
|
| 7 |
+
from typing import Callable, Optional, Sequence, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from einops import rearrange, repeat
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
from torch.profiler import ProfilerActivity, profile, record_function
|
| 14 |
+
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class InferenceParams:
|
| 19 |
+
"""Inference parameters that are passed to the main model in order
|
| 20 |
+
to efficienly calculate and store the context during inference."""
|
| 21 |
+
|
| 22 |
+
max_seqlen: int
|
| 23 |
+
max_batch_size: int
|
| 24 |
+
seqlen_offset: int = 0
|
| 25 |
+
batch_size_offset: int = 0
|
| 26 |
+
key_value_memory_dict: dict = field(default_factory=dict)
|
| 27 |
+
lengths_per_sample: Optional[Tensor] = None
|
| 28 |
+
|
| 29 |
+
def reset(self, max_seqlen, max_batch_size):
|
| 30 |
+
self.max_seqlen = max_seqlen
|
| 31 |
+
self.max_batch_size = max_batch_size
|
| 32 |
+
self.seqlen_offset = 0
|
| 33 |
+
if self.lengths_per_sample is not None:
|
| 34 |
+
self.lengths_per_sample.zero_()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def modify_logits_for_min_p_filtering(logits, min_p):
|
| 38 |
+
"""Set the logits for none min_p values to -inf. Done in-place."""
|
| 39 |
+
if min_p <= 0.0 or min_p >= 1.0:
|
| 40 |
+
return
|
| 41 |
+
indices_to_remove = logits < min_p
|
| 42 |
+
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
| 43 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
| 44 |
+
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
|
| 45 |
+
def modify_logits_for_top_k_filtering(logits, top_k):
|
| 46 |
+
"""Set the logits for none top-k values to -inf. Done in-place."""
|
| 47 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 48 |
+
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
| 52 |
+
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
| 53 |
+
def modify_logits_for_top_p_filtering(logits, top_p):
|
| 54 |
+
"""Set the logits for none top-p values to -inf. Done in-place."""
|
| 55 |
+
if top_p <= 0.0 or top_p >= 1.0:
|
| 56 |
+
return
|
| 57 |
+
# First sort and calculate cumulative sum of probabilities.
|
| 58 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
| 59 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
| 60 |
+
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
| 61 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
| 62 |
+
# scatter sorted tensors to original indexing
|
| 63 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 64 |
+
1, sorted_indices, sorted_indices_to_remove
|
| 65 |
+
)
|
| 66 |
+
logits.masked_fill_(indices_to_remove, float("-inf"))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
|
| 70 |
+
"""Apply repetition penalty. See https://arxiv.org/abs/1909.05858
|
| 71 |
+
logits: (batch_size, vocab_size)
|
| 72 |
+
prev_output_tokens: (batch_size, seq_len)
|
| 73 |
+
"""
|
| 74 |
+
if repetition_penalty == 1.0:
|
| 75 |
+
return logits
|
| 76 |
+
score = torch.gather(logits, 1, prev_output_tokens)
|
| 77 |
+
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
| 78 |
+
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
|
| 79 |
+
logits.scatter_(1, prev_output_tokens, score)
|
| 80 |
+
return logits
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
|
| 84 |
+
"""Sample from top-k logits.
|
| 85 |
+
Arguments:
|
| 86 |
+
logits: Tensor of shape (batch_size, vocab_size)
|
| 87 |
+
"""
|
| 88 |
+
if top_k == 1: # Short-circuit for greedy decoding
|
| 89 |
+
return logits.argmax(dim=-1)
|
| 90 |
+
else:
|
| 91 |
+
if top_p > 0.0:
|
| 92 |
+
assert top_p <= 1.0, "top-p should be in (0, 1]."
|
| 93 |
+
if top_k > 0:
|
| 94 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
| 95 |
+
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
| 96 |
+
if temperature != 1.0:
|
| 97 |
+
logits_top /= temperature
|
| 98 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
| 99 |
+
return indices[
|
| 100 |
+
torch.arange(indices.shape[0], device=indices.device),
|
| 101 |
+
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
|
| 102 |
+
]
|
| 103 |
+
else:
|
| 104 |
+
if min_p > 0.0:
|
| 105 |
+
logits_top = logits.clone()
|
| 106 |
+
max_prob = logits_top[..., 0].item()
|
| 107 |
+
min_prob = max_prob * min_p
|
| 108 |
+
modify_logits_for_min_p_filtering(logits_top, min_prob)
|
| 109 |
+
if temperature != 1.0:
|
| 110 |
+
logits_top /= temperature
|
| 111 |
+
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
| 112 |
+
# Clone so that when we modify for top_p we don't change the original logits
|
| 113 |
+
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
|
| 114 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
| 115 |
+
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
|
| 116 |
+
dim=-1
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@torch.inference_mode()
|
| 121 |
+
def decode(
|
| 122 |
+
input_ids,
|
| 123 |
+
encoder_hidden_states,
|
| 124 |
+
model,
|
| 125 |
+
max_length,
|
| 126 |
+
top_k=1,
|
| 127 |
+
top_p=0.0,
|
| 128 |
+
min_p=0.0,
|
| 129 |
+
temperature=1.0,
|
| 130 |
+
repetition_penalty=1.0,
|
| 131 |
+
eos_token_id=None,
|
| 132 |
+
teacher_outputs=None,
|
| 133 |
+
vocab_size=None,
|
| 134 |
+
cg=False,
|
| 135 |
+
enable_timing=False,
|
| 136 |
+
output_scores=False,
|
| 137 |
+
streamer: Optional[TextStreamer] = None
|
| 138 |
+
):
|
| 139 |
+
"""Decoding, either greedy or with top-k or top-p sampling.
|
| 140 |
+
If top-k = 0, don't limit the number of candidates (pure sampling).
|
| 141 |
+
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
| 142 |
+
then top-p.
|
| 143 |
+
We assume that all sequences in the same batch have the same length.
|
| 144 |
+
|
| 145 |
+
Arguments:
|
| 146 |
+
input_ids: (batch, seq_len)
|
| 147 |
+
max_length: int
|
| 148 |
+
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
|
| 149 |
+
logits, the next token is taken from the teacher_outputs. Useful for testing.
|
| 150 |
+
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
|
| 151 |
+
sequences: (batch, max_length)
|
| 152 |
+
scores: tuples of (batch, vocab_size)
|
| 153 |
+
"""
|
| 154 |
+
if streamer is not None:
|
| 155 |
+
streamer.put(input_ids.cpu())
|
| 156 |
+
|
| 157 |
+
batch_size, seqlen_og = input_ids.shape
|
| 158 |
+
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
|
| 159 |
+
if cg:
|
| 160 |
+
if not hasattr(model, "_decoding_cache"):
|
| 161 |
+
model._decoding_cache = None
|
| 162 |
+
model._decoding_cache = update_graph_cache(
|
| 163 |
+
model,
|
| 164 |
+
encoder_hidden_states,
|
| 165 |
+
model._decoding_cache,
|
| 166 |
+
batch_size,
|
| 167 |
+
seqlen_og,
|
| 168 |
+
max_length,
|
| 169 |
+
)
|
| 170 |
+
inference_params = model._decoding_cache.inference_params
|
| 171 |
+
inference_params.reset(max_length, batch_size)
|
| 172 |
+
else:
|
| 173 |
+
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
| 174 |
+
|
| 175 |
+
def get_logits(input_ids, inference_params):
|
| 176 |
+
decoding = inference_params.seqlen_offset > 0
|
| 177 |
+
if decoding:
|
| 178 |
+
position_ids = torch.full(
|
| 179 |
+
(batch_size, 1),
|
| 180 |
+
inference_params.seqlen_offset,
|
| 181 |
+
dtype=torch.long,
|
| 182 |
+
device=input_ids.device,
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
position_ids = None
|
| 186 |
+
if not cg or not decoding:
|
| 187 |
+
logits = model(
|
| 188 |
+
input_ids,
|
| 189 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 190 |
+
position_ids=position_ids,
|
| 191 |
+
inference_params=inference_params,
|
| 192 |
+
num_last_tokens=1,
|
| 193 |
+
).logits.squeeze(dim=1)
|
| 194 |
+
else:
|
| 195 |
+
logits = model._decoding_cache.run(
|
| 196 |
+
input_ids, position_ids, inference_params.seqlen_offset
|
| 197 |
+
).squeeze(dim=1)
|
| 198 |
+
return logits[..., :vocab_size] if vocab_size is not None else logits
|
| 199 |
+
|
| 200 |
+
def sample_tokens(logits, inference_params):
|
| 201 |
+
if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
|
| 202 |
+
token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
|
| 203 |
+
else:
|
| 204 |
+
token = teacher_outputs[:, inference_params.seqlen_offset]
|
| 205 |
+
# return rearrange(token, "b -> b 1")
|
| 206 |
+
return token.unsqueeze(1)
|
| 207 |
+
|
| 208 |
+
def should_stop(current_token, inference_params):
|
| 209 |
+
if inference_params.seqlen_offset == 0:
|
| 210 |
+
return False
|
| 211 |
+
if eos_token_id is not None and (current_token == eos_token_id).all():
|
| 212 |
+
return True
|
| 213 |
+
if inference_params.seqlen_offset >= max_length - 1:
|
| 214 |
+
return True
|
| 215 |
+
return False
|
| 216 |
+
|
| 217 |
+
start = torch.cuda.Event(enable_timing=enable_timing)
|
| 218 |
+
end = torch.cuda.Event(enable_timing=enable_timing)
|
| 219 |
+
|
| 220 |
+
if enable_timing:
|
| 221 |
+
start.record()
|
| 222 |
+
scores, sequences = [], [input_ids]
|
| 223 |
+
sequences_cat = input_ids
|
| 224 |
+
while not should_stop(sequences[-1], inference_params):
|
| 225 |
+
logits = get_logits(sequences[-1], inference_params)
|
| 226 |
+
if output_scores:
|
| 227 |
+
scores.append(logits.clone())
|
| 228 |
+
inference_params.seqlen_offset += sequences[-1].shape[1]
|
| 229 |
+
if repetition_penalty == 1.0:
|
| 230 |
+
sampled_tokens = sample_tokens(logits, inference_params)
|
| 231 |
+
else:
|
| 232 |
+
logits = modify_logit_for_repetition_penalty(
|
| 233 |
+
logits, sequences_cat, repetition_penalty
|
| 234 |
+
)
|
| 235 |
+
sampled_tokens = sample_tokens(logits, inference_params)
|
| 236 |
+
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
|
| 237 |
+
sequences.append(sampled_tokens)
|
| 238 |
+
if streamer is not None:
|
| 239 |
+
streamer.put(sampled_tokens.cpu())
|
| 240 |
+
if streamer is not None:
|
| 241 |
+
streamer.end()
|
| 242 |
+
if enable_timing:
|
| 243 |
+
end.record()
|
| 244 |
+
torch.cuda.synchronize()
|
| 245 |
+
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
|
| 246 |
+
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
| 247 |
+
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class GenerationMixin:
|
| 251 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 252 |
+
raise NotImplementedError
|
| 253 |
+
|
| 254 |
+
def generate(
|
| 255 |
+
self,
|
| 256 |
+
input_ids,
|
| 257 |
+
encoder_hidden_states,
|
| 258 |
+
max_length,
|
| 259 |
+
top_k=1,
|
| 260 |
+
top_p=0.0,
|
| 261 |
+
min_p=0.0,
|
| 262 |
+
temperature=1.0,
|
| 263 |
+
return_dict_in_generate=False,
|
| 264 |
+
output_scores=False,
|
| 265 |
+
**kwargs,
|
| 266 |
+
):
|
| 267 |
+
output = decode(
|
| 268 |
+
input_ids, encoder_hidden_states, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, output_scores=output_scores, **kwargs
|
| 269 |
+
)
|
| 270 |
+
if not output_scores:
|
| 271 |
+
output.scores = None
|
| 272 |
+
return output if return_dict_in_generate else output.sequences
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
@dataclass
|
| 276 |
+
class DecodingCGCache:
|
| 277 |
+
max_batch_size: int = 0
|
| 278 |
+
max_seqlen: int = 0
|
| 279 |
+
device = None
|
| 280 |
+
dtype = None
|
| 281 |
+
callables: dict = field(default_factory=dict)
|
| 282 |
+
mempool = None
|
| 283 |
+
inference_params: Optional[InferenceParams] = None
|
| 284 |
+
run: Optional[Callable] = None
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
@torch.inference_mode()
|
| 288 |
+
def update_graph_cache(
|
| 289 |
+
model,
|
| 290 |
+
encoder_hidden_states,
|
| 291 |
+
cache,
|
| 292 |
+
batch_size,
|
| 293 |
+
seqlen_og,
|
| 294 |
+
max_seqlen,
|
| 295 |
+
decoding_seqlens=(1,),
|
| 296 |
+
dtype=None,
|
| 297 |
+
n_warmups=2,
|
| 298 |
+
):
|
| 299 |
+
if cache is None:
|
| 300 |
+
cache = DecodingCGCache()
|
| 301 |
+
param_example = next(iter(model.parameters()))
|
| 302 |
+
device = param_example.device
|
| 303 |
+
if dtype is None:
|
| 304 |
+
dtype = param_example.dtype
|
| 305 |
+
if (
|
| 306 |
+
(device, dtype) != (cache.device, cache.dtype)
|
| 307 |
+
or batch_size > cache.max_batch_size
|
| 308 |
+
or max_seqlen > cache.max_seqlen
|
| 309 |
+
): # Invalidate the cache
|
| 310 |
+
cache.callables = {}
|
| 311 |
+
cache.mempool = None
|
| 312 |
+
cache.inference_params = None
|
| 313 |
+
gc.collect()
|
| 314 |
+
cache.device, cache.dtype = device, dtype
|
| 315 |
+
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
|
| 316 |
+
assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
|
| 317 |
+
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
|
| 318 |
+
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
|
| 319 |
+
cache.inference_params = InferenceParams(
|
| 320 |
+
max_seqlen=max_seqlen,
|
| 321 |
+
max_batch_size=batch_size,
|
| 322 |
+
seqlen_offset=seqlen_og,
|
| 323 |
+
key_value_memory_dict=inf_cache,
|
| 324 |
+
lengths_per_sample=lengths_per_sample,
|
| 325 |
+
)
|
| 326 |
+
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
| 327 |
+
for decoding_seqlen in decoding_seqlens:
|
| 328 |
+
if (batch_size, decoding_seqlen) not in cache.callables:
|
| 329 |
+
cache.callables[batch_size, decoding_seqlen] = capture_graph(
|
| 330 |
+
model,
|
| 331 |
+
encoder_hidden_states,
|
| 332 |
+
cache.inference_params,
|
| 333 |
+
batch_size,
|
| 334 |
+
max_seqlen,
|
| 335 |
+
decoding_seqlen=decoding_seqlen,
|
| 336 |
+
mempool=cache.mempool,
|
| 337 |
+
n_warmups=n_warmups,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
def dispatch(input_ids, position_ids, seqlen):
|
| 341 |
+
batch_size, decoding_seqlen = input_ids.shape[:2]
|
| 342 |
+
return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
|
| 343 |
+
|
| 344 |
+
cache.run = dispatch
|
| 345 |
+
cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
|
| 346 |
+
return cache
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def capture_graph(
|
| 350 |
+
model, encoder_hidden_states, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
|
| 351 |
+
):
|
| 352 |
+
device = next(iter(model.parameters())).device
|
| 353 |
+
input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
|
| 354 |
+
position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
|
| 355 |
+
seqlen_offset_og = inference_params.seqlen_offset
|
| 356 |
+
inference_params.seqlen_offset = max_seqlen - decoding_seqlen
|
| 357 |
+
inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
|
| 358 |
+
|
| 359 |
+
# Warmup before capture
|
| 360 |
+
s = torch.cuda.Stream()
|
| 361 |
+
s.wait_stream(torch.cuda.current_stream())
|
| 362 |
+
with torch.cuda.stream(s):
|
| 363 |
+
for _ in range(n_warmups):
|
| 364 |
+
logits = model(
|
| 365 |
+
input_ids,
|
| 366 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 367 |
+
position_ids=position_ids,
|
| 368 |
+
inference_params=inference_params,
|
| 369 |
+
num_last_tokens=decoding_seqlen,
|
| 370 |
+
).logits
|
| 371 |
+
s.synchronize()
|
| 372 |
+
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
|
| 373 |
+
# which requires that graph launch and non-captured launch to not overlap (I think,
|
| 374 |
+
# that's how I interpret the documentation). I'm not sure if this is required.
|
| 375 |
+
if torch.distributed.is_initialized():
|
| 376 |
+
torch.distributed.barrier()
|
| 377 |
+
torch.cuda.current_stream().wait_stream(s)
|
| 378 |
+
# Captures the graph
|
| 379 |
+
# To allow capture, automatically sets a side stream as the current stream in the context
|
| 380 |
+
graph = torch.cuda.CUDAGraph()
|
| 381 |
+
with torch.cuda.graph(graph, pool=mempool):
|
| 382 |
+
logits = model(
|
| 383 |
+
input_ids,
|
| 384 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 385 |
+
position_ids=position_ids,
|
| 386 |
+
inference_params=inference_params,
|
| 387 |
+
num_last_tokens=decoding_seqlen,
|
| 388 |
+
).logits
|
| 389 |
+
|
| 390 |
+
def run(new_input_ids, new_position_ids, seqlen):
|
| 391 |
+
inference_params.lengths_per_sample[:] = seqlen
|
| 392 |
+
input_ids.copy_(new_input_ids)
|
| 393 |
+
position_ids.copy_(new_position_ids)
|
| 394 |
+
graph.replay()
|
| 395 |
+
return logits.clone()
|
| 396 |
+
|
| 397 |
+
inference_params.seqlen_offset = seqlen_offset_og
|
| 398 |
+
return run
|
str_bamba/load.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .bamba_config import BambaEncoderDecoderConfig
|
| 2 |
+
from .bamba import BambaConfig, BambaEncoderDecoder
|
| 3 |
+
from .tokenizer.str_tokenizer import load_tokenizer
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import random
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_strbamba(ckpt_filename,
|
| 12 |
+
base_folder='./str_bamba',
|
| 13 |
+
config_filename='config_encoder-decoder_436M.json',
|
| 14 |
+
tokenizer_filename='str_bamba_tokenizer.json',
|
| 15 |
+
eval_model=True,
|
| 16 |
+
device='cuda:0',
|
| 17 |
+
dtype=torch.float32
|
| 18 |
+
):
|
| 19 |
+
# load config
|
| 20 |
+
with open(os.path.join(base_folder, f'config/{config_filename}')) as json_data:
|
| 21 |
+
config_json = json.load(json_data)
|
| 22 |
+
bamba_config = BambaEncoderDecoderConfig(
|
| 23 |
+
encoder_config=BambaConfig(**config_json['encoder_config']),
|
| 24 |
+
decoder_config=BambaConfig(**config_json['decoder_config']),
|
| 25 |
+
tie_word_embeddings=config_json['tie_word_embeddings'],
|
| 26 |
+
seed=config_json['seed']
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# load tokenizer
|
| 30 |
+
tokenizer = load_tokenizer(os.path.join(base_folder, f'tokenizer/{tokenizer_filename}'))
|
| 31 |
+
|
| 32 |
+
# load model
|
| 33 |
+
model = BambaEncoderDecoder(bamba_config, tokenizer, device=device, dtype=dtype)
|
| 34 |
+
|
| 35 |
+
# load weights
|
| 36 |
+
ckpt_dict = torch.load(
|
| 37 |
+
os.path.join(base_folder, f'checkpoints/{ckpt_filename}'),
|
| 38 |
+
map_location=device,
|
| 39 |
+
weights_only=False
|
| 40 |
+
)
|
| 41 |
+
model.load_state_dict(ckpt_dict['module'])
|
| 42 |
+
|
| 43 |
+
# load RNG states each time the model and states are loaded from checkpoint
|
| 44 |
+
if 'rng' in ckpt_dict:
|
| 45 |
+
rng = ckpt_dict['rng']
|
| 46 |
+
for key, value in rng.items():
|
| 47 |
+
if key =='torch_state':
|
| 48 |
+
torch.set_rng_state(value.cpu())
|
| 49 |
+
elif key =='cuda_state':
|
| 50 |
+
torch.cuda.set_rng_state(value.cpu())
|
| 51 |
+
elif key =='numpy_state':
|
| 52 |
+
np.random.set_state(value)
|
| 53 |
+
elif key =='python_state':
|
| 54 |
+
random.setstate(value)
|
| 55 |
+
else:
|
| 56 |
+
print('unrecognized state')
|
| 57 |
+
|
| 58 |
+
if eval_model:
|
| 59 |
+
return model.eval()
|
| 60 |
+
return model
|
str_bamba/tokenizer/special_tokens.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
STR_SPECIAL_TOKENS = {
|
| 2 |
+
### basic tokens ###
|
| 3 |
+
"BOS_TOKEN": "<bos>",
|
| 4 |
+
"EOS_TOKEN": "<sep>",
|
| 5 |
+
"PAD_TOKEN": "<pad>",
|
| 6 |
+
"MASK_TOKEN": "<mask>",
|
| 7 |
+
"UNK_TOKEN": "<unk>",
|
| 8 |
+
|
| 9 |
+
### molecular representations ###
|
| 10 |
+
# molecular formula
|
| 11 |
+
"MOLECULAR_FORMULA_TOKEN": "<formula>",
|
| 12 |
+
|
| 13 |
+
# canonical SMILES
|
| 14 |
+
"SMILES_TOKEN": "<smiles>",
|
| 15 |
+
|
| 16 |
+
# IUPAC name
|
| 17 |
+
"IUPAC_TOKEN": "<iupac>",
|
| 18 |
+
|
| 19 |
+
# InChI
|
| 20 |
+
"INCHI_TOKEN": "<inchi>",
|
| 21 |
+
"INCHI_INITIAL_TOKEN": "InChI=", # force `InChI=` to be a unique token
|
| 22 |
+
"INCHI_COMMA_TOKEN": ",", # force `,` to be a unique token
|
| 23 |
+
"INCHI_DASH_TOKEN": "-", # force `-` to be a unique token
|
| 24 |
+
"INCHI_FORWARDSLASH_TOKEN": "/", # force `/` to be a unique token
|
| 25 |
+
"INCHI_QUESTIONMARK_TOKEN": "?", # force `?` to be a unique token
|
| 26 |
+
"INCHI_PARENTHESIS_OPEN_TOKEN": "(", # force `(` to be a unique token
|
| 27 |
+
"INCHI_PARENTHESIS_CLOSE_TOKEN": ")", # force `)` to be a unique token
|
| 28 |
+
|
| 29 |
+
# SELFIES
|
| 30 |
+
"SELFIES_TOKEN": "<selfies>",
|
| 31 |
+
|
| 32 |
+
# polymer SPG
|
| 33 |
+
"POLYMER_SPG_TOKEN": "<polymer_spg>",
|
| 34 |
+
"POLYMER_ARROW_TOKEN": "->", # force `->` to be a unique token
|
| 35 |
+
|
| 36 |
+
# formulation
|
| 37 |
+
"FORMULATION_START_TOKEN": "<formulation_start>",
|
| 38 |
+
"FORMULATION_END_TOKEN": "<formulation_end>",
|
| 39 |
+
}
|
str_bamba/tokenizer/str_bamba_tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
str_bamba/tokenizer/str_tokenizer.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from tokenizers import NormalizedString, PreTokenizedString
|
| 4 |
+
from tokenizers.pre_tokenizers import PreTokenizer
|
| 5 |
+
from transformers import PreTrainedTokenizerFast
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
ATOM_REGEX_PATTERN = r"""(<(.*?)>|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""
|
| 11 |
+
FORMULATION_REGEX_PATTERN = r"""(<(.*?)>|[-+]?\d*\.\d+|[-+]?\d+\.?\d*[eE][-+]?\d+|[-+]?\d+|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""
|
| 12 |
+
NUMBER_REGEX_PATTERN = r"""(\d{2}|\d[a-zA-Z]\d|\d[a-zA-Z]|[a-zA-Z]\d+|\(|\))"""
|
| 13 |
+
# NUMBER_REGEX_PATTERN = r"""((?<!\d)\d{2}(?!\d)|\d[a-zA-Z]\d|\d[a-zA-Z]|[a-zA-Z]\d)"""
|
| 14 |
+
# NUMBER_REGEX_PATTERN = r"""(\d[a-zA-Z]\d|\d[a-zA-Z]|[a-zA-Z]\d|\b\d{2}\b)"""
|
| 15 |
+
SPECIAL_REGEX_PATTERN = r"""<(.*?)>"""
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MoleculePreTokenizer:
|
| 19 |
+
|
| 20 |
+
def molecule_based_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
|
| 21 |
+
splits = []
|
| 22 |
+
if str(normalized_string).startswith(('<smiles>', '<selfies>', '<polymer_spg>')):
|
| 23 |
+
for m in re.finditer(ATOM_REGEX_PATTERN, str(normalized_string)):
|
| 24 |
+
start = m.start(0)
|
| 25 |
+
stop = m.end(0)
|
| 26 |
+
if start == 0: # remove special tokens
|
| 27 |
+
continue
|
| 28 |
+
splits.append(normalized_string[start:stop])
|
| 29 |
+
elif str(normalized_string).startswith('<formulation_start>'):
|
| 30 |
+
for m in re.finditer(FORMULATION_REGEX_PATTERN, str(normalized_string)):
|
| 31 |
+
start = m.start(0)
|
| 32 |
+
stop = m.end(0)
|
| 33 |
+
if start == 0 or stop == len(str(normalized_string)): # remove special tokens
|
| 34 |
+
continue
|
| 35 |
+
splits.append(normalized_string[start:stop])
|
| 36 |
+
elif str(normalized_string).startswith(('<formula>', '<inchi>')):
|
| 37 |
+
for m in re.finditer(NUMBER_REGEX_PATTERN, str(normalized_string)):
|
| 38 |
+
start = m.start(0)
|
| 39 |
+
stop = m.end(0)
|
| 40 |
+
if start == 0: # remove special tokens
|
| 41 |
+
continue
|
| 42 |
+
splits.append(normalized_string[start:stop])
|
| 43 |
+
else:
|
| 44 |
+
last = 0
|
| 45 |
+
for m in re.finditer(SPECIAL_REGEX_PATTERN, str(normalized_string)): # remove special tokens
|
| 46 |
+
start = m.start(0)
|
| 47 |
+
stop = m.end(0)
|
| 48 |
+
# splits.append(normalized_string[start:stop])
|
| 49 |
+
last = stop
|
| 50 |
+
splits.append(normalized_string[last:])
|
| 51 |
+
|
| 52 |
+
return splits
|
| 53 |
+
|
| 54 |
+
def pre_tokenize(self, pretok: PreTokenizedString):
|
| 55 |
+
pretok.split(self.molecule_based_split)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class MultiMolTranBertTokenizer(PreTrainedTokenizerFast):
|
| 59 |
+
def __init__(self, vocab_file: str = '',
|
| 60 |
+
do_lower_case=False,
|
| 61 |
+
cls_token='<bos>',
|
| 62 |
+
eos_token='<sep>',
|
| 63 |
+
pad_token='<pad>',
|
| 64 |
+
unk_token='<unk>',
|
| 65 |
+
mask_token='<mask>',
|
| 66 |
+
**kwargs):
|
| 67 |
+
|
| 68 |
+
super().__init__(
|
| 69 |
+
tokenizer_file=vocab_file,
|
| 70 |
+
bos_token=cls_token,
|
| 71 |
+
eos_token=eos_token,
|
| 72 |
+
pad_token=pad_token,
|
| 73 |
+
unk_token=unk_token,
|
| 74 |
+
mask_token=mask_token
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def get_padding_idx(self):
|
| 78 |
+
return 2
|
| 79 |
+
|
| 80 |
+
def convert_idx_to_tokens(self, idx_tensor):
|
| 81 |
+
tokens = [self.convert_ids_to_tokens(idx) for idx in idx_tensor.tolist()]
|
| 82 |
+
return tokens
|
| 83 |
+
|
| 84 |
+
def convert_tokens_to_string(self, tokens):
|
| 85 |
+
stopwords = ['<bos>', '<eos>']
|
| 86 |
+
clean_tokens = [word for word in tokens if word not in stopwords]
|
| 87 |
+
out_string = ''.join(clean_tokens)
|
| 88 |
+
return out_string
|
| 89 |
+
|
| 90 |
+
def idx_to_smiles(self, torch_model, idx):
|
| 91 |
+
'''Convert tokens idx back to SMILES text'''
|
| 92 |
+
rev_tokens = torch_model.tokenizer.convert_idx_to_tokens(idx)
|
| 93 |
+
flat_list_tokens = [item for sublist in rev_tokens for item in sublist]
|
| 94 |
+
decoded_smiles = torch_model.tokenizer.convert_tokens_to_string(flat_list_tokens)
|
| 95 |
+
return decoded_smiles
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def load_tokenizer(vocab_file, **kwargs):
|
| 99 |
+
tokenizer = MultiMolTranBertTokenizer(vocab_file, **kwargs)
|
| 100 |
+
tokenizer.backend_tokenizer.pre_tokenizer = PreTokenizer.custom(MoleculePreTokenizer())
|
| 101 |
+
return tokenizer
|