| |
|
|
| from dataclasses import dataclass, field |
| from enum import Enum |
| from typing import Optional, Tuple |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states |
|
|
| from apps.mamba.component.causal_conv1d_compilable import ( |
| causal_conv1d_fn, |
| causal_conv1d_update, |
| ) |
| from apps.mamba.component.ssm_compilable import mamba_chunk_scan_combined |
| from mamba_ssm.ops.triton.selective_state_update import selective_state_update |
|
|
| from lingua.transformer import InitStdFactor, RMSNorm |
| from lingua.probe import log_stats |
|
|
|
|
| @dataclass |
| class InitArgs: |
| dt_max: float = 0.1 |
| dt_min: float = 0.001 |
|
|
| dt_init_floor: float = 1e-4 |
|
|
| A_init_min: float = 1 |
| A_init_max: float = 16 |
|
|
|
|
| @dataclass |
| class BaseMambaArgs: |
|
|
| dim: int = 512 |
| n_layers: int = 8 |
| n_heads: int = 8 |
|
|
| state_dim: int = 128 |
| n_groups: int = 1 |
| conv_size: Optional[int] = None |
|
|
| dt_bias: bool = False |
| D_has_head_dim: bool = False |
| learnable_init_states: bool = False |
|
|
| ssm_chunk_size: int = 256 |
|
|
| vocab_size: int = -1 |
|
|
| ffn_dim_multiplier: Optional[float] = None |
|
|
| multiple_of: int = 256 |
| """ |
| Enforces that the SwiGLU hidden layer size is a multiple |
| of large power of 2. |
| """ |
|
|
| norm_eps: float = 1e-5 |
|
|
| init_use_depth: bool = False |
| init_base_std: Optional[float] = None |
| init_std_factor: str = "disabled" |
|
|
| init_args: InitArgs = field(default_factory=InitArgs) |
| seed: int = 42 |
|
|
|
|
| class SSM(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| hidden_dim: int, |
| multiple_of: int, |
| ffn_dim_multiplier: Optional[float], |
| state_dim: int, |
| n_heads: int, |
| n_groups: int, |
| conv_size: Optional[int], |
| dt_bias: bool, |
| D_has_head_dim: Optional[bool], |
| learnable_init_states: bool, |
| dt_limit: Tuple[float, float] = (0.0, float("inf")), |
| |
| chunk_size=256, |
| ): |
| super().__init__() |
|
|
| self.dim = dim |
|
|
| hidden_dim = int(2 * hidden_dim / 3) |
| if ffn_dim_multiplier is not None: |
| hidden_dim = int(ffn_dim_multiplier * hidden_dim) |
| self.hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
| assert ( |
| self.hidden_dim % n_heads == 0 |
| ), f"Hidden dim must be divisible by n_heads: {self.hidden_dim} % {n_heads} != 0" |
|
|
| self.state_dim = state_dim |
| self.head_dim = self.hidden_dim // n_heads |
| self.n_heads = n_heads |
| self.n_groups = n_groups |
|
|
| self.dt_limit = dt_limit |
|
|
| self.chunk_size = chunk_size |
|
|
| |
| d_in_proj = ( |
| 2 * self.hidden_dim + 2 * self.n_groups * self.state_dim + self.n_heads |
| ) |
| self.in_proj = nn.Linear(dim, d_in_proj, bias=False) |
|
|
| self.conv_size = conv_size |
| self.conv_dim = None |
| if conv_size is not None: |
| self.conv_dim = self.hidden_dim + 2 * self.n_groups * self.state_dim |
| assert (self.conv_dim % 8 == 0) and ( |
| conv_size in [2, 3, 4] |
| ), f"Causal conv1d only supports conv_size in [2, 3, 4] and hidden_dim/head_dim % 8 == 0, got {self.conv_dim} and {conv_size}" |
| self.conv_dim = self.hidden_dim + 2 * self.n_groups * self.state_dim |
| self.conv_weight = nn.Parameter(torch.empty((self.conv_dim, conv_size))) |
|
|
| self.learnable_init_states = learnable_init_states |
| if learnable_init_states: |
| self.init_states = nn.Parameter( |
| torch.zeros(n_heads, self.head_dim, state_dim) |
| ) |
|
|
| self.dt_bias = None |
| if dt_bias: |
| self.dt_bias = nn.Parameter(torch.empty(n_heads)) |
| self.A_log = nn.Parameter(torch.empty(n_heads)) |
|
|
| if D_has_head_dim is None: |
| self.D = None |
| elif D_has_head_dim: |
| self.D = nn.Parameter(torch.ones(n_heads, self.head_dim)) |
| else: |
| self.D = nn.Parameter(torch.ones(n_heads)) |
|
|
| self.out_proj = nn.Linear(self.hidden_dim, self.dim, bias=False) |
|
|
| self.ssm_norm = RMSNorm(self.hidden_dim, eps=1e-5) |
|
|
| self.dt_limit_kwargs = ( |
| {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) |
| ) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| tok_idx: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| ssm_impl: str = "ssm", |
| ) -> torch.Tensor: |
| bsz, seq_len, _ = x.shape |
|
|
| zxbcdt = self.in_proj(x) |
|
|
| |
| if self.conv_size is not None: |
| z, xBC, dt = torch.split( |
| zxbcdt, |
| [ |
| self.hidden_dim, |
| self.hidden_dim + 2 * self.n_groups * self.state_dim, |
| self.n_heads, |
| ], |
| dim=-1, |
| ) |
|
|
| conv1d = log_stats(self.conv_weight, "conv1d.w") |
| xBC = log_stats(xBC, "conv1d.in") |
|
|
| if ssm_impl == "ssm": |
| if hasattr(self, "cache"): |
| conv_varlen_states = causal_conv1d_varlen_states( |
| xBC.squeeze(0), |
| cu_seqlens, |
| state_len=self.cache.conv_cache.shape[-1], |
| ) |
| self.cache.conv_cache.copy_(conv_varlen_states) |
|
|
| xBC = causal_conv1d_fn( |
| x=xBC.transpose(1, 2), |
| weight=conv1d, |
| bias=None, |
| activation="silu", |
| seq_idx=tok_idx, |
| ).transpose(1, 2) |
|
|
| elif ssm_impl == "ssm_update": |
| xBC = causal_conv1d_update( |
| x=xBC.squeeze(0), |
| conv_state=self.cache.conv_cache, |
| weight=self.conv_weight, |
| bias=None, |
| activation="silu", |
| ).unsqueeze(0) |
|
|
| else: |
| raise NotImplementedError( |
| f"SSM implementation {ssm_impl} not supported" |
| ) |
|
|
| xBC = log_stats(xBC, "conv1d.out") |
|
|
| x, B, C = torch.split( |
| xBC, |
| [ |
| self.hidden_dim, |
| self.n_groups * self.state_dim, |
| self.n_groups * self.state_dim, |
| ], |
| dim=-1, |
| ) |
| else: |
| z, x, B, C, dt = torch.split( |
| zxbcdt, |
| [ |
| self.hidden_dim, |
| self.hidden_dim, |
| self.n_groups * self.state_dim, |
| self.n_groups * self.state_dim, |
| self.n_heads, |
| ], |
| dim=-1, |
| ) |
|
|
| initial_states = None |
| if self.learnable_init_states: |
| initial_states = self.init_states.expand(bsz, -1, -1, -1) |
|
|
| x = x.view( |
| bsz, seq_len, self.n_heads, self.head_dim |
| ) |
|
|
| A_log = log_stats(self.A_log, "A_log") |
| A = -torch.exp(A_log.float()) |
| B = B.view( |
| bsz, seq_len, self.n_groups, self.state_dim |
| ) |
| C = C.view( |
| bsz, seq_len, self.n_groups, self.state_dim |
| ) |
|
|
| A, B, C = log_stats(A, "A"), log_stats(B, "B"), log_stats(C, "C") |
|
|
| if ssm_impl == "ssm": |
| y = mamba_chunk_scan_combined( |
| x, |
| dt, |
| A, |
| B, |
| C, |
| dt_bias=self.dt_bias, |
| dt_softplus=True, |
| chunk_size=self.chunk_size, |
| D=self.D, |
| z=None, |
| seq_idx=tok_idx, |
| cu_seqlens=cu_seqlens, |
| initial_states=initial_states, |
| **self.dt_limit_kwargs, |
| ) |
|
|
| if hasattr(self, "cache"): |
| y, varlen_states = y |
| self.cache.state_cache.copy_(varlen_states) |
|
|
| elif ssm_impl == "ssm_update": |
| x = x.squeeze(0) |
| A = A[..., None, None].expand(self.n_heads, self.head_dim, self.state_dim) |
| dt = dt.permute(1, 2, 0).expand(seq_len, self.n_heads, self.head_dim) |
| D = self.D |
| if D is not None and D.dim() == 1: |
| D = D.unsqueeze(1).expand(self.n_heads, self.head_dim) |
| B, C = B.squeeze(0), C.squeeze(0) |
| y = selective_state_update( |
| self.cache.state_cache, |
| x, |
| dt, |
| A, |
| B, |
| C, |
| D, |
| z=None, |
| dt_bias=( |
| torch.zeros(self.n_heads, self.head_dim).to(x) |
| if self.dt_bias is None |
| else self.dt_bias.unsqueeze(1).expand(self.n_heads, self.head_dim) |
| ), |
| dt_softplus=True, |
| ).unsqueeze(0) |
|
|
| else: |
| raise NotImplementedError(f"SSM implementation {ssm_impl} not supported") |
|
|
| y = y.view(bsz, seq_len, self.hidden_dim) |
|
|
| |
| y = log_stats(y, "ssm_out") |
| |
| y = self.ssm_norm(y * F.silu(z)) |
|
|
| out = self.out_proj(y) |
|
|
| return out |
|
|
| def reset_parameters(self, init_std, factor, init_args: InitArgs): |
| |
| in_init_std = init_std or (self.dim ** (-0.5)) |
| out_init_std = init_std or (self.hidden_dim ** (-0.5)) |
| out_init_std = out_init_std / factor |
|
|
| nn.init.trunc_normal_( |
| self.in_proj.weight, |
| mean=0.0, |
| std=in_init_std, |
| a=-3 * in_init_std, |
| b=3 * in_init_std, |
| ) |
|
|
| nn.init.trunc_normal_( |
| self.out_proj.weight, |
| mean=0.0, |
| std=out_init_std, |
| a=-3 * out_init_std, |
| b=3 * out_init_std, |
| ) |
|
|
| |
| if self.dt_bias is not None: |
| self.dt_bias.uniform_(init_args.dt_min, init_args.dt_max) |
| self.dt_bias.clamp_(min=init_args.dt_init_floor) |
| |
| self.dt_bias.data = self.dt_bias.data + torch.log( |
| -torch.expm1(-self.dt_bias.data) |
| ) |
|
|
| if self.conv_size is not None: |
| conv_std = init_std or (self.conv_size ** (-0.5)) |
| nn.init.trunc_normal_( |
| self.conv_weight, |
| mean=0.0, |
| std=conv_std, |
| a=-3 * conv_std, |
| b=3 * conv_std, |
| ) |
|
|
| if self.learnable_init_states: |
| self.init_states.zero_() |
|
|
| |
| self.A_log.uniform_(init_args.A_init_min, init_args.A_init_max) |
| self.A_log.log_() |
|
|
| self.D.data.fill_(1.0) |
| self.ssm_norm.reset_parameters() |
|
|
|
|
| class MambaBlock(nn.Module): |
| def __init__(self, args: BaseMambaArgs): |
| super().__init__() |
|
|
| self.ssm_norm = RMSNorm(args.dim, args.norm_eps) |
| self.ssm = SSM( |
| dim=args.dim, |
| hidden_dim=3 * args.dim, |
| multiple_of=args.multiple_of, |
| ffn_dim_multiplier=args.ffn_dim_multiplier, |
| state_dim=args.state_dim, |
| n_heads=args.n_heads, |
| n_groups=args.n_groups, |
| conv_size=args.conv_size, |
| dt_bias=args.dt_bias, |
| D_has_head_dim=args.D_has_head_dim, |
| learnable_init_states=args.learnable_init_states, |
| chunk_size=args.ssm_chunk_size, |
| ) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| tok_idx: Optional[torch.Tensor], |
| cu_seqlens: Optional[torch.Tensor], |
| ssm_impl: str = "ssm", |
| ) -> torch.Tensor: |
| x = x + self.ssm( |
| self.ssm_norm(x), tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl |
| ) |
| return x |
|
|
| def init_weights(self, init_std=None, factor=1.0, init_args: InitArgs = InitArgs()): |
| self.ssm_norm.reset_parameters() |
| self.ssm.reset_parameters(init_std, factor, init_args) |
|
|
|
|
| class BaseMamba(nn.Module): |
| def __init__(self, args: BaseMambaArgs): |
| super().__init__() |
| self.model_dim = args.dim |
| self.init_base_std = args.init_base_std |
|
|
| self.init_args = args.init_args |
| self.init_std_factor = InitStdFactor(args.init_std_factor) |
|
|
| self.layers = nn.ModuleList() |
| for _ in range(args.n_layers): |
| self.layers.append(MambaBlock(args)) |
|
|
| def forward( |
| self, |
| h: torch.Tensor, |
| tok_idx: Optional[torch.Tensor], |
| cu_seqlens: Optional[torch.Tensor], |
| ssm_impl: str = "ssm", |
| ) -> torch.Tensor: |
| for layer in self.layers: |
| h = layer(h, tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl) |
| return h |
|
|
| def reset_parameters(self): |
| pass |
|
|
| def init_weights(self): |
| self.reset_parameters() |
| for depth, layer in enumerate(self.layers): |
| factor = { |
| InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, |
| InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, |
| InitStdFactor.DIM_RATIO: self.model_dim / 4096, |
| InitStdFactor.DISABLED: 1.0, |
| }[self.init_std_factor] |
|
|
| layer.init_weights(self.init_base_std, factor) |
|
|