| |
| |
|
|
| from __future__ import annotations |
|
|
| import warnings |
| from typing import TYPE_CHECKING, Dict, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
| from fla.layers.utils import get_unpad_data, index_first_axis, pad_input |
| from fla.modules import RMSNorm, ShortConvolution |
| from fla.modules.feature_map import ReLUFeatureMap, SwishFeatureMap, T2RFeatureMap |
| from fla.modules.layernorm import rms_norm_linear |
| from fla.ops.gsa import chunk_gsa, fused_recurrent_gsa |
|
|
| if TYPE_CHECKING: |
| from transformers.processing_utils import Unpack |
|
|
| from fla.models.utils import Cache |
|
|
|
|
| class GatedSlotAttention(nn.Module): |
|
|
| def __init__( |
| self, |
| mode: str = 'chunk', |
| hidden_size: int = 1024, |
| expand_k: float = 1., |
| expand_v: float = 1., |
| num_heads: int = 4, |
| num_kv_heads: Optional[int] = None, |
| use_short_conv: bool = False, |
| conv_size: int = 4, |
| conv_bias: bool = False, |
| num_slots: Optional[int] = None, |
| elementwise_affine: Optional[bool] = True, |
| norm_eps: float = 1e-5, |
| gate_logit_normalizer: int = 8, |
| feature_map: str = 'swish', |
| use_output_gate: bool = False, |
| use_norm: bool = True, |
| layer_idx: Optional[int] = None, |
| scale: Optional[float] = 1., |
| **kwargs |
| ) -> GatedSlotAttention: |
| super().__init__() |
|
|
| self.mode = mode |
| self.hidden_size = hidden_size |
| self.expand_k = expand_k |
| self.expand_v = expand_v |
| self.num_heads = num_heads |
| self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads |
| self.num_kv_groups = self.num_heads // self.num_kv_heads |
| self.key_dim = int(hidden_size * expand_k) |
| self.value_dim = int(hidden_size * expand_v) |
| self.key_dim_per_group = self.key_dim // self.num_kv_groups |
| self.value_dim_per_group = self.value_dim // self.num_kv_groups |
| self.head_k_dim = self.key_dim // self.num_heads |
| self.head_v_dim = self.value_dim // self.num_heads |
|
|
| self.use_short_conv = use_short_conv |
| self.conv_size = conv_size |
| self.conv_bias = conv_bias |
|
|
| self.gate_logit_normalizer = gate_logit_normalizer |
|
|
| self.use_output_gate = use_output_gate |
| self.use_norm = use_norm |
| self.scale = scale |
|
|
| if num_slots is None: |
| num_slots = self.head_k_dim |
| self.num_slots = num_slots |
|
|
| self.layer_idx = layer_idx |
|
|
| if layer_idx is None: |
| warnings.warn( |
| f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " |
| "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " |
| "when creating this class." |
| ) |
|
|
| self.register_module('feature_map', None) |
| if feature_map == 'swish': |
| self.feature_map = SwishFeatureMap() |
| elif feature_map == 'relu': |
| self.feature_map = ReLUFeatureMap() |
| elif feature_map == 't2r': |
| self.feature_map = T2RFeatureMap(self.head_k_dim, self.head_k_dim) |
| else: |
| raise NotImplementedError(f"Feature map `{feature_map}` is not supported now.") |
|
|
| self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False) |
| self.k_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False) |
| self.v_proj = nn.Linear(self.hidden_size, self.value_dim_per_group, bias=False) |
| self.f_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.num_slots, bias=False) |
|
|
| if use_short_conv: |
| self.conv_size = conv_size |
| self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') |
| self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu') |
| self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu') |
|
|
| self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, eps=norm_eps) |
| self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[Cache] = None, |
| use_cache: Optional[bool] = False, |
| output_attentions: Optional[bool] = False, |
| **kwargs: Unpack[Dict] |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: |
| if attention_mask is not None: |
| assert len(attention_mask.shape) == 2, ( |
| "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " |
| "for padding purposes (0 indicating padding). " |
| "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." |
| ) |
|
|
| batch_size, q_len, _ = hidden_states.shape |
| mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode |
|
|
| last_state = None |
| if past_key_values is not None and len(past_key_values) > self.layer_idx: |
| last_state = past_key_values[self.layer_idx] |
|
|
| cu_seqlens = kwargs.get('cu_seqlens', None) |
| if attention_mask is not None: |
| indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) |
| hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) |
|
|
| if self.use_short_conv: |
| conv_state_q, conv_state_k, conv_state_v = None, None, None |
| if last_state is not None: |
| conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] |
| q, conv_state_q = self.q_conv1d( |
| x=self.q_proj(hidden_states), |
| cache=conv_state_q, |
| output_final_state=use_cache, |
| cu_seqlens=cu_seqlens |
| ) |
| k, conv_state_k = self.k_conv1d( |
| x=self.k_proj(hidden_states), |
| cache=conv_state_k, |
| output_final_state=use_cache, |
| cu_seqlens=cu_seqlens |
| ) |
| v, conv_state_v = self.v_conv1d( |
| x=self.v_proj(hidden_states), |
| cache=conv_state_v, |
| output_final_state=use_cache, |
| cu_seqlens=cu_seqlens |
| ) |
| else: |
| q = self.q_proj(hidden_states) |
| k = self.k_proj(hidden_states) |
| v = self.v_proj(hidden_states) |
| f = self.f_proj(hidden_states) |
|
|
| q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim) |
| k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim) |
| v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim) |
| f = rearrange(f, '... (h m) -> ... h m', m=self.num_slots) |
|
|
| if self.feature_map is not None: |
| q, k = map(lambda x: self.feature_map(x), (q, k)) |
| v = F.silu(v) |
|
|
| f = F.logsigmoid(f) / self.gate_logit_normalizer |
| s = (1 - f.exp()).to(f.dtype) |
|
|
| recurrent_state = last_state['recurrent_state'] if last_state is not None else None |
| if mode == 'fused_recurrent': |
| o, recurrent_state = fused_recurrent_gsa( |
| q=q, |
| k=k, |
| v=v, |
| s=s, |
| g=f, |
| initial_state=recurrent_state, |
| output_final_state=use_cache, |
| scale=self.scale, |
| cu_seqlens=cu_seqlens, |
| ) |
| elif mode == 'chunk': |
| o, recurrent_state = chunk_gsa( |
| q=q, |
| k=k, |
| v=v, |
| s=s, |
| g=f, |
| initial_state=recurrent_state, |
| output_final_state=use_cache, |
| scale=self.scale, |
| cu_seqlens=cu_seqlens, |
| ) |
| else: |
| raise NotImplementedError(f"Not supported mode `{mode}`.") |
|
|
| if past_key_values is not None: |
| past_key_values.update( |
| recurrent_state=recurrent_state, |
| conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, |
| layer_idx=self.layer_idx, |
| offset=q_len |
| ) |
|
|
| o = rearrange(o, '... h d -> ... (h d)') |
| o = rms_norm_linear(F.silu(o), self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias) |
| if attention_mask is not None: |
| o = pad_input(o.squeeze(0), indices, batch_size, q_len) |
|
|
| return o, None, past_key_values |
|
|