MCDP_MLSTM / mcdpmlstm.py
Abdullah-Nazhat's picture
Upload 2 files
71b7031 verified
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
import einops
from einops.layers.torch import Rearrange
import math
class LinearHeadwiseExpand(nn.Module):
def __init__(self, dim, num_heads, bias=False):
super().__init__()
assert dim % num_heads == 0
self.dim = dim
self.num_heads = num_heads
dim_per_head = dim // num_heads
self.weight = nn.Parameter(torch.empty(num_heads, dim_per_head, dim_per_head))
if bias:
self.bias = nn.Parameter(torch.empty(dim))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self):
nn.init.normal_(self.weight.data, mean=0.0, std=math.sqrt(2 / 5 / self.weight.shape[-1]))
if self.bias is not None:
nn.init.zeros_(self.bias.data)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = einops.rearrange(x, "... (nh d) -> ... nh d", nh=self.num_heads)
x = einops.einsum(
x,
self.weight,
"... nh d, nh out_d d -> ... nh out_d",
)
x = einops.rearrange(x, "... nh out_d -> ... (nh out_d)")
if self.bias is not None:
x = x + self.bias
return x
def extra_repr(self):
return (
f"dim={self.dim}, "
f"num_heads={self.num_heads}, "
f"bias={self.bias is not None}, "
)
def wang_init_(param: torch.Tensor, dim: int, num_blocks: int):
std = 2 / num_blocks / math.sqrt(dim)
torch.nn.init.normal_(param, mean=0.0, std=std)
return param
def small_init_(param: torch.Tensor, dim: int) -> torch.Tensor:
std = math.sqrt(2 / (5 * dim))
torch.nn.init.normal_(param, mean=0.0, std=std)
return param
def bias_linspace_init_(param: torch.Tensor, start: float = 3.4, end: float = 6.0) -> torch.Tensor:
assert param.dim() == 1, f"param must be 1-dimensional (typically a bias), got {param.dim()}"
n_dims = param.shape[0]
init_vals = torch.linspace(start, end, n_dims)
with torch.no_grad():
param.copy_(init_vals)
return param
class CausalConv1d(nn.Module):
def __init__(self, dim, kernel_size=4, bias=True):
super().__init__()
self.dim = dim
self.kernel_size = kernel_size
self.bias = bias
# padding of this size assures temporal causality.
self.pad = kernel_size - 1
self.conv = nn.Conv1d(
in_channels=dim,
out_channels=dim,
kernel_size=kernel_size,
padding=self.pad,
groups=dim,
bias=bias,
)
self.reset_parameters()
def reset_parameters(self):
self.conv.reset_parameters()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = einops.rearrange(x, "b l d -> b d l")
x = self.conv(x)
x = x[:, :, :-self.pad]
x = einops.rearrange(x, "b d l -> b l d")
return x
class LayerNorm(nn.Module):
def __init__(
self,
ndim: int = -1,
weight: bool = True,
bias: bool = False,
eps: float = 1e-5,
residual_weight: bool = True,
):
super().__init__()
self.weight = nn.Parameter(torch.zeros(ndim)) if weight else None
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
self.eps = eps
self.residual_weight = residual_weight
self.ndim = ndim
self.reset_parameters()
@property
def weight_proxy(self) -> torch.Tensor:
if self.weight is None:
return None
if self.residual_weight:
return 1.0 + self.weight
else:
return self.weight
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(
x,
normalized_shape=(self.ndim,),
weight=self.weight_proxy,
bias=self.bias,
eps=self.eps,
)
def reset_parameters(self):
if self.weight_proxy is not None:
if self.residual_weight:
nn.init.zeros_(self.weight)
else:
nn.init.ones_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
class MultiHeadLayerNorm(LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.ndim == 4, "Input must be 4D tensor (B, NH, S, DH)"
B, NH, S, DH = x.shape
gn_in_1 = x.transpose(1, 2) # (B, S, NH, DH)
gn_in_2 = gn_in_1.reshape(B * S, NH * DH) # (B * S, NH * DH)
out = F.group_norm(
gn_in_2,
num_groups=NH,
weight=self.weight_proxy,
bias=self.bias,
eps=self.eps,
) # .to(x.dtype)
out = out.view(B, S, NH, DH).transpose(1, 2)
return out
def parallel_stabilized_simple(
queries: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
igate_preact: torch.Tensor,
fgate_preact: torch.Tensor,
lower_triangular_matrix: torch.Tensor = None,
stabilize_rowwise: bool = True,
eps: float = 1e-6,
) -> torch.Tensor:
B, NH, S, DH = queries.shape
_dtype, _device = queries.dtype, queries.device
log_fgates = torch.nn.functional.logsigmoid(fgate_preact) # (B, NH, S, 1)
if lower_triangular_matrix is None or S < lower_triangular_matrix.size(-1):
ltr = torch.tril(torch.ones((S, S), dtype=torch.bool, device=_device))
else:
ltr = lower_triangular_matrix
assert ltr.dtype == torch.bool, f"lower_triangular_matrix must be of dtype bool, got {ltr.dtype}"
log_fgates_cumsum = torch.cat(
[
torch.zeros((B, NH, 1, 1), dtype=_dtype, device=_device),
torch.cumsum(log_fgates, dim=-2),
],
dim=-2,
)
rep_log_fgates_cumsum = log_fgates_cumsum.repeat(1, 1, 1, S + 1)
_log_fg_matrix = rep_log_fgates_cumsum - rep_log_fgates_cumsum.transpose(-2, -1)
log_fg_matrix = torch.where(ltr, _log_fg_matrix[:, :, 1:, 1:], -float("inf"))
log_D_matrix = log_fg_matrix + igate_preact.transpose(-2, -1)
if stabilize_rowwise:
max_log_D, _ = torch.max(log_D_matrix, dim=-1, keepdim=True)
else:
max_log_D = torch.max(log_D_matrix.view(B, NH, -1), dim=-1, keepdim=True)[0].unsqueeze(-1)
log_D_matrix_stabilized = log_D_matrix - max_log_D
D_matrix = torch.exp(log_D_matrix_stabilized)
keys_scaled = keys / math.sqrt(DH)
qk_matrix = queries @ keys_scaled.transpose(-2, -1)
C_matrix = qk_matrix * D_matrix # (B, NH, S, S)
normalizer = torch.maximum(C_matrix.sum(dim=-1, keepdim=True).abs(), torch.exp(-max_log_D)) # (B, NH, S, 1)
C_matrix_normalized = C_matrix / (normalizer + eps)
h_tilde_state = C_matrix_normalized @ values
return h_tilde_state
class MatrixLSTMCell(nn.Module):
def __init__(self, dim, num_heads):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.igate = nn.Linear(3 * dim, num_heads)
self.fgate = nn.Linear(3 * dim, num_heads)
self.outnorm = MultiHeadLayerNorm(ndim=dim, weight=True, bias=False)
self.causal_mask_cache = {}
self.reset_parameters()
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
B, S, _ = q.shape
if_gate_input = torch.cat([q, k, v], dim=-1)
q = q.view(B, S, self.num_heads, -1)
k = k.view(B, S, self.num_heads, -1)
v = v.view(B, S, self.num_heads, -1)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
igate_preact = self.igate(if_gate_input)
igate_preact = igate_preact.transpose(-1, -2).unsqueeze(-1)
fgate_preact = self.fgate(if_gate_input)
fgate_preact = fgate_preact.transpose(-1, -2).unsqueeze(-1)
if S in self.causal_mask_cache:
causal_mask = self.causal_mask_cache[(S, str(q.device))]
else:
causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device=q.device))
self.causal_mask_cache[(S, str(q.device))] = causal_mask
h_state = parallel_stabilized_simple(
queries=q,
keys=k,
values=v,
igate_preact=igate_preact,
fgate_preact=fgate_preact,
lower_triangular_matrix=causal_mask,
)
h_state_norm = self.outnorm(h_state)
h_state_norm = h_state_norm.transpose(1, 2).reshape(B, S, -1)
return h_state_norm
def reset_parameters(self):
self.outnorm.reset_parameters()
torch.nn.init.zeros_(self.fgate.weight)
bias_linspace_init_(self.fgate.bias, start=3.0, end=6.0)
torch.nn.init.zeros_(self.igate.weight)
torch.nn.init.normal_(self.igate.bias, mean=0.0, std=0.1)
class mLSTM(nn.Module):
def __init__(
self,
dim,
expansion=2,
qkv_block_size=4,
proj_bias=False,
conv_bias=True,
kernel_size=4,
):
super().__init__()
assert dim % qkv_block_size == 0
self.dim = dim
self.expansion = expansion
self.qkv_block_size = qkv_block_size
self.proj_bias = proj_bias
self.conv_bias = conv_bias
self.kernel_size = kernel_size
inner_dim = expansion * dim
num_heads = inner_dim // qkv_block_size
self.proj_up = nn.Linear(
in_features=dim,
out_features=2 * inner_dim,
bias=proj_bias,
)
self.q_proj = LinearHeadwiseExpand(
dim=inner_dim,
num_heads=num_heads,
bias=proj_bias,
)
self.k_proj = LinearHeadwiseExpand(
dim=inner_dim,
num_heads=num_heads,
bias=proj_bias,
)
self.v_proj = LinearHeadwiseExpand(
dim=inner_dim,
num_heads=num_heads,
bias=proj_bias,
)
self.conv1d = CausalConv1d(
dim=inner_dim,
kernel_size=kernel_size,
bias=conv_bias,
)
self.mlstm_cell = MatrixLSTMCell(
dim=inner_dim,
num_heads=qkv_block_size,
)
self.learnable_skip = nn.Parameter(torch.ones(inner_dim))
self.proj_down = nn.Linear(
in_features=inner_dim,
out_features=dim,
bias=proj_bias,
)
self.reset_parameters()
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, S, _ = x.shape
x_inner = self.proj_up(x)
x_mlstm, z = torch.chunk(x_inner, chunks=2, dim=-1)
x_mlstm_conv = self.conv1d(x_mlstm)
x_mlstm_conv_act = F.silu(x_mlstm_conv)
q = self.q_proj(x_mlstm_conv_act)
k = self.k_proj(x_mlstm_conv_act)
v = self.v_proj(x_mlstm)
h_tilde_state = self.mlstm_cell(q=q, k=k, v=v)
h_tilde_state_skip = h_tilde_state + (self.learnable_skip * x_mlstm_conv_act)
h_state = h_tilde_state_skip * F.silu(z)
x = self.proj_down(h_state)
return x
def reset_parameters(self):
small_init_(self.proj_up.weight, dim=self.dim)
if self.proj_up.bias is not None:
nn.init.zeros_(self.proj_up.bias)
wang_init_(self.proj_down.weight, dim=self.dim, num_blocks=1)
if self.proj_down.bias is not None:
nn.init.zeros_(self.proj_down.bias)
nn.init.ones_(self.learnable_skip)
def _init_qkv_proj(qkv_proj: LinearHeadwiseExpand):
small_init_(qkv_proj.weight, dim=self.dim)
if qkv_proj.bias is not None:
nn.init.zeros_(qkv_proj.bias)
_init_qkv_proj(self.q_proj)
_init_qkv_proj(self.k_proj)
_init_qkv_proj(self.v_proj)
self.mlstm_cell.reset_parameters()
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class MCGatingUnit(nn.Module):
def __init__(self,dim, hidden_dim, dropout):
super().__init__()
self.mlstm_1 = mLSTM(dim)
self.mlstm_2 = mLSTM(dim)
def forward(self, x):
u, v = x, x
u = self.mlstm_1(u)
v = self.mlstm_1(v)
out = u * v
return out
class MCDPMLSTMBlock(nn.Module):
def __init__(self, d_model, d_ffn,dropout):
super().__init__()
self.norm = nn.LayerNorm(d_model)
self.mcgu = MCGatingUnit(d_model,d_ffn,dropout)
self.ffn = FeedForward(d_model,d_ffn,dropout)
def forward(self, x):
residual = x
x = self.norm(x)
x = self.mcgu(x)
x = x + residual
residual = x
x = self.norm(x)
x = self.ffn(x)
out = x + residual
return out
class MCDPMLSTM(nn.Module):
def __init__(self, d_model, d_ffn,num_layers,dropout):
super().__init__()
self.model = nn.Sequential(
*[MCDPMLSTMBlock(d_model, d_ffn,dropout) for _ in range(num_layers)]
)
def forward(self, x):
return self.model(x)