| |
| |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange, repeat |
|
|
| from fla.modules import RMSNorm |
| |
| |
| import torch.nn.init as init |
| import math |
| from fla.modules.l2norm import l2norm |
| from causal_conv1d import causal_conv1d_fn, causal_conv1d_update |
| from einops import rearrange |
|
|
| class emla(nn.Module): |
| def __init__( |
| self, |
| mode: str = 'chunk', |
| hidden_size: str = 1024, |
| expand_k: int = 1.0, |
| expand_v: int = 1.0, |
| num_heads: int = 8, |
| |
| output_norm: str = 'rmsnorm', |
| elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| use_gate :bool = False, |
| ratio : int =2, |
| **kwargs |
| ): |
| super().__init__() |
|
|
| self.hidden_size = hidden_size |
| self.mode = mode |
| self.num_heads = num_heads |
| self.num_kv_heads = num_heads |
| self.num_kv_groups = self.num_heads // self.num_kv_heads |
| self.key_dim = int(hidden_size * expand_k) |
| self.value_dim = int(hidden_size * expand_v) |
| self.key_dim_per_group = self.key_dim // self.num_kv_groups |
| self.value_dim_per_group = self.value_dim // self.num_kv_groups |
|
|
| assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." |
| assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" |
| assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" |
|
|
| self.head_k_dim = self.key_dim // num_heads |
| self.head_v_dim = self.value_dim // num_heads |
|
|
| self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) |
| self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False) |
| self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False) |
| self.use_gate = use_gate |
| if use_gate : |
| self.g_proj = nn.Linear(self.hidden_size,self.value_dim_per_group,False) |
| if output_norm == 'rmsnorm': |
| self.norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps) |
| elif output_norm == 'identity': |
| self.norm = nn.Identity() |
| else: |
| raise NotImplementedError(f"Not supported output norm `{output_norm}`.") |
| self.ratio = ratio |
| self.gate_fn = nn.functional.silu |
| self.router_weight = nn.Parameter(torch.empty((self.num_heads,self.ratio,self.head_v_dim))) |
| self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) |
| self.d_conv = 4 |
| self.conv1d = nn.Conv1d( |
| in_channels=self.hidden_size, |
| out_channels=self.hidden_size, |
| bias=False, |
| kernel_size=self.d_conv, |
| groups=self.hidden_size, |
| padding=self.d_conv - 1, |
| |
| ) |
| self.reset_parameters() |
| def reset_parameters(self) -> None: |
| import torch.nn.init as init |
| init.kaiming_uniform_(self.router_weight, a=math.sqrt(5)) |
| nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5) |
| nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5) |
| nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5) |
| if self.use_gate: |
| nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5) |
| nn.init.xavier_uniform_(self.out_proj.weight, gain=2 ** -2.5) |
|
|
|
|
| def forward(self, hidden_state,seqlen_offset = None): |
| x = hidden_state |
| |
| b,l,d = x.shape |
| x = rearrange(x, 'b l d -> b d l').contiguous() |
| if self.training: |
| x = causal_conv1d_fn( |
| x=x, |
| weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), |
| bias=self.conv1d.bias.to(self.precision) |
| if self.conv1d.bias is not None |
| else self.conv1d.bias, |
| activation="silu", |
| ) |
| elif conv_states is None: |
| conv_states = nn.functional.pad( |
| x, (self.d_conv - x.shape[-1], 0) |
| ) |
| x = causal_conv1d_fn( |
| x=x, |
| weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), |
| bias=self.conv1d.bias.to(self.precision) |
| if self.conv1d.bias is not None |
| else self.conv1d.bias, |
| activation="silu", |
| ) |
| else: |
| x = causal_conv1d_update( |
| x, |
| conv_states, |
| weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), |
| bias=self.conv1d.bias.to(self.precision) |
| if self.conv1d.bias is not None |
| else self.conv1d.bias, |
| activation="silu", |
| ) |
| x = x |
| x = rearrange(x, 'b d l -> b l d').contiguous() |
| q,_ = (self.q_proj(x)) |
| q = self.gate_fn(q) |
| k,_ = self.k_proj(x) |
| v,_ = self.v_proj(x) |
| g,_ = self.g_proj(x) |
| q = rearrange(q, 'b l (h d) -> b h l d', h = self.num_heads).contiguous() |
| k = rearrange(k, 'b l (h d) -> b h l d', h = self.num_heads).contiguous() |
| v = rearrange(v, 'b l (h d) -> b h l d', h = self.num_heads).contiguous() |
| output,k_f,s_f = self.gated_linear_attention(q, k, v,k_f,s_f) |
| output = rearrange(output,'b h l d -> b l h d') |
| output = self.norm(output) |
|
|
| output = self.gate_fn(g) * (output.view(b,l,d)) |
| output,_ = self.o_proj(output) |
| |
| return output,k_f,s_f,conv_states |
|
|
| def gated_linear_attention(self,q, k, v, past_sum=None,past_state = None): |
| '''torch qk version''' |
| b,h,l,d = v.shape |
| dk = q.shape[-1] |
| logits = torch.matmul(v,self.router_weight) |
| scores = logits.softmax(dim=-1) |
| topk_score , topk_idx = torch.topk(scores,k = self.top_k,dim=-1,sorted=False) |
| if self.top_k>1: |
| sum_score = topk_score.sum(dim=-1,keepdim=True)+1e-20 |
| topk_score = topk_score/sum_score |
| |
| masked_scores = torch.zeros_like(scores,device=q.device) |
| masked_scores.scatter_(-1, topk_idx, topk_score) |
| masked_idx = masked_scores.bool() |
| if self.training: |
| k_exp0 = torch.einsum('b h l d, b h l r-> b h l r d',k,masked_idx) |
| router_weight_qk = torch.cumsum(k_exp0,dim=-3) |
| k_exp = torch.einsum('b h l d, b h l r-> b h l r d',k,masked_scores) |
| norm_k = (l2norm(router_weight_qk)) |
| qlogit = torch.einsum('b h l d, b h l r d-> b h l r',q,norm_k).softmax(dim=-1) |
| q_exp = torch.einsum('b h l d, b h l r-> b h l r d',q,qlogit) |
| q_exp = rearrange(q_exp,'b h l r d -> b h l (r d)') |
| k_exp = rearrange(k_exp,'b h l r d -> b h l (r d)') |
| qk = q_exp @ k_exp.transpose(-1,-2) * (dk**-0.5) |
| qk = qk.tril(diagonal=0) |
| o_moe = qk@v |
| return o_moe,None,None |
| else: |
| if past_sum == None: |
| k_final = torch.zeros([b,h,self.ratio,dk]).to(q) |
| else: |
| k_final = past_sum |
| k_exp0 = torch.einsum('b h l d, b h l r-> b h l r d',k,masked_idx) |
| router_weight_qk = torch.cumsum(k_exp0,dim=-3)+k_final.unsqueeze(-3) |
| norm_k = (l2norm(router_weight_qk)) |
| k_exp = torch.einsum('b h l d, b h l r-> b h l r d',k,masked_scores) |
| if past_state==None: |
| s_final = torch.zeros([b,h,self.ratio,dk,d]).to(q) |
| else: |
| s_final = past_state |
| qlogit = torch.einsum('b h l d, b h l r d-> b h l r',q,norm_k).softmax(dim=-1) |
| q_exp = torch.einsum('b h l d, b h l r-> b h l r d',q,qlogit) |
| k_transexp = rearrange(k_exp,'b h l r d-> b h r d l') |
| final_state = s_final + k_transexp@(v.unsqueeze(-3)) |
| if past_state == None: |
| q_exp = rearrange(q_exp,'b h l r d -> b h l (r d)') |
| k_exp = rearrange(k_exp,'b h l r d -> b h l (r d)') |
| qk = q_exp @ k_exp.transpose(-1,-2) * (dk**-0.5) |
| qk = qk.tril(diagonal=0) |
| o_moe = qk@v |
| else: |
| o_moe = torch.einsum('b h l r k,b h r k d-> b h l d',q_exp,s_final)* (dk**-0.5) |
| q_exp = rearrange(q_exp,'b h l r d -> b h l (r d)') |
| k_exp = rearrange(k_exp,'b h l r d -> b h l (r d)') |
| qk = q_exp @ k_exp.transpose(-1,-2) * (dk**-0.5) |
| qk = qk.tril(diagonal=0) |
| o_moe += qk@v |
| return o_moe,router_weight_qk[:,:,-1,:,:],final_state |
|
|
|
|