| import torch |
| import numpy as np |
| from torch import nn |
| from torch.nn import functional as F |
| import einops |
| from einops.layers.torch import Rearrange |
| import math |
|
|
|
|
| class LinearHeadwiseExpand(nn.Module): |
| |
|
|
| def __init__(self, dim, num_heads, bias=False): |
| super().__init__() |
| assert dim % num_heads == 0 |
| self.dim = dim |
| self.num_heads = num_heads |
|
|
| dim_per_head = dim // num_heads |
| self.weight = nn.Parameter(torch.empty(num_heads, dim_per_head, dim_per_head)) |
| if bias: |
| self.bias = nn.Parameter(torch.empty(dim)) |
| else: |
| self.bias = None |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| nn.init.normal_(self.weight.data, mean=0.0, std=math.sqrt(2 / 5 / self.weight.shape[-1])) |
| if self.bias is not None: |
| nn.init.zeros_(self.bias.data) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = einops.rearrange(x, "... (nh d) -> ... nh d", nh=self.num_heads) |
| x = einops.einsum( |
| x, |
| self.weight, |
| "... nh d, nh out_d d -> ... nh out_d", |
| ) |
| x = einops.rearrange(x, "... nh out_d -> ... (nh out_d)") |
| if self.bias is not None: |
| x = x + self.bias |
| return x |
|
|
| def extra_repr(self): |
| return ( |
| f"dim={self.dim}, " |
| f"num_heads={self.num_heads}, " |
| f"bias={self.bias is not None}, " |
| ) |
|
|
|
|
| def wang_init_(param: torch.Tensor, dim: int, num_blocks: int): |
| |
| std = 2 / num_blocks / math.sqrt(dim) |
| torch.nn.init.normal_(param, mean=0.0, std=std) |
| return param |
|
|
| def small_init_(param: torch.Tensor, dim: int) -> torch.Tensor: |
|
|
| std = math.sqrt(2 / (5 * dim)) |
| torch.nn.init.normal_(param, mean=0.0, std=std) |
| return param |
|
|
|
|
| def bias_linspace_init_(param: torch.Tensor, start: float = 3.4, end: float = 6.0) -> torch.Tensor: |
|
|
| assert param.dim() == 1, f"param must be 1-dimensional (typically a bias), got {param.dim()}" |
| n_dims = param.shape[0] |
| init_vals = torch.linspace(start, end, n_dims) |
| with torch.no_grad(): |
| param.copy_(init_vals) |
| return param |
|
|
|
|
|
|
| class CausalConv1d(nn.Module): |
| |
|
|
| def __init__(self, dim, kernel_size=4, bias=True): |
| super().__init__() |
| self.dim = dim |
| self.kernel_size = kernel_size |
| self.bias = bias |
| |
| self.pad = kernel_size - 1 |
| self.conv = nn.Conv1d( |
| in_channels=dim, |
| out_channels=dim, |
| kernel_size=kernel_size, |
| padding=self.pad, |
| groups=dim, |
| bias=bias, |
| ) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| self.conv.reset_parameters() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x = einops.rearrange(x, "b l d -> b d l") |
| |
| x = self.conv(x) |
| x = x[:, :, :-self.pad] |
| |
| x = einops.rearrange(x, "b d l -> b l d") |
| return x |
|
|
| class LayerNorm(nn.Module): |
| |
| def __init__( |
| self, |
| ndim: int = -1, |
| weight: bool = True, |
| bias: bool = False, |
| eps: float = 1e-5, |
| residual_weight: bool = True, |
| ): |
| super().__init__() |
| self.weight = nn.Parameter(torch.zeros(ndim)) if weight else None |
| self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
| self.eps = eps |
| self.residual_weight = residual_weight |
| self.ndim = ndim |
| self.reset_parameters() |
|
|
| @property |
| def weight_proxy(self) -> torch.Tensor: |
| if self.weight is None: |
| return None |
| if self.residual_weight: |
| return 1.0 + self.weight |
| else: |
| return self.weight |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return F.layer_norm( |
| x, |
| normalized_shape=(self.ndim,), |
| weight=self.weight_proxy, |
| bias=self.bias, |
| eps=self.eps, |
| ) |
|
|
| def reset_parameters(self): |
| if self.weight_proxy is not None: |
| if self.residual_weight: |
| nn.init.zeros_(self.weight) |
| else: |
| nn.init.ones_(self.weight) |
| if self.bias is not None: |
| nn.init.zeros_(self.bias) |
|
|
| class MultiHeadLayerNorm(LayerNorm): |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| assert x.ndim == 4, "Input must be 4D tensor (B, NH, S, DH)" |
| B, NH, S, DH = x.shape |
|
|
| gn_in_1 = x.transpose(1, 2) |
| gn_in_2 = gn_in_1.reshape(B * S, NH * DH) |
| out = F.group_norm( |
| gn_in_2, |
| num_groups=NH, |
| weight=self.weight_proxy, |
| bias=self.bias, |
| eps=self.eps, |
| ) |
| |
| out = out.view(B, S, NH, DH).transpose(1, 2) |
| return out |
|
|
|
|
| def parallel_stabilized_simple( |
| queries: torch.Tensor, |
| keys: torch.Tensor, |
| values: torch.Tensor, |
| igate_preact: torch.Tensor, |
| fgate_preact: torch.Tensor, |
| lower_triangular_matrix: torch.Tensor = None, |
| stabilize_rowwise: bool = True, |
| eps: float = 1e-6, |
| ) -> torch.Tensor: |
| |
|
|
| B, NH, S, DH = queries.shape |
| _dtype, _device = queries.dtype, queries.device |
|
|
| |
| log_fgates = torch.nn.functional.logsigmoid(fgate_preact) |
| if lower_triangular_matrix is None or S < lower_triangular_matrix.size(-1): |
| ltr = torch.tril(torch.ones((S, S), dtype=torch.bool, device=_device)) |
| else: |
| ltr = lower_triangular_matrix |
| assert ltr.dtype == torch.bool, f"lower_triangular_matrix must be of dtype bool, got {ltr.dtype}" |
|
|
| log_fgates_cumsum = torch.cat( |
| [ |
| torch.zeros((B, NH, 1, 1), dtype=_dtype, device=_device), |
| torch.cumsum(log_fgates, dim=-2), |
| ], |
| dim=-2, |
| ) |
| |
| rep_log_fgates_cumsum = log_fgates_cumsum.repeat(1, 1, 1, S + 1) |
| |
| _log_fg_matrix = rep_log_fgates_cumsum - rep_log_fgates_cumsum.transpose(-2, -1) |
| |
| log_fg_matrix = torch.where(ltr, _log_fg_matrix[:, :, 1:, 1:], -float("inf")) |
|
|
| |
| log_D_matrix = log_fg_matrix + igate_preact.transpose(-2, -1) |
| |
| if stabilize_rowwise: |
| max_log_D, _ = torch.max(log_D_matrix, dim=-1, keepdim=True) |
| else: |
| max_log_D = torch.max(log_D_matrix.view(B, NH, -1), dim=-1, keepdim=True)[0].unsqueeze(-1) |
| |
| log_D_matrix_stabilized = log_D_matrix - max_log_D |
| D_matrix = torch.exp(log_D_matrix_stabilized) |
|
|
| keys_scaled = keys / math.sqrt(DH) |
|
|
| |
| qk_matrix = queries @ keys_scaled.transpose(-2, -1) |
| C_matrix = qk_matrix * D_matrix |
| normalizer = torch.maximum(C_matrix.sum(dim=-1, keepdim=True).abs(), torch.exp(-max_log_D)) |
| |
| C_matrix_normalized = C_matrix / (normalizer + eps) |
|
|
| |
| h_tilde_state = C_matrix_normalized @ values |
| return h_tilde_state |
|
|
| class MatrixLSTMCell(nn.Module): |
| def __init__(self, dim, num_heads): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
|
|
| self.igate = nn.Linear(3 * dim, num_heads) |
| self.fgate = nn.Linear(3 * dim, num_heads) |
| self.outnorm = MultiHeadLayerNorm(ndim=dim, weight=True, bias=False) |
| self.causal_mask_cache = {} |
| self.reset_parameters() |
|
|
| def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: |
| B, S, _ = q.shape |
|
|
| if_gate_input = torch.cat([q, k, v], dim=-1) |
| q = q.view(B, S, self.num_heads, -1) |
| k = k.view(B, S, self.num_heads, -1) |
| v = v.view(B, S, self.num_heads, -1) |
|
|
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| |
| igate_preact = self.igate(if_gate_input) |
| igate_preact = igate_preact.transpose(-1, -2).unsqueeze(-1) |
| fgate_preact = self.fgate(if_gate_input) |
| fgate_preact = fgate_preact.transpose(-1, -2).unsqueeze(-1) |
|
|
| |
| if S in self.causal_mask_cache: |
| causal_mask = self.causal_mask_cache[(S, str(q.device))] |
| else: |
| causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device=q.device)) |
| self.causal_mask_cache[(S, str(q.device))] = causal_mask |
|
|
| h_state = parallel_stabilized_simple( |
| queries=q, |
| keys=k, |
| values=v, |
| igate_preact=igate_preact, |
| fgate_preact=fgate_preact, |
| lower_triangular_matrix=causal_mask, |
| ) |
|
|
| h_state_norm = self.outnorm(h_state) |
| h_state_norm = h_state_norm.transpose(1, 2).reshape(B, S, -1) |
| return h_state_norm |
|
|
| def reset_parameters(self): |
| self.outnorm.reset_parameters() |
| |
| torch.nn.init.zeros_(self.fgate.weight) |
| bias_linspace_init_(self.fgate.bias, start=3.0, end=6.0) |
| |
| torch.nn.init.zeros_(self.igate.weight) |
| torch.nn.init.normal_(self.igate.bias, mean=0.0, std=0.1) |
|
|
| class mLSTM(nn.Module): |
| def __init__( |
| self, |
| dim, |
| |
| expansion=2, |
| qkv_block_size=4, |
| proj_bias=False, |
| conv_bias=True, |
| kernel_size=4, |
| ): |
| super().__init__() |
| assert dim % qkv_block_size == 0 |
| self.dim = dim |
| |
| self.expansion = expansion |
| self.qkv_block_size = qkv_block_size |
| self.proj_bias = proj_bias |
| self.conv_bias = conv_bias |
| self.kernel_size = kernel_size |
|
|
| inner_dim = expansion * dim |
| num_heads = inner_dim // qkv_block_size |
| self.proj_up = nn.Linear( |
| in_features=dim, |
| out_features=2 * inner_dim, |
| bias=proj_bias, |
| ) |
| self.q_proj = LinearHeadwiseExpand( |
| dim=inner_dim, |
| num_heads=num_heads, |
| bias=proj_bias, |
| ) |
| self.k_proj = LinearHeadwiseExpand( |
| dim=inner_dim, |
| num_heads=num_heads, |
| bias=proj_bias, |
| ) |
| self.v_proj = LinearHeadwiseExpand( |
| dim=inner_dim, |
| num_heads=num_heads, |
| bias=proj_bias, |
| ) |
|
|
| self.conv1d = CausalConv1d( |
| dim=inner_dim, |
| kernel_size=kernel_size, |
| bias=conv_bias, |
| ) |
| self.mlstm_cell = MatrixLSTMCell( |
| dim=inner_dim, |
| num_heads=qkv_block_size, |
| ) |
| self.learnable_skip = nn.Parameter(torch.ones(inner_dim)) |
|
|
| self.proj_down = nn.Linear( |
| in_features=inner_dim, |
| out_features=dim, |
| bias=proj_bias, |
| ) |
| self.reset_parameters() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, S, _ = x.shape |
|
|
| |
| |
| x_inner = self.proj_up(x) |
| x_mlstm, z = torch.chunk(x_inner, chunks=2, dim=-1) |
|
|
| |
| x_mlstm_conv = self.conv1d(x_mlstm) |
| x_mlstm_conv_act = F.silu(x_mlstm_conv) |
| q = self.q_proj(x_mlstm_conv_act) |
| k = self.k_proj(x_mlstm_conv_act) |
| v = self.v_proj(x_mlstm) |
| h_tilde_state = self.mlstm_cell(q=q, k=k, v=v) |
| h_tilde_state_skip = h_tilde_state + (self.learnable_skip * x_mlstm_conv_act) |
|
|
| |
| h_state = h_tilde_state_skip * F.silu(z) |
|
|
| |
| x = self.proj_down(h_state) |
|
|
| |
|
|
| return x |
|
|
| def reset_parameters(self): |
| |
| small_init_(self.proj_up.weight, dim=self.dim) |
| if self.proj_up.bias is not None: |
| nn.init.zeros_(self.proj_up.bias) |
| |
| wang_init_(self.proj_down.weight, dim=self.dim, num_blocks=1) |
| if self.proj_down.bias is not None: |
| nn.init.zeros_(self.proj_down.bias) |
|
|
| nn.init.ones_(self.learnable_skip) |
|
|
| def _init_qkv_proj(qkv_proj: LinearHeadwiseExpand): |
| |
| small_init_(qkv_proj.weight, dim=self.dim) |
| if qkv_proj.bias is not None: |
| nn.init.zeros_(qkv_proj.bias) |
|
|
| _init_qkv_proj(self.q_proj) |
| _init_qkv_proj(self.k_proj) |
| _init_qkv_proj(self.v_proj) |
|
|
| self.mlstm_cell.reset_parameters() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, dim, hidden_dim, dropout): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(dim, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, dim), |
| nn.Dropout(dropout) |
| ) |
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
|
|
|
|
| class MCGatingUnit(nn.Module): |
| def __init__(self,dim, hidden_dim, dropout): |
| super().__init__() |
| self.mlstm_1 = mLSTM(dim) |
| self.mlstm_2 = mLSTM(dim) |
| |
|
|
| def forward(self, x): |
| u, v = x, x |
| u = self.mlstm_1(u) |
| v = self.mlstm_1(v) |
| out = u * v |
| return out |
|
|
|
|
| class MCDPMLSTMBlock(nn.Module): |
| def __init__(self, d_model, d_ffn,dropout): |
| super().__init__() |
| |
| self.norm = nn.LayerNorm(d_model) |
| self.mcgu = MCGatingUnit(d_model,d_ffn,dropout) |
| self.ffn = FeedForward(d_model,d_ffn,dropout) |
| def forward(self, x): |
| residual = x |
| x = self.norm(x) |
| x = self.mcgu(x) |
| x = x + residual |
| residual = x |
| x = self.norm(x) |
| x = self.ffn(x) |
| out = x + residual |
| return out |
|
|
|
|
| class MCDPMLSTM(nn.Module): |
| def __init__(self, d_model, d_ffn,num_layers,dropout): |
| super().__init__() |
| |
| self.model = nn.Sequential( |
| *[MCDPMLSTMBlock(d_model, d_ffn,dropout) for _ in range(num_layers)] |
| ) |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|