| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """PyTorch MAMBA model.""" |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Any, Dict, Optional, Tuple, Union |
|
|
| import torch |
| import torch.utils.checkpoint |
| from torch import nn |
| from transformers.activations import ACT2FN |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import ModelOutput, logging |
|
|
| from fla.models.mamba.configuration_mamba import MambaConfig |
| from fla.modules import FusedCrossEntropyLoss, RMSNorm |
|
|
| logger = logging.get_logger(__name__) |
|
|
| try: |
| from mamba_ssm.ops.selective_scan_interface import (mamba_inner_fn, |
| selective_scan_fn) |
| from mamba_ssm.ops.triton.selective_state_update import \ |
| selective_state_update |
| except ImportError: |
| selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None |
|
|
| try: |
| from causal_conv1d import causal_conv1d_fn, causal_conv1d_update |
| except ImportError: |
| causal_conv1d_update, causal_conv1d_fn = None, None |
|
|
| is_fast_path_available = all( |
| (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) |
| ) |
|
|
|
|
| class MambaCache: |
| def __init__(self, config, batch_size, dtype=torch.float16, device=None): |
| self.seqlen_offset = 0 |
| self.dtype = dtype |
| intermediate_size = config.intermediate_size |
| ssm_state_size = config.state_size |
| conv_kernel_size = config.conv_kernel |
|
|
| self.conv_states = { |
| i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) |
| for i in range(config.num_hidden_layers) |
| } |
| self.ssm_states = { |
| i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) |
| for i in range(config.num_hidden_layers) |
| } |
|
|
|
|
| class MambaMixer(nn.Module): |
| """ |
| Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. |
| A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) |
| ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, |
| and is why Mamba is called **selective** state spaces) |
| """ |
|
|
| def __init__(self, config, layer_idx): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.ssm_state_size = config.state_size |
| self.conv_kernel_size = config.conv_kernel |
| self.intermediate_size = config.intermediate_size |
| self.time_step_rank = config.time_step_rank |
| self.layer_idx = layer_idx |
| self.use_conv_bias = config.use_conv_bias |
| self.conv1d = nn.Conv1d( |
| in_channels=self.intermediate_size, |
| out_channels=self.intermediate_size, |
| bias=config.use_conv_bias, |
| kernel_size=config.conv_kernel, |
| groups=self.intermediate_size, |
| padding=config.conv_kernel - 1, |
| ) |
|
|
| self.activation = config.hidden_act |
| self.act = ACT2FN[config.hidden_act] |
|
|
| |
| self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias) |
| |
| self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) |
| |
| self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) |
|
|
| |
| |
| A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] |
| A = A.expand(self.intermediate_size, -1).contiguous() |
|
|
| self.A_log = nn.Parameter(torch.log(A)) |
| self.D = nn.Parameter(torch.ones(self.intermediate_size)) |
| self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) |
| self.use_bias = config.use_bias |
|
|
| if not is_fast_path_available: |
| logger.warning_once( |
| "The fast path is not available because on of " |
| "`(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" |
| " is None. Falling back to the naive implementation. " |
| "To install follow https://github.com/state-spaces/mamba/#installation and" |
| " https://github.com/Dao-AILab/causal-conv1d" |
| ) |
|
|
| def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Optional[MambaCache] = None): |
| |
| projected_states = self.in_proj(hidden_states).transpose(1, 2) |
|
|
| if self.training and cache_params is None: |
| contextualized_states = mamba_inner_fn( |
| projected_states, |
| self.conv1d.weight, |
| self.conv1d.bias if self.use_conv_bias else None, |
| self.x_proj.weight, |
| self.dt_proj.weight, |
| self.out_proj.weight, |
| self.out_proj.bias.float() if self.use_bias else None, |
| -torch.exp(self.A_log.float()), |
| None, |
| None, |
| self.D.float(), |
| delta_bias=self.dt_proj.bias.float(), |
| delta_softplus=True, |
| ) |
|
|
| else: |
| hidden_states, gate = projected_states.chunk(2, dim=1) |
|
|
| |
| conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) |
| if cache_params is not None and cache_params.seqlen_offset > 0: |
| hidden_states = causal_conv1d_update( |
| hidden_states.squeeze(-1), |
| cache_params.conv_states[self.layer_idx], |
| conv_weights, |
| self.conv1d.bias, |
| self.activation, |
| ) |
| hidden_states = hidden_states.unsqueeze(-1) |
| else: |
| if cache_params is not None: |
| conv_states = nn.functional.pad( |
| hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) |
| ) |
| cache_params.conv_states[self.layer_idx].copy_(conv_states) |
| hidden_states = causal_conv1d_fn( |
| hidden_states, conv_weights, self.conv1d.bias, activation=self.activation |
| ) |
|
|
| |
| |
| ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) |
| time_step, B, C = torch.split( |
| ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 |
| ) |
| discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) |
|
|
| A = -torch.exp(self.A_log.float()) |
| |
| time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None |
| if cache_params is not None and cache_params.seqlen_offset > 0: |
| scan_outputs = selective_state_update( |
| cache_params.ssm_states[self.layer_idx], |
| hidden_states[..., 0], |
| discrete_time_step[..., 0], |
| A, |
| B[:, 0], |
| C[:, 0], |
| self.D, |
| gate[..., 0], |
| time_proj_bias, |
| dt_softplus=True, |
| ).unsqueeze(-1) |
| else: |
| scan_outputs, ssm_state = selective_scan_fn( |
| hidden_states, |
| discrete_time_step, |
| A, |
| B.transpose(1, 2), |
| C.transpose(1, 2), |
| self.D.float(), |
| gate, |
| time_proj_bias, |
| delta_softplus=True, |
| return_last_state=True, |
| ) |
| if ssm_state is not None and cache_params is not None: |
| cache_params.ssm_states[self.layer_idx].copy_(ssm_state) |
|
|
| |
| contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) |
| return contextualized_states |
|
|
| |
| def slow_forward(self, input_states, cache_params: Optional[MambaCache] = None): |
| batch_size, seq_len, _ = input_states.shape |
| dtype = input_states.dtype |
| |
| |
| projected_states = self.in_proj(input_states).transpose(1, 2) |
| hidden_states, gate = projected_states.chunk(2, dim=1) |
|
|
| |
| if cache_params is not None: |
| ssm_state = cache_params.ssm_states[self.layer_idx].clone() |
| if cache_params.seqlen_offset > 0: |
| |
| conv_state = cache_params.conv_states[self.layer_idx] |
| conv_state = torch.roll(conv_state, shifts=-1, dims=-1) |
| conv_state[:, :, -1] = hidden_states[:, :, 0] |
| cache_params.conv_states[self.layer_idx].copy_(conv_state) |
| hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) |
| if self.use_conv_bias: |
| hidden_states += self.conv1d.bias |
| |
| hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) |
| else: |
| conv_state = nn.functional.pad( |
| hidden_states, |
| (self.conv_kernel_size - hidden_states.shape[-1], 0) |
| ) |
| cache_params.conv_states[self.layer_idx].copy_(conv_state) |
| |
| hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) |
| else: |
| ssm_state = torch.zeros( |
| (batch_size, self.intermediate_size, self.ssm_state_size), |
| device=hidden_states.device, dtype=dtype |
| ) |
| |
| hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) |
|
|
| |
| |
| ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) |
| time_step, B, C = torch.split( |
| ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 |
| ) |
| |
| discrete_time_step = self.dt_proj(time_step) |
| |
| discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) |
|
|
| |
| |
| A = -torch.exp(self.A_log.float()) |
| |
| discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) |
| |
| discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() |
| deltaB_u = discrete_B * hidden_states[:, :, :, None].float() |
|
|
| |
| scan_outputs = [] |
| for i in range(seq_len): |
| |
| ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] |
| |
| scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) |
| scan_outputs.append(scan_output[:, :, 0]) |
| |
| scan_output = torch.stack(scan_outputs, dim=-1) |
| scan_output = scan_output + (hidden_states * self.D[None, :, None]) |
| scan_output = (scan_output * self.act(gate)) |
|
|
| if cache_params is not None: |
| cache_params.ssm_states[self.layer_idx].copy_(ssm_state) |
|
|
| |
| |
| contextualized_states = self.out_proj(scan_output.transpose(1, 2)) |
| return contextualized_states |
| |
|
|
| def forward(self, hidden_states, cache_params: Optional[MambaCache] = None): |
| if is_fast_path_available and "cuda" in self.x_proj.weight.device.type: |
| return self.cuda_kernels_forward(hidden_states, cache_params) |
| return self.slow_forward(hidden_states, cache_params) |
|
|
|
|
| class MambaBlock(nn.Module): |
| def __init__(self, config, layer_idx): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.residual_in_fp32 = config.residual_in_fp32 |
| self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) |
| self.mixer = MambaMixer(config, layer_idx=layer_idx) |
|
|
| def forward(self, hidden_states, cache_params: Optional[MambaCache] = None): |
| residual = hidden_states |
| hidden_states = self.norm(hidden_states) |
| |
| |
| hidden_states = self.mixer(hidden_states, cache_params=cache_params) |
| hidden_states = residual + hidden_states |
| return hidden_states |
|
|
|
|
| class MambaPreTrainedModel(PreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| models. |
| """ |
|
|
| config_class = MambaConfig |
| base_model_prefix = "backbone" |
| _no_split_modules = ["MambaBlock"] |
| supports_gradient_checkpointing = True |
|
|
| def _init_weights(self, module): |
| """Initialize the weights.""" |
| if isinstance(module, MambaMixer): |
| module.A_log._no_weight_decay = True |
| module.D._no_weight_decay = True |
|
|
| dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale |
| if self.config.time_step_init_scheme == "constant": |
| nn.init.constant_(module.dt_proj.weight, dt_init_std) |
| elif self.config.time_step_init_scheme == "random": |
| nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std) |
|
|
| dt = torch.exp( |
| torch.rand(self.config.intermediate_size) |
| * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) |
| + math.log(self.config.time_step_min) |
| ).clamp(min=self.config.time_step_floor) |
| |
| inv_dt = dt + torch.log(-torch.expm1(-dt)) |
| with torch.no_grad(): |
| module.dt_proj.bias.copy_(inv_dt) |
| module.dt_proj.bias._no_reinit = True |
|
|
| if isinstance(module, nn.Linear): |
| if module.bias is not None: |
| if not getattr(module.bias, "_no_reinit", False): |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, std=self.config.initializer_range) |
|
|
| if self.config.rescale_prenorm_residual: |
| |
| |
| |
| |
| |
| |
| for name, p in module.named_parameters(): |
| if name in ["out_proj.weight"]: |
| |
| |
| |
| |
| nn.init.kaiming_uniform_(p, a=math.sqrt(5)) |
| with torch.no_grad(): |
| p /= math.sqrt(self.config.num_layers) |
|
|
|
|
| @dataclass |
| class MambaOutput(ModelOutput): |
| """ |
| Class for the MAMBA model outputs. |
| |
| Args: |
| last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| Sequence of hidden-states at the output of the last layer of the model. |
| cache_params (`MambaCache`): |
| The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to |
| avoid providing the old `input_ids`. |
| |
| Includes both the State space model state matrices after the selective scan, and the Convolutional states |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, |
| returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
| |
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
| """ |
|
|
| last_hidden_state: Optional[torch.FloatTensor] = None |
| cache_params: Optional[MambaCache] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
| @dataclass |
| class MambaCausalLMOutput(ModelOutput): |
| """ |
| Base class for causal language model (or autoregressive) outputs. |
| |
| Args: |
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| Language modeling loss (for next-token prediction). |
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
| cache_params (`MambaCache`): |
| The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to |
| avoid providing the old `input_ids`. |
| |
| Includes both the State space model state matrices after the selective scan, and the Convolutional states |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, |
| returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
| |
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
| """ |
|
|
| loss: Optional[torch.FloatTensor] = None |
| logits: Optional[torch.FloatTensor] = None |
| cache_params: Optional[MambaCache] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
| class MambaModel(MambaPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) |
|
|
| self.gradient_checkpointing = False |
| self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embeddings |
|
|
| def set_input_embeddings(self, new_embeddings): |
| self.embeddings = new_embeddings |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.LongTensor] = None, |
| cache_params: Optional[MambaCache] = None, |
| use_cache: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs, |
| ) -> Union[Tuple, MambaOutput]: |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError( |
| "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
| ) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embeddings(input_ids) |
|
|
| if self.gradient_checkpointing and self.training and use_cache: |
| use_cache = False |
|
|
| if cache_params is None and use_cache: |
| cache_params = MambaCache( |
| self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype |
| ) |
|
|
| hidden_states = inputs_embeds |
| all_hidden_states = () if output_hidden_states else None |
| for mixer_block in self.layers: |
| if self.gradient_checkpointing and self.training: |
| hidden_states = self._gradient_checkpointing_func(mixer_block.__call__, hidden_states, cache_params) |
| else: |
| hidden_states = mixer_block(hidden_states, cache_params=cache_params) |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if use_cache: |
| cache_params.seqlen_offset += inputs_embeds.shape[1] |
|
|
| hidden_states = self.norm_f(hidden_states) |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) |
|
|
| return MambaOutput( |
| last_hidden_state=hidden_states, |
| cache_params=cache_params if use_cache else None, |
| hidden_states=all_hidden_states, |
| ) |
|
|
|
|
| class MambaForCausalLM(MambaPreTrainedModel): |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.backbone = MambaModel(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
| self.post_init() |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def get_input_embeddings(self): |
| return self.backbone.get_input_embeddings() |
|
|
| def set_input_embeddings(self, new_embeddings): |
| return self.backbone.set_input_embeddings(new_embeddings) |
|
|
| def _update_model_kwargs_for_generation( |
| self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs |
| ) -> Dict[str, Any]: |
| model_kwargs["cache_params"] = outputs.get("cache_params", None) |
| return model_kwargs |
|
|
| def prepare_inputs_for_generation( |
| self, input_ids, cache_params: Optional[MambaCache] = None, inputs_embeds=None, attention_mask=None, **kwargs |
| ): |
| |
| if cache_params is not None: |
| input_ids = input_ids[:, -1].unsqueeze(-1) |
|
|
| if inputs_embeds is not None and cache_params is None: |
| model_inputs = {"inputs_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| model_inputs["cache_params"] = cache_params |
| return model_inputs |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| cache_params: Optional[MambaCache] = None, |
| labels: Optional[torch.LongTensor] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| use_cache: Optional[bool] = None, |
| **kwargs, |
| ) -> Union[Tuple, MambaCausalLMOutput]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set |
| `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` |
| are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| mamba_outputs = self.backbone( |
| input_ids, |
| cache_params=cache_params, |
| inputs_embeds=inputs_embeds, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| use_cache=use_cache, |
| ) |
| hidden_states = mamba_outputs[0] |
| logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| if self.config.fuse_cross_entropy: |
| loss_fct = FusedCrossEntropyLoss(inplace_backward=True) |
| else: |
| loss_fct = nn.CrossEntropyLoss() |
| |
| labels = labels.to(logits.device) |
| labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) |
| loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
| if not return_dict: |
| output = (logits,) + mamba_outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return MambaCausalLMOutput( |
| loss=loss, |
| logits=logits, |
| cache_params=mamba_outputs.cache_params, |
| hidden_states=mamba_outputs.hidden_states, |
| ) |
|
|