| | |
| |
|
| | import math |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from einops import rearrange, repeat, pack, unpack |
| |
|
| | try: |
| | from causal_conv1d import causal_conv1d_fn, causal_conv1d_update |
| | except ImportError: |
| | causal_conv1d_fn, causal_conv1d_update = None, None |
| |
|
| | try: |
| | from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states |
| | except ImportError: |
| | causal_conv1d_varlen_states = None |
| |
|
| | from mamba_ssm.ops.triton.selective_state_update import selective_state_update |
| | from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated |
| |
|
| |
|
| | from mamba_ssm.distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear |
| | from mamba_ssm.distributed.distributed_utils import all_reduce, reduce_scatter |
| |
|
| | from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined |
| | from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined |
| |
|
| |
|
| | class Mamba2(nn.Module): |
| | def __init__( |
| | self, |
| | config, |
| | conv_init=None, |
| | d_ssm=None, |
| | ngroups=1, |
| | A_init_range=(1, 16), |
| | D_has_hdim=False, |
| | rmsnorm=True, |
| | norm_before_gate=False, |
| | dt_min=0.001, |
| | dt_max=0.1, |
| | dt_init_floor=1e-4, |
| | dt_limit=(0.0, float("inf")), |
| | bias=False, |
| | conv_bias=True, |
| | |
| | chunk_size=256, |
| | use_mem_eff_path=False, |
| | layer_idx=None, |
| | process_group=None, |
| | sequence_parallel=True, |
| | device=None, |
| | dtype=None, |
| | ): |
| | factory_kwargs = {"device": device, "dtype": dtype} |
| | super().__init__() |
| | |
| | self.config = config |
| | self.d_model = config.hidden_size |
| | self.d_state = config.mamba_d_state |
| | self.d_conv = config.mamba_d_conv |
| |
|
| | self.conv_init = conv_init |
| | self.expand = config.mamba_expand |
| | self.process_group = process_group |
| | self.sequence_parallel = sequence_parallel |
| | self.world_size = 1 if process_group is None else process_group.size() |
| | self.local_rank = 0 if process_group is None else process_group.rank() |
| | self.d_inner = (self.expand * self.d_model) // self.world_size |
| | assert self.d_inner * self.world_size == self.expand * self.d_model |
| | self.headdim = config.mamba2_headdim |
| | self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size |
| | assert ngroups % self.world_size == 0 |
| | self.ngroups = ngroups // self.world_size |
| | assert self.d_ssm % self.headdim == 0 |
| | self.nheads = self.d_ssm // self.headdim |
| | self.D_has_hdim = D_has_hdim |
| | self.rmsnorm = rmsnorm |
| | self.norm_before_gate = norm_before_gate |
| | self.dt_limit = dt_limit |
| | self.activation = "silu" |
| | self.chunk_size = chunk_size |
| | self.use_mem_eff_path = use_mem_eff_path |
| | self.layer_idx = layer_idx |
| |
|
| | assert (self.d_model * self.expand / self.headdim) % 8 == 0 |
| |
|
| | |
| | d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads |
| | if self.process_group is None: |
| | self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs) |
| | else: |
| | self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias, |
| | process_group=self.process_group, sequence_parallel=self.sequence_parallel, |
| | **factory_kwargs) |
| |
|
| | conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state |
| | self.conv1d = nn.Conv1d( |
| | in_channels=conv_dim, |
| | out_channels=conv_dim, |
| | bias=conv_bias, |
| | kernel_size=self.d_conv, |
| | groups=conv_dim, |
| | padding=self.d_conv - 1, |
| | **factory_kwargs, |
| | ) |
| | if self.conv_init is not None: |
| | nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) |
| |
|
| | self.act = nn.SiLU() |
| |
|
| | |
| | dt = torch.exp( |
| | torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) |
| | + math.log(dt_min) |
| | ) |
| | dt = torch.clamp(dt, min=dt_init_floor) |
| | |
| | inv_dt = dt + torch.log(-torch.expm1(-dt)) |
| | |
| | self.dt_bias = nn.Parameter(inv_dt) |
| | |
| | |
| | self.dt_bias._no_weight_decay = True |
| |
|
| | assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] |
| | A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range) |
| | A_log = torch.log(A).to(dtype=dtype) |
| | self.A_log = nn.Parameter(A_log) |
| | self.A_log._no_weight_decay = True |
| |
|
| | |
| | self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)) |
| | self.D._no_weight_decay = True |
| |
|
| | if self.rmsnorm: |
| | assert RMSNormGated is not None |
| | self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate, |
| | group_size=self.d_ssm // ngroups, **factory_kwargs) |
| |
|
| | if self.process_group is None: |
| | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) |
| | else: |
| | self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias, |
| | process_group=self.process_group, sequence_parallel=self.sequence_parallel, |
| | **factory_kwargs) |
| |
|
| | |
| | def forward(self, hidden_states, attention_mask=None, past_key_value=None, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None): |
| | """ |
| | hidden_states: (batch, seqlen, hidden_dim) if seqlen=None. |
| | If seqlen is not None, hidden_states is (batch * seqlen, hidden_dim). This is so that when we |
| | split hidden_states during sequence parallel, we split the batch * seqlen dimension |
| | (in case batch is small). |
| | Returns: same shape as u |
| | """ |
| | |
| |
|
| | seqlen_og = seqlen |
| | if seqlen is None: |
| | batch, seqlen, dim = hidden_states.shape |
| | else: |
| | batch_seqlen, dim = hidden_states.shape |
| | batch = batch_seqlen // seqlen |
| |
|
| | conv_state, ssm_state = None, None |
| |
|
| | if inference_params is not None: |
| | inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch |
| | conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch) |
| |
|
| | if inference_params.seqlen_offset > 0: |
| | |
| | out, _, _ = self.step(hidden_states, conv_state, ssm_state) |
| | return out, past_key_value |
| |
|
| | zxbcdt = self.in_proj(hidden_states) |
| |
|
| | if seqlen_og is not None: |
| | zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen) |
| | |
| | A = -torch.exp(self.A_log.float()) |
| | dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) |
| | if self.use_mem_eff_path and inference_params is None: |
| | out = mamba_split_conv1d_scan_combined( |
| | zxbcdt, |
| | rearrange(self.conv1d.weight, "d 1 w -> d w"), |
| | self.conv1d.bias, |
| | self.dt_bias, |
| | A, |
| | D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, |
| | chunk_size=self.chunk_size, |
| | seq_idx=seq_idx, |
| | activation=self.activation, |
| | rmsnorm_weight=self.norm.weight if self.rmsnorm else None, |
| | rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6, |
| | outproj_weight=self.out_proj.weight, |
| | outproj_bias=self.out_proj.bias, |
| | headdim=None if self.D_has_hdim else self.headdim, |
| | ngroups=self.ngroups, |
| | norm_before_gate=self.norm_before_gate, |
| | **dt_limit_kwargs, |
| | ) |
| | if seqlen_og is not None: |
| | out = rearrange(out, "b l d -> (b l) d") |
| | if self.process_group is not None: |
| | reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce |
| | out = reduce_fn(out, self.process_group) |
| | else: |
| | d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2 |
| | z0, x0, z, xBC, dt = torch.split( |
| | zxbcdt, |
| | [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads], |
| | dim=-1 |
| | ) |
| |
|
| | if conv_state is not None: |
| | if cu_seqlens is None: |
| | |
| | |
| | xBC_t = rearrange(xBC, "b l d -> b d l") |
| | conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) |
| | else: |
| | assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package" |
| | assert batch == 1, "varlen inference only supports batch dimension 1" |
| | conv_varlen_states = causal_conv1d_varlen_states( |
| | xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1] |
| | ) |
| | conv_state.copy_(conv_varlen_states) |
| | assert self.activation in ["silu", "swish"] |
| | if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: |
| | assert seq_idx is None, "varlen conv1d requires the causal_conv1d package" |
| | xBC = self.act( |
| | self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):] |
| | ) |
| | else: |
| | xBC = causal_conv1d_fn( |
| | xBC.transpose(1, 2), |
| | rearrange(self.conv1d.weight, "d 1 w -> d w"), |
| | bias=self.conv1d.bias, |
| | activation=self.activation, |
| | |
| | ).transpose(1, 2) |
| |
|
| | x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1) |
| | |
| |
|
| | y = mamba_chunk_scan_combined( |
| | rearrange(x, "b l (h p) -> b l h p", p=self.headdim), |
| | dt, |
| | A, |
| | rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), |
| | rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), |
| | chunk_size=self.chunk_size, |
| | |
| | D=self.D, |
| | z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None, |
| | dt_bias=self.dt_bias, |
| | dt_softplus=True, |
| | seq_idx=seq_idx, |
| | cu_seqlens=cu_seqlens, |
| | **dt_limit_kwargs, |
| | return_final_states=ssm_state is not None, |
| | return_varlen_states=cu_seqlens is not None and inference_params is not None, |
| | ) |
| | if ssm_state is not None: |
| | y, last_state, *rest = y |
| | if cu_seqlens is None: |
| | ssm_state.copy_(last_state) |
| | else: |
| | varlen_states = rest[0] |
| | ssm_state.copy_(varlen_states) |
| | y = rearrange(y, "b l h p -> b l (h p)") |
| | if self.rmsnorm: |
| | y_full = y |
| | z_full = z |
| |
|
| | y = self.norm(y_full, z_full) |
| | if d_mlp > 0: |
| | y = torch.cat([F.silu(z0) * x0, y], dim=-1) |
| | if seqlen_og is not None: |
| | y = rearrange(y, "b l d -> (b l) d") |
| | |
| | out = self.out_proj(y) |
| |
|
| | return out, past_key_value |
| |
|
| |
|
| | def step(self, hidden_states, conv_state, ssm_state): |
| | dtype = hidden_states.dtype |
| | |
| | batch_size, seq_len, _ = hidden_states.shape |
| | |
| | if seq_len == 1: |
| | |
| | zxbcdt = self.in_proj(hidden_states.squeeze(1)) |
| | else: |
| | |
| | zxbcdt = self.in_proj(hidden_states) |
| | |
| | d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2 |
| | |
| | if seq_len == 1: |
| | z0, x0, z, xBC, dt = torch.split( |
| | zxbcdt, |
| | [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads], |
| | dim=-1 |
| | ) |
| | else: |
| | z0, x0, z, xBC, dt = torch.split( |
| | zxbcdt, |
| | [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads], |
| | dim=-1 |
| | ) |
| |
|
| | |
| | if seq_len == 1: |
| | |
| | if causal_conv1d_update is None: |
| | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) |
| | conv_state[:, :, -1] = xBC |
| | xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) |
| | if self.conv1d.bias is not None: |
| | xBC = xBC + self.conv1d.bias |
| | xBC = self.act(xBC).to(dtype=dtype) |
| | else: |
| | xBC = causal_conv1d_update( |
| | xBC, |
| | conv_state, |
| | rearrange(self.conv1d.weight, "d 1 w -> d w"), |
| | self.conv1d.bias, |
| | self.activation, |
| | ) |
| | else: |
| | |
| | |
| | xBC_t = rearrange(xBC, "b l d -> b d l") |
| | conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) |
| | |
| | |
| | if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: |
| | xBC = self.act( |
| | self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.d_conv - 1):] |
| | ) |
| | else: |
| | xBC = causal_conv1d_fn( |
| | xBC.transpose(1, 2), |
| | rearrange(self.conv1d.weight, "d 1 w -> d w"), |
| | bias=self.conv1d.bias, |
| | activation=self.activation, |
| | ).transpose(1, 2) |
| |
|
| | x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1) |
| | A = -torch.exp(self.A_log.float()) |
| |
|
| | |
| | if seq_len == 1: |
| | |
| | if selective_state_update is None: |
| | assert self.ngroups == 1, "Only support ngroups=1 for this inference code path" |
| | |
| | dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) |
| | dA = torch.exp(dt * A) |
| | x = rearrange(x, "b (h p) -> b h p", p=self.headdim) |
| | dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x) |
| | ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) |
| | y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C) |
| | y = y + rearrange(self.D.to(dtype), "h -> h 1") * x |
| | y = rearrange(y, "b h p -> b (h p)") |
| | if not self.rmsnorm: |
| | y = y * self.act(z) |
| | else: |
| | A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32) |
| | dt = repeat(dt, "b h -> b h p", p=self.headdim) |
| | dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) |
| | D = repeat(self.D, "h -> h p", p=self.headdim) |
| | B = rearrange(B, "b (g n) -> b g n", g=self.ngroups) |
| | C = rearrange(C, "b (g n) -> b g n", g=self.ngroups) |
| | x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim) |
| | if not self.rmsnorm: |
| | z = rearrange(z, "b (h p) -> b h p", p=self.headdim) |
| | y = selective_state_update( |
| | ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None, |
| | dt_bias=dt_bias, dt_softplus=True |
| | ) |
| | y = rearrange(y, "b h p -> b (h p)") |
| | else: |
| | |
| | dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) |
| | |
| | y = mamba_chunk_scan_combined( |
| | rearrange(x, "b l (h p) -> b l h p", p=self.headdim), |
| | dt, |
| | A, |
| | rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), |
| | rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), |
| | chunk_size=self.chunk_size, |
| | D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, |
| | z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None, |
| | dt_bias=self.dt_bias, |
| | dt_softplus=True, |
| | **dt_limit_kwargs, |
| | return_final_states=True, |
| | ) |
| | |
| | y, final_ssm_state = y |
| | ssm_state.copy_(final_ssm_state) |
| | y = rearrange(y, "b l h p -> b l (h p)") |
| | |
| | if self.rmsnorm: |
| | y = self.norm(y, z) |
| | if d_mlp > 0: |
| | y = torch.cat([F.silu(z0) * x0, y], dim=-1) |
| | out = self.out_proj(y) |
| | |
| | |
| | if seq_len == 1 and out.dim() == 2: |
| | out = out.unsqueeze(1) |
| | |
| | return out, conv_state, ssm_state |
| | |
| |
|
| | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
| | device = self.out_proj.weight.device |
| | conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype |
| | conv_state = torch.zeros( |
| | batch_size, self.d_conv, self.conv1d.weight.shape[0], device=device, dtype=conv_dtype |
| | ).transpose(1, 2) |
| | ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype |
| | ssm_state = torch.zeros( |
| | batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype |
| | ) |
| | return conv_state, ssm_state |
| |
|
| | def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): |
| | assert self.layer_idx is not None |
| | if self.layer_idx not in inference_params.key_value_memory_dict: |
| | batch_shape = (batch_size,) |
| | conv_state = torch.zeros( |
| | batch_size, |
| | self.d_conv, |
| | self.conv1d.weight.shape[0], |
| | device=self.conv1d.weight.device, |
| | dtype=self.conv1d.weight.dtype, |
| | ).transpose(1, 2) |
| | ssm_state = torch.zeros( |
| | batch_size, |
| | self.nheads, |
| | self.headdim, |
| | self.d_state, |
| | device=self.in_proj.weight.device, |
| | dtype=self.in_proj.weight.dtype, |
| | ) |
| | inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) |
| | else: |
| | conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] |
| | |
| | if initialize_states: |
| | conv_state.zero_() |
| | ssm_state.zero_() |
| | return conv_state, ssm_state |