|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
|
from typing import List, Optional |
|
|
|
|
|
import tensorrt as trt |
|
|
|
|
|
from ..._common import default_net |
|
|
from ..._utils import str_dtype_to_trt |
|
|
from ...functional import (Tensor, arange, cast, concat, expand, |
|
|
gather_last_token_logits, shape, unsqueeze) |
|
|
from ...layers import Embedding, LayerNorm, Linear, Mamba, Mamba2, RmsNorm |
|
|
from ...module import Module, ModuleList |
|
|
from ...plugin import current_all_reduce_helper |
|
|
from ..generation_mixin import GenerationMixin |
|
|
from ..modeling_utils import PretrainedConfig, PretrainedModel |
|
|
|
|
|
|
|
|
class MambaLayer(Module): |
|
|
|
|
|
def __init__(self, config: PretrainedConfig, layer_idx: int): |
|
|
super().__init__() |
|
|
self.dtype = config.dtype |
|
|
self.residual_in_fp32 = config.residual_in_fp32 |
|
|
n_layer = config.num_hidden_layers |
|
|
self.last_layer = layer_idx == n_layer - 1 |
|
|
|
|
|
if config.mamba_version == 'Mamba1': |
|
|
self.ssm = Mamba(config.hidden_size, |
|
|
config.rnn_hidden_size, |
|
|
d_state=config.state_size, |
|
|
d_conv=config.conv_kernel, |
|
|
bias=config.use_bias, |
|
|
dtype=config.dtype) |
|
|
elif config.mamba_version == 'Mamba2': |
|
|
self.ssm = Mamba2(config.hidden_size, |
|
|
config.rnn_hidden_size, |
|
|
d_state=config.state_size, |
|
|
d_conv=config.conv_kernel, |
|
|
headdim=config.rnn_head_size, |
|
|
ngroups=config.ngroups, |
|
|
chunk_size=config.chunk_size, |
|
|
bias=config.use_bias, |
|
|
rmsnorm=config.ssm_rmsnorm, |
|
|
dtype=config.dtype) |
|
|
if config.rms_norm: |
|
|
self.input_layernorm = RmsNorm(normalized_shape=config.hidden_size, |
|
|
eps=config.norm_epsilon, |
|
|
dtype=config.dtype) |
|
|
else: |
|
|
self.input_layernorm = LayerNorm( |
|
|
normalized_shape=config.hidden_size, |
|
|
eps=config.norm_epsilon, |
|
|
dtype=config.dtype) |
|
|
|
|
|
def forward(self, |
|
|
hidden_states: Tensor, |
|
|
residual: Tensor, |
|
|
conv_state: Tensor, |
|
|
ssm_state: Tensor, |
|
|
host_request_types: Tensor, |
|
|
last_token_ids: Tensor, |
|
|
host_context_lengths: Optional[Tensor] = None, |
|
|
slot_mapping: Optional[Tensor] = None, |
|
|
conv_indices: Optional[Tensor] = None): |
|
|
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
ssm_out, present_conv, present_ssm = self.ssm( |
|
|
hidden_states, |
|
|
conv_state=conv_state, |
|
|
ssm_state=ssm_state, |
|
|
host_request_types=host_request_types, |
|
|
last_token_ids=last_token_ids, |
|
|
host_context_lengths=host_context_lengths, |
|
|
slot_mapping=slot_mapping, |
|
|
conv_indices=conv_indices) |
|
|
if self.residual_in_fp32: |
|
|
residual = residual + cast(ssm_out, 'float32') |
|
|
hidden_states = cast(residual, self.dtype) |
|
|
else: |
|
|
residual = residual + ssm_out |
|
|
hidden_states = residual |
|
|
|
|
|
if self.last_layer: |
|
|
return hidden_states, None, present_conv, present_ssm |
|
|
else: |
|
|
return hidden_states, residual, present_conv, present_ssm |
|
|
|
|
|
|
|
|
class MambaModel(Module): |
|
|
|
|
|
def __init__(self, config: PretrainedConfig): |
|
|
super().__init__() |
|
|
self.d_conv = config.conv_kernel |
|
|
self.d_inner = config.rnn_hidden_size |
|
|
n_layer = config.num_hidden_layers |
|
|
self.residual_in_fp32 = config.residual_in_fp32 |
|
|
if config.vocab_size % config.pad_vocab_size_multiple != 0: |
|
|
config.vocab_size += config.pad_vocab_size_multiple - ( |
|
|
config.vocab_size % config.pad_vocab_size_multiple) |
|
|
self.vocab_embedding = Embedding(config.vocab_size, |
|
|
config.hidden_size, |
|
|
dtype=config.dtype) |
|
|
self.layers = ModuleList( |
|
|
[MambaLayer(config, i) for i in range(n_layer)]) |
|
|
if config.rms_norm: |
|
|
self.ln_f = RmsNorm(normalized_shape=config.hidden_size, |
|
|
eps=config.norm_epsilon, |
|
|
dtype=config.dtype) |
|
|
else: |
|
|
self.ln_f = LayerNorm(normalized_shape=config.hidden_size, |
|
|
eps=config.norm_epsilon, |
|
|
dtype=config.dtype) |
|
|
|
|
|
def forward(self, |
|
|
input_ids, |
|
|
conv_states, |
|
|
ssm_states, |
|
|
host_request_types, |
|
|
last_token_ids, |
|
|
host_context_lengths, |
|
|
slot_mapping: Optional[Tensor] = None): |
|
|
hidden_states = self.vocab_embedding(input_ids) |
|
|
|
|
|
|
|
|
indices = None |
|
|
if not default_net().plugin_config.mamba_conv1d_plugin: |
|
|
batch_size = shape(input_ids, 0) |
|
|
indices = expand( |
|
|
unsqueeze(arange(0, self.d_conv - 1, dtype='int32'), 0), |
|
|
concat([batch_size, self.d_conv - 1])) |
|
|
offsets = expand(unsqueeze(last_token_ids, 1), |
|
|
concat([batch_size, self.d_conv - 1])) |
|
|
indices = unsqueeze(indices + offsets, 1) |
|
|
indices = expand( |
|
|
indices, concat([batch_size, self.d_inner, self.d_conv - 1])) |
|
|
|
|
|
residual = cast(hidden_states, |
|
|
'float32') if self.residual_in_fp32 else hidden_states |
|
|
hidden_values = [hidden_states, residual] |
|
|
present_convs, present_ssms = [], [] |
|
|
for layer, past_conv, past_ssm in zip(self.layers, conv_states, |
|
|
ssm_states): |
|
|
hidden_values = layer(hidden_values[0], hidden_values[1], past_conv, |
|
|
past_ssm, host_request_types, last_token_ids, |
|
|
host_context_lengths, slot_mapping, indices) |
|
|
present_convs.append(hidden_values[2]) |
|
|
present_ssms.append(hidden_values[3]) |
|
|
hidden_states = hidden_values[0] |
|
|
hidden_states = self.ln_f(hidden_states) |
|
|
return hidden_states, tuple(present_convs), tuple(present_ssms) |
|
|
|
|
|
|
|
|
class MambaForCausalLM(PretrainedModel): |
|
|
|
|
|
def __init__(self, config: PretrainedConfig): |
|
|
super().__init__(config) |
|
|
dtype = config.dtype |
|
|
logits_dtype = config.logits_dtype |
|
|
if isinstance(dtype, str): |
|
|
self.dtype = str_dtype_to_trt(dtype) |
|
|
else: |
|
|
assert isinstance(dtype, trt.DataType) |
|
|
self.dtype = dtype |
|
|
|
|
|
self.config = config |
|
|
self.mamba_version = config.mamba_version |
|
|
self.d_inner = config.rnn_hidden_size |
|
|
self.d_conv = config.conv_kernel |
|
|
self.d_state = config.state_size |
|
|
self.conv_dim = config.rnn_conv_dim_size |
|
|
self.gather_context_logits = False |
|
|
|
|
|
if isinstance(logits_dtype, str): |
|
|
self._logits_dtype = str_dtype_to_trt(logits_dtype) |
|
|
else: |
|
|
assert isinstance(logits_dtype, trt.DataType) |
|
|
self._logits_dtype = logits_dtype |
|
|
|
|
|
self.backbone = MambaModel(config) |
|
|
self.lm_head = Linear(config.hidden_size, |
|
|
config.vocab_size, |
|
|
bias=False, |
|
|
dtype=dtype, |
|
|
gather_output=False) |
|
|
|
|
|
def __post_init__(self): |
|
|
return |
|
|
|
|
|
def forward(self, |
|
|
input_ids, |
|
|
conv_states, |
|
|
ssm_states, |
|
|
host_request_types, |
|
|
last_token_ids, |
|
|
last_token_ids_for_logits, |
|
|
host_context_lengths, |
|
|
slot_mapping: Optional[Tensor] = None): |
|
|
hidden_states, present_convs, present_ssms = self.backbone( |
|
|
input_ids, conv_states, ssm_states, host_request_types, |
|
|
last_token_ids, host_context_lengths, slot_mapping) |
|
|
|
|
|
if not self.gather_context_logits: |
|
|
hidden_states = gather_last_token_logits( |
|
|
hidden_states, last_token_ids_for_logits, |
|
|
default_net().plugin_config.remove_input_padding) |
|
|
|
|
|
lm_logits = self.lm_head(hidden_states) |
|
|
lm_logits.mark_output('logits', self._logits_dtype) |
|
|
if not default_net().plugin_config.paged_state: |
|
|
for i, present_conv in enumerate(present_convs): |
|
|
present_conv.mark_output(f'present_conv_state_{i}', self.dtype) |
|
|
for i, present_ssm in enumerate(present_ssms): |
|
|
present_ssm.mark_output(f'present_rnn_state_{i}', self.dtype) |
|
|
|
|
|
return (lm_logits, present_convs, present_ssms) |
|
|
|
|
|
def prepare_inputs( |
|
|
self, |
|
|
max_batch_size, |
|
|
max_input_len, |
|
|
max_seq_len, |
|
|
max_num_tokens, |
|
|
use_cache, |
|
|
max_beam_width: int = 1, |
|
|
opt_num_tokens: int = None, |
|
|
opt_batch_size: int = 0, |
|
|
prompt_embedding_table_size: int = 0, |
|
|
max_draft_len: int = 0, |
|
|
gather_context_logits: bool = False, |
|
|
gather_generation_logits: bool = False, |
|
|
lora_target_modules: List[str] = None, |
|
|
speculative_decoding_draft_tokens_external: bool = False): |
|
|
'''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the |
|
|
ranges of the dimensions of when using TRT dynamic shapes. |
|
|
|
|
|
@return: a list contains values which can be fed into the self.forward() |
|
|
''' |
|
|
assert speculative_decoding_draft_tokens_external == False, "Speculative decoding is not supported in Mamba" |
|
|
assert max_beam_width == 1, "We don't support beam search for the Mamba model." |
|
|
|
|
|
remove_input_padding = default_net().plugin_config.remove_input_padding |
|
|
use_gemm_plugin = default_net().plugin_config.gemm_plugin |
|
|
paged_state = default_net().plugin_config.paged_state |
|
|
multiple_profiles = default_net().plugin_config.multiple_profiles |
|
|
use_mamba_conv1d_plugin = default_net( |
|
|
).plugin_config.mamba_conv1d_plugin |
|
|
|
|
|
self.gather_context_logits = gather_context_logits |
|
|
mapping = self.config.mapping |
|
|
|
|
|
|
|
|
enable_ctx_gen_opt_profiles = GenerationMixin.has_ctx_gen_opt_profiles( |
|
|
True, use_gemm_plugin, remove_input_padding, paged_state) |
|
|
|
|
|
num_profiles, ranges = GenerationMixin.get_profiles_ranges( |
|
|
max_batch_size=max_batch_size, |
|
|
max_beam_width=max_beam_width, |
|
|
max_input_len=max_input_len, |
|
|
max_num_tokens=max_num_tokens, |
|
|
max_draft_len=max_draft_len, |
|
|
opt_batch_size=opt_batch_size, |
|
|
opt_num_tokens=opt_num_tokens, |
|
|
enable_ctx_gen_opt_profiles=enable_ctx_gen_opt_profiles, |
|
|
multiple_profiles=multiple_profiles) |
|
|
|
|
|
if remove_input_padding: |
|
|
assert use_mamba_conv1d_plugin, "mamba_conv1d_plugin is needed to support remove_input_padding" |
|
|
input_ids = Tensor(name='input_ids', |
|
|
dtype=trt.int32, |
|
|
shape=[-1], |
|
|
dim_range=OrderedDict([ |
|
|
('num_tokens', ranges['num_tokens_range']), |
|
|
])) |
|
|
else: |
|
|
input_ids = Tensor(name='input_ids', |
|
|
dtype=trt.int32, |
|
|
shape=[-1, -1], |
|
|
dim_range=OrderedDict([ |
|
|
('batch_size_beam_width', |
|
|
ranges['bb_range']), |
|
|
('input_len', ranges['inlen_range']), |
|
|
])) |
|
|
if mapping.tp_size > 1: |
|
|
current_all_reduce_helper().set_workspace_tensor( |
|
|
mapping, num_profiles) |
|
|
|
|
|
|
|
|
conv_states = [] |
|
|
ssm_states = [] |
|
|
if use_mamba_conv1d_plugin: |
|
|
conv_state_dim_range = OrderedDict([ |
|
|
('batch_size', ranges['bb_range']), |
|
|
('kernel_size', [self.d_conv - 1] * num_profiles), |
|
|
('dim_size', [self.conv_dim] * num_profiles), |
|
|
]) |
|
|
else: |
|
|
conv_state_dim_range = OrderedDict([ |
|
|
('batch_size', ranges['bb_range']), |
|
|
('dim_size', [self.conv_dim] * num_profiles), |
|
|
('kernel_size', [self.d_conv - 1] * num_profiles), |
|
|
]) |
|
|
|
|
|
if self.mamba_version == 'Mamba2': |
|
|
headdim = self.config.rnn_head_size |
|
|
nheads = self.d_inner // headdim |
|
|
ssm_state_dim_range = OrderedDict([ |
|
|
('batch_size', ranges['bb_range']), |
|
|
('head_size', [nheads] * num_profiles), |
|
|
('state_size', [self.d_state] * num_profiles), |
|
|
('headdim_size', [headdim] * num_profiles), |
|
|
]) |
|
|
ssm_state_shape = [-1, nheads, self.d_state, headdim] |
|
|
else: |
|
|
ssm_state_dim_range = OrderedDict([ |
|
|
('batch_size', ranges['bb_range']), |
|
|
('state_size', [self.d_state] * num_profiles), |
|
|
('dim_size', [self.d_inner] * num_profiles), |
|
|
]) |
|
|
ssm_state_shape = [-1, self.d_state, self.d_inner] |
|
|
one_dim_range = OrderedDict([ |
|
|
('buffer_count', [1] * num_profiles), |
|
|
]) |
|
|
|
|
|
for i in range(self.config.num_hidden_layers): |
|
|
if default_net().plugin_config.paged_state: |
|
|
conv_state = Tensor(name=f'conv_state_ptr_{i}', |
|
|
dtype=str_dtype_to_trt('int64'), |
|
|
shape=[1], |
|
|
dim_range=one_dim_range) |
|
|
|
|
|
ssm_state = Tensor(name=f'rnn_state_ptr_{i}', |
|
|
dtype=str_dtype_to_trt('int64'), |
|
|
shape=[1], |
|
|
dim_range=one_dim_range) |
|
|
else: |
|
|
if use_mamba_conv1d_plugin: |
|
|
conv_state = Tensor( |
|
|
name=f'past_conv_state_{i}', |
|
|
dtype=self.dtype, |
|
|
shape=[-1, self.d_conv - 1, self.conv_dim], |
|
|
dim_range=conv_state_dim_range) |
|
|
else: |
|
|
conv_state = Tensor( |
|
|
name=f'past_conv_state_{i}', |
|
|
dtype=self.dtype, |
|
|
shape=[-1, self.conv_dim, self.d_conv - 1], |
|
|
dim_range=conv_state_dim_range) |
|
|
|
|
|
ssm_state = Tensor(name=f'past_rnn_state_{i}', |
|
|
dtype=self.dtype, |
|
|
shape=ssm_state_shape, |
|
|
dim_range=ssm_state_dim_range) |
|
|
|
|
|
conv_states.append(conv_state) |
|
|
ssm_states.append(ssm_state) |
|
|
|
|
|
host_request_types = Tensor( |
|
|
name='host_request_types', |
|
|
dtype=trt.int32, |
|
|
shape=[-1], |
|
|
dim_range=OrderedDict([('batch_size', ranges['bb_range'])]), |
|
|
) |
|
|
|
|
|
if remove_input_padding: |
|
|
host_context_lengths = Tensor( |
|
|
name='host_context_lengths', |
|
|
dtype=trt.int32, |
|
|
shape=[-1], |
|
|
dim_range=OrderedDict([('batch_size', ranges['bb_range'])]), |
|
|
) |
|
|
else: |
|
|
host_context_lengths = None |
|
|
|
|
|
last_token_ids = Tensor( |
|
|
name='last_token_ids', |
|
|
dtype=trt.int32, |
|
|
shape=[-1], |
|
|
dim_range=OrderedDict([ |
|
|
('batch_size', ranges['bbd_range']), |
|
|
]), |
|
|
) |
|
|
last_token_ids_for_logits = None |
|
|
if not gather_context_logits: |
|
|
last_token_ids_for_logits = last_token_ids |
|
|
|
|
|
return_dict = { |
|
|
'input_ids': input_ids, |
|
|
'conv_states': conv_states, |
|
|
'ssm_states': ssm_states, |
|
|
'host_request_types': host_request_types, |
|
|
'last_token_ids': last_token_ids, |
|
|
'last_token_ids_for_logits': last_token_ids_for_logits, |
|
|
'host_context_lengths': host_context_lengths, |
|
|
} |
|
|
|
|
|
if default_net().plugin_config.paged_state: |
|
|
slot_mapping = Tensor( |
|
|
name='slot_mapping', |
|
|
dtype=trt.int32, |
|
|
shape=[-1], |
|
|
dim_range=OrderedDict([('batch_size', ranges['bb_range'])]), |
|
|
) |
|
|
return_dict['slot_mapping'] = slot_mapping |
|
|
|
|
|
return return_dict |
|
|
|