| |
| |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange, repeat |
|
|
| from fla.modules import RMSNorm |
| from fla.modules.feature_map import DPFPFeatureMap, HadamardFeatureMap, HedgehogFeatureMap, T2RFeatureMap |
| from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn |
|
|
|
|
| class LinearAttention(nn.Module): |
|
|
| def __init__( |
| self, |
| mode: str = 'chunk', |
| hidden_size: str = 1024, |
| expand_k: int = 1.0, |
| expand_v: int = 1.0, |
| num_heads: int = 8, |
| num_kv_heads: Optional[int] = None, |
| feature_map: str = 'elementwise_product', |
| tie_feature_map_qk: bool = False, |
| output_norm: str = 'rmsnorm', |
| norm_q: bool = False, |
| norm_k: bool = False, |
| do_feature_map_norm: bool = False, |
| elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| **kwargs |
| ): |
| super().__init__() |
|
|
| self.hidden_size = hidden_size |
| self.mode = mode |
| self.num_heads = num_heads |
| self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads |
| self.num_kv_groups = self.num_heads // self.num_kv_heads |
| self.key_dim = int(hidden_size * expand_k) |
| self.value_dim = int(hidden_size * expand_v) |
| self.key_dim_per_group = self.key_dim // self.num_kv_groups |
| self.value_dim_per_group = self.value_dim // self.num_kv_groups |
|
|
| assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." |
| assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" |
| assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" |
|
|
| self.head_k_dim = self.key_dim // num_heads |
| self.head_v_dim = self.value_dim // num_heads |
| self.do_feature_map_norm = do_feature_map_norm |
|
|
| if feature_map == 'hedgehog': |
| if tie_feature_map_qk: |
| self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim) |
| else: |
| self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_k_dim) |
| self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim) |
|
|
| elif feature_map == 't2r': |
| if tie_feature_map_qk: |
| self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim) |
| else: |
| self.feature_map_q = T2RFeatureMap(head_dim=self.head_k_dim) |
| self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim) |
|
|
| elif feature_map == 'elementwise_product': |
| if tie_feature_map_qk: |
| self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim) |
| else: |
| self.feature_map_q = HadamardFeatureMap(head_dim=self.head_k_dim) |
| self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim) |
|
|
| elif feature_map == 'dpfp': |
| self.feature_map_q = DPFPFeatureMap(head_dim=self.head_k_dim) |
| self.feature_map_k = DPFPFeatureMap(head_dim=self.head_k_dim) |
|
|
| elif feature_map == 'elu': |
| def elu(x): |
| return F.elu(x) + 1 |
| self.feature_map_q = elu |
| self.feature_map_k = elu |
|
|
| elif feature_map == 'relu': |
| self.feature_map_q = nn.ReLU() |
| self.feature_map_k = nn.ReLU() |
|
|
| elif feature_map == 'identity': |
| self.feature_map_q = nn.Identity() |
| self.feature_map_k = nn.Identity() |
| else: |
| raise NotImplementedError(f"Not supported feature map `{feature_map}`.") |
|
|
| self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) |
| self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False) |
| self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False) |
|
|
| if output_norm == 'rmsnorm': |
| self.norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps) |
| elif output_norm == 'identity': |
| self.norm = nn.Identity() |
| else: |
| raise NotImplementedError(f"Not supported output norm `{output_norm}`.") |
|
|
| self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) |
|
|
| self.norm_q = norm_q |
| self.norm_k = norm_k |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| **kwargs |
| ) -> torch.Tensor: |
| mode = self.mode |
| q = self.q_proj(hidden_states) |
| k = self.k_proj(hidden_states) |
| v = self.v_proj(hidden_states) |
|
|
| q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim) |
| if self.num_kv_groups > 1: |
| k = repeat(k, '... (h d) -> ... (h g) d', d=self.head_k_dim, g=self.num_kv_groups) |
| v = repeat(v, '... (h d) -> ... (h g) d', d=self.head_v_dim, g=self.num_kv_groups) |
| else: |
| k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim) |
| v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim) |
|
|
| q = self.feature_map_q(q) |
| k = self.feature_map_k(k) |
|
|
| if self.norm_q: |
| q = q / (q.sum(-1, True) + 1e-4) |
| if self.norm_k: |
| k = k / (k.sum(-1, True) + 1e-4) |
|
|
| if mode == 'chunk': |
| o, final_state = chunk_linear_attn( |
| q=q, |
| k=k, |
| v=v, |
| normalize=self.do_feature_map_norm, |
| head_first=False |
| ) |
| elif mode == 'fused_chunk': |
| o, final_state = fused_chunk_linear_attn( |
| q=q, |
| k=k, |
| v=v, |
| normalize=self.do_feature_map_norm, |
| ) |
| elif mode == 'fused_recurrent': |
| o, final_state = fused_recurrent_linear_attn( |
| q=q, |
| k=k, |
| v=v, |
| normalize=self.do_feature_map_norm, |
| ) |
| else: |
| raise NotImplementedError |
| o = self.norm(o) |
| o = rearrange(o, '... h d -> ... (h d)') |
| o = self.o_proj(o) |
| return o |
|
|