org_gdn_1B / fla3 /layers /emla.py
msj19's picture
Add files using upload-large-folder tool
b56c89c verified
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from fla.modules import RMSNorm
# from fla.modules.feature_map import DPFPFeatureMap, HadamardFeatureMap, HedgehogFeatureMap, T2RFeatureMap
# from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn
import torch.nn.init as init
import math
from fla.modules.l2norm import l2norm
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from einops import rearrange
class emla(nn.Module):
def __init__(
self,
mode: str = 'chunk',
hidden_size: str = 1024,
expand_k: int = 1.0,
expand_v: int = 1.0,
num_heads: int = 8,
output_norm: str = 'rmsnorm',
elementwise_affine: bool = True,
norm_eps: float = 1e-5,
use_gate :bool = False,
ratio : int =2,
**kwargs
):
super().__init__()
self.hidden_size = hidden_size
self.mode = mode
self.num_heads = num_heads
self.num_kv_heads = num_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.key_dim = int(hidden_size * expand_k)
self.value_dim = int(hidden_size * expand_v)
self.key_dim_per_group = self.key_dim // self.num_kv_groups
self.value_dim_per_group = self.value_dim // self.num_kv_groups
assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
self.head_k_dim = self.key_dim // num_heads
self.head_v_dim = self.value_dim // num_heads
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
self.use_gate = use_gate
if use_gate :
self.g_proj = nn.Linear(self.hidden_size,self.value_dim_per_group,False)
if output_norm == 'rmsnorm':
self.norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps)
elif output_norm == 'identity':
self.norm = nn.Identity()
else:
raise NotImplementedError(f"Not supported output norm `{output_norm}`.")
self.ratio = ratio
self.gate_fn = nn.functional.silu
self.router_weight = nn.Parameter(torch.empty((self.num_heads,self.ratio,self.head_v_dim)))
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
self.d_conv = 4
self.conv1d = nn.Conv1d(
in_channels=self.hidden_size,
out_channels=self.hidden_size,
bias=False,
kernel_size=self.d_conv,
groups=self.hidden_size,
padding=self.d_conv - 1,
# **factory_kwargs,
)
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.router_weight, a=math.sqrt(5))
nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
if self.use_gate:
nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.out_proj.weight, gain=2 ** -2.5)
def forward(self, hidden_state,seqlen_offset = None):
x = hidden_state
# x = x.transpose(0, 1).contiguous()
b,l,d = x.shape
x = rearrange(x, 'b l d -> b d l').contiguous()
if self.training:
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias.to(self.precision)
if self.conv1d.bias is not None
else self.conv1d.bias,
activation="silu",
)
elif conv_states is None:
conv_states = nn.functional.pad(
x, (self.d_conv - x.shape[-1], 0)
)
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias.to(self.precision)
if self.conv1d.bias is not None
else self.conv1d.bias,
activation="silu",
)
else:
x = causal_conv1d_update(
x,
conv_states,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias.to(self.precision)
if self.conv1d.bias is not None
else self.conv1d.bias,
activation="silu",
)
x = x
x = rearrange(x, 'b d l -> b l d').contiguous()
q,_ = (self.q_proj(x)) #query_q(b l dk)
q = self.gate_fn(q)
k,_ = self.k_proj(x) #get k(b l dk)
v,_ = self.v_proj(x) #b l 2*self.head_dv
g,_ = self.g_proj(x) #all get b l d
q = rearrange(q, 'b l (h d) -> b h l d', h = self.num_heads).contiguous()
k = rearrange(k, 'b l (h d) -> b h l d', h = self.num_heads).contiguous()
v = rearrange(v, 'b l (h d) -> b h l d', h = self.num_heads).contiguous()
output,k_f,s_f = self.gated_linear_attention(q, k, v,k_f,s_f)
output = rearrange(output,'b h l d -> b l h d')
output = self.norm(output)
output = self.gate_fn(g) * (output.view(b,l,d))
output,_ = self.o_proj(output)
# output = output.transpose(0, 1)
return output,k_f,s_f,conv_states
def gated_linear_attention(self,q, k, v, past_sum=None,past_state = None):
'''torch qk version'''
b,h,l,d = v.shape #b h l d
dk = q.shape[-1] # h d r
logits = torch.matmul(v,self.router_weight)#get b h l r'
scores = logits.softmax(dim=-1)
topk_score , topk_idx = torch.topk(scores,k = self.top_k,dim=-1,sorted=False)#get b,h,l,top_k
if self.top_k>1:#norm
sum_score = topk_score.sum(dim=-1,keepdim=True)+1e-20
topk_score = topk_score/sum_score
#到这都类似
masked_scores = torch.zeros_like(scores,device=q.device)
masked_scores.scatter_(-1, topk_idx, topk_score)
masked_idx = masked_scores.bool()
if self.training:
k_exp0 = torch.einsum('b h l d, b h l r-> b h l r d',k,masked_idx)
router_weight_qk = torch.cumsum(k_exp0,dim=-3)
k_exp = torch.einsum('b h l d, b h l r-> b h l r d',k,masked_scores)
norm_k = (l2norm(router_weight_qk))
qlogit = torch.einsum('b h l d, b h l r d-> b h l r',q,norm_k).softmax(dim=-1) #bhlr #bhlr
q_exp = torch.einsum('b h l d, b h l r-> b h l r d',q,qlogit)
q_exp = rearrange(q_exp,'b h l r d -> b h l (r d)')
k_exp = rearrange(k_exp,'b h l r d -> b h l (r d)')
qk = q_exp @ k_exp.transpose(-1,-2) * (dk**-0.5)
qk = qk.tril(diagonal=0)
o_moe = qk@v
return o_moe,None,None
else:
if past_sum == None:
k_final = torch.zeros([b,h,self.ratio,dk]).to(q)
else:
k_final = past_sum #bhrd
k_exp0 = torch.einsum('b h l d, b h l r-> b h l r d',k,masked_idx)
router_weight_qk = torch.cumsum(k_exp0,dim=-3)+k_final.unsqueeze(-3) #bhlrd
norm_k = (l2norm(router_weight_qk))
k_exp = torch.einsum('b h l d, b h l r-> b h l r d',k,masked_scores)#bhlrd
if past_state==None:
s_final = torch.zeros([b,h,self.ratio,dk,d]).to(q)#bhr dk d
else:
s_final = past_state
qlogit = torch.einsum('b h l d, b h l r d-> b h l r',q,norm_k).softmax(dim=-1) #bhlr #bhlr
q_exp = torch.einsum('b h l d, b h l r-> b h l r d',q,qlogit)
k_transexp = rearrange(k_exp,'b h l r d-> b h r d l')
final_state = s_final + k_transexp@(v.unsqueeze(-3))#b h r dk d
if past_state == None:
q_exp = rearrange(q_exp,'b h l r d -> b h l (r d)')
k_exp = rearrange(k_exp,'b h l r d -> b h l (r d)')
qk = q_exp @ k_exp.transpose(-1,-2) * (dk**-0.5)
qk = qk.tril(diagonal=0)
o_moe = qk@v
else:
o_moe = torch.einsum('b h l r k,b h r k d-> b h l d',q_exp,s_final)* (dk**-0.5)
q_exp = rearrange(q_exp,'b h l r d -> b h l (r d)')
k_exp = rearrange(k_exp,'b h l r d -> b h l (r d)')
qk = q_exp @ k_exp.transpose(-1,-2) * (dk**-0.5)
qk = qk.tril(diagonal=0)
o_moe += qk@v
return o_moe,router_weight_qk[:,:,-1,:,:],final_state