diff --git a/.ipynb_checkpoints/configuration_forgetting_transformer-checkpoint.py b/.ipynb_checkpoints/configuration_forgetting_transformer-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..7f4e7d87a4fd925592b9ef29a66cf924fa95f36e --- /dev/null +++ b/.ipynb_checkpoints/configuration_forgetting_transformer-checkpoint.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +from typing import Optional +from transformers.configuration_utils import PretrainedConfig + +class ForgettingTransformerConfig(PretrainedConfig): + model_type = 'forgetting_transformer' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + hidden_ratio: Optional[float] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 32, + num_kv_heads: int = None, + hidden_act: str = "swish", + window_size: Optional[int] = None, + max_position_embeddings: int = 2048, + initializer_range: float = 0.02, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + attention_bias: bool = False, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + rope_base: float = 500000.0, + use_rope: bool = False, + use_output_gate: bool = False, + ogate_act: str = "sigmoid", + fgate_type: str = "full", + fgate_bias_init: bool = False, + decay_time_min: Optional[float] = None, + decay_time_max: Optional[float] = None, + use_output_norm: bool = False, + qk_norm: bool = False, + qk_norm_share_param_across_head: bool = False, + use_k_shift: bool = False, + use_v_shift: bool = False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.window_size = window_size + self.max_position_embeddings = max_position_embeddings + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_norm = fuse_norm + self.rope_base = rope_base + self.use_rope = use_rope + self.use_output_gate = use_output_gate + self.ogate_act = ogate_act + self.fgate_type = fgate_type + self.fgate_bias_init = fgate_bias_init + self.decay_time_min = decay_time_min + self.decay_time_max = decay_time_max + self.use_output_norm = use_output_norm + self.qk_norm = qk_norm + self.qk_norm_share_param_across_head = qk_norm_share_param_across_head + self.use_k_shift = use_k_shift + self.use_v_shift = use_v_shift + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/.ipynb_checkpoints/fgate_cache-checkpoint.py b/.ipynb_checkpoints/fgate_cache-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ba956105bfa56d210a1e4e7ac9ba887abf6b6de7 --- /dev/null +++ b/.ipynb_checkpoints/fgate_cache-checkpoint.py @@ -0,0 +1,143 @@ +from typing import List, Tuple, Optional, Any, Dict +import torch + +class FgateDynamicCache: + """ + A cache that grows dynamically as more tokens are generated. + Custom cache for Forgetting Transformer that does not inherit from transformers.Cache. + """ + + def __init__(self, num_hidden_layers: Optional[int] = None) -> None: + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + self.log_fgate_cache: List[torch.Tensor] = [] + self.key_shift_cache: List[torch.Tensor] = [] + self.value_shift_cache: List[torch.Tensor] = [] + self._seen_tokens = 0 + + def update_shift_cache( + self, + key_shift_state: torch.Tensor, + value_shift_state: torch.Tensor, + layer_idx, + ): + assert layer_idx == len(self.key_shift_cache) == len(self.value_shift_cache) + self.key_shift_cache.append(key_shift_state) + self.value_shift_cache.append(value_shift_state) + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + if layer_idx < len(self): + return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + for layer_idx in range(len(self)): + yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]) + + def __len__(self): + return len(self.key_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + log_fgate_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert log_fgate_states.ndim == 3, f"log_fgate must be (B, H, T), but get {log_fgate_states.size()}" + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + self.log_fgate_cache.append(log_fgate_states) + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + self.log_fgate_cache[layer_idx] = torch.cat([self.log_fgate_cache[layer_idx], log_fgate_states], dim=-1) + + return self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_length(self) -> Optional[int]: + return None + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], ...]: + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]),) + return legacy_cache + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_layers: Optional[int] = None) -> "FgateDynamicCache": + """ + Converts a cache in the legacy cache format into an equivalent FgateDynamicCache. + + Args: + past_key_values: Optional legacy cache format + num_layers: Not used in this implementation + + Returns: + FgateDynamicCache instance + """ + cache = cls() + + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states, log_fgate_states = past_key_values[layer_idx] + cache.update(key_states, value_states, log_fgate_states, layer_idx) + + return cache + + def crop(self, max_length: int): + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + self._seen_tokens = max_length + for idx in range(len(self.key_cache)): + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + self.log_fgate_cache[idx] = self.log_fgate_cache[idx][..., :max_length] + + def batch_split(self, full_batch_size: int, split_size: int) -> List["FgateDynamicCache"]: + out = [] + for i in range(0, full_batch_size, split_size): + current_split = FgateDynamicCache() + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] + current_split.log_fgate_cache = [tensor[i : i + split_size] for tensor in self.log_fgate_cache] + out.append(current_split) + return out + + @classmethod + def from_batch_splits(cls, splits: List["FgateDynamicCache"]) -> "FgateDynamicCache": + cache = cls() + for idx in range(len(splits[0])): + layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) + layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) + layer_log_fgates = torch.cat([current.log_fgate_cache[idx] for current in splits], dim=0) + cache.update(layer_keys, layer_values, layer_log_fgates, idx) + return cache + + def batch_repeat_interleave(self, repeats: int): + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx].repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor): + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx][indices, ...] \ No newline at end of file diff --git a/.ipynb_checkpoints/fgate_cache.py-checkpoint.backup b/.ipynb_checkpoints/fgate_cache.py-checkpoint.backup new file mode 100644 index 0000000000000000000000000000000000000000..bef4bad152a1679650aa4c271e559cc76e0f67e1 --- /dev/null +++ b/.ipynb_checkpoints/fgate_cache.py-checkpoint.backup @@ -0,0 +1,203 @@ +from typing import List, Tuple, Optional, Any, Dict +import torch +from transformers.cache_utils import Cache + +class FgateDynamicCache(Cache): + """ + A cache that grows dynamically as more tokens are generated. This is the default for generative models. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = DynamicCache() + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + DynamicCache() + ``` + """ + + def __init__(self) -> None: + super().__init__() + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + self.log_fgate_cache: List[torch.Tensor] = [] + + self.key_shift_cache: List[torch.Tensor] = [] + self.value_shift_cache: List[torch.Tensor] = [] + + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + + def update_shift_cache( + self, + key_shift_state: torch.Tensor, + value_shift_state: torch.Tensor, + layer_idx, + ): + assert layer_idx == len(self.key_shift_cache) == len(self.value_shift_cache) + self.key_shift_cache.append(key_shift_state) + self.value_shift_cache.append(value_shift_state) + + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + log_fgate_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + assert log_fgate_states.ndim == 3, f"log_fgate must be (B, H, T), but get {log_fgate_states.size()}" + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + self.log_fgate_cache.append(log_fgate_states) + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + self.log_fgate_cache[layer_idx] = torch.cat([self.log_fgate_cache[layer_idx], log_fgate_states], dim=-1) + + return self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" + return None + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]),) + return legacy_cache + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_layers: Optional[int] = None) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" + raise NotImplementedError + assert num_layers is not None + cache = cls(num_layers) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states, log_fgate_states = past_key_values[layer_idx] + cache.update(key_states, value_states, log_fgate_states, layer_idx) + return cache + + def crop(self, max_length: int): + """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" + # In case it is negative + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + self._seen_tokens = max_length + for idx in range(len(self.key_cache)): + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + self.log_fgate_cache[idx] = self.log_fgate_cache[idx][..., :max_length] + + def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" + out = [] + for i in range(0, full_batch_size, split_size): + current_split = DynamicCache() + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] + current_split.log_fgate_cache = [tensor[i : i + split_size] for tensor in self.log_fgate_cache] + out.append(current_split) + return out + + @classmethod + def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache": + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" + cache = cls() + for idx in range(len(splits[0])): + layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) + layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) + layer_log_fgates = torch.cat([current.log_fgate_cache[idx] for current in splits], dim=0) + cache.update(layer_keys, layer_values, layer_log_fgates, idx) + return cache + + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx].repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx][indices, ...] diff --git a/.ipynb_checkpoints/modeling_forgetting_transformer-checkpoint.py b/.ipynb_checkpoints/modeling_forgetting_transformer-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..deff96ee4c8558b72b8a525d4b38184b2a46af77 --- /dev/null +++ b/.ipynb_checkpoints/modeling_forgetting_transformer-checkpoint.py @@ -0,0 +1,910 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +# from fla.layers.attn import Attention +from fla.modules import FusedCrossEntropyLoss, RMSNorm +from fla.modules.layernorm import group_norm_fn +from fla.modules.activations import swiglu_linear + +from fla.modules import RotaryEmbedding +from einops import rearrange + +# 动态导入配置类以支持本地和HuggingFace Hub加载 +try: + from .configuration_forgetting_transformer import ForgettingTransformerConfig +except (ImportError, ValueError): + try: + from configuration_forgetting_transformer import ForgettingTransformerConfig + except ImportError: + from forgetting_transformer.model.forgetting_transformer.configuration_forgetting_transformer import ForgettingTransformerConfig + +from forgetting_transformer.ops.forgetting_attention_std import forgetting_attention_std as forgetting_attention +from .fgate_cache import FgateDynamicCache +from .glu_linear import glu_linear +from .token_shift import token_shift + +from functools import partial + +logger = logging.get_logger(__name__) + + +class ShiftLinear(nn.Module): + + def __init__( + self, + input_dim: int, + output_dim: int, + num_heads: int, + bias: bool, + shift_bias: bool = False + ): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.num_heads = num_heads + assert self.output_dim % self.num_heads == 0 + + self.linear = nn.Linear(input_dim, output_dim, bias=bias) + self.shift_proj = nn.Linear(input_dim, num_heads, bias=shift_bias) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim})" + return s + + def forward(self, x: torch.Tensor, shift_state: Optional[torch.Tensor]) -> torch.Tensor: + assert x.ndim == 3, "Input must be (B, T, D)" + B, T, D = x.size() + out = self.linear(x) + # (B, T, H, 1) + alpha = torch.sigmoid(self.shift_proj(x).float()).float() + # left, right, top, bottom (B, T=H, D=W) + # out_prev = nn.functional.pad(out, (0, 0, 1, -1)) + # out_prev = torch.roll(out, shifts=1, dims=1) + + out_per_head = rearrange(out, 'b t (h d) -> b t h d', h=self.num_heads) + if T > 1: + # TODO: note in this case cache is not used + result_per_head = token_shift(out_per_head, alpha, 1.0 - alpha) + else: + shift_state_per_head = rearrange(shift_state, 'b (h d) -> b 1 h d', h=self.num_heads) + result_per_head = (alpha[..., None] * shift_state_per_head + (1 - alpha[..., None]) * out_per_head) + + result_per_head = result_per_head.to(out.dtype) + + if shift_state is not None: + shift_state.copy_(out[:, -1, :]) + + result = rearrange(result_per_head, 'b t h d -> b t (h d)', h=self.num_heads) + return result + +class GroupRMSNorm(nn.Module): + def __init__( + self, + num_groups: int, + hidden_size: int, + elementwise_affine: bool = True, + bias: bool = False, + eps: float = 1e-5 + ) -> GroupRMSNorm: + super().__init__() + + if hidden_size % num_groups != 0: + raise ValueError('num_channels must be divisible by num_groups') + + self.num_groups = num_groups + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + if bias: + self.bias = nn.Parameter(torch.zeros(hidden_size)) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.num_groups}, {self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return group_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=True, + num_groups=self.num_groups + ) + +class ForgettingAttentionLayer(nn.Module): + + def __init__( + self, + hidden_size: int = 2048, + num_heads: int = 32, + num_kv_heads: Optional[int] = None, + window_size: Optional[int] = None, + max_position_embeddings: Optional[int] = None, + use_rope: bool = False, + rope_base: float = 500000.0, + use_output_gate: bool = False, + ogate_act: str = "sigmoid", + fgate_type: str = "full", + fgate_bias_init: bool = False, + decay_time_min: Optional[float] = None, + decay_time_max: Optional[float] = None, + use_output_norm: bool = False, + norm_eps: float = 1e-6, + qk_norm: bool = False, + qk_norm_share_param_across_head: bool = False, + use_k_shift: bool = False, + use_v_shift: bool = False, + initializer_range: float = 0.02, + layer_idx: int = None + ): + """ + Forgetting Attention layer. + + Arguments: + - hidden_size: Input dimension and qkv dimension + - num_heads: Number of heads + - num_kv_heads: Not used. Should be None + - window_size: Not used. Should be None + - max_position_embeddings: Not used. Should be None + - use_rope: Whether to use RoPE. Default is False + - rope_base: the theta hyperparameter in RoPE. This has no effect if + use_rope=False + - use_output_gate: Whether to use output gates. Note that using output gates + introduces extra parameters and you may want to reduce parameters from + other components (e.g., MLPs) + - ogate_act: Activation for the output gate. Either "sigmoid" or "silu" + - fgate_type: Forget gate type. The following are supported: + - "full": The default data-dependent forget gate + - "bias_only": The data-independent forget gate + - "fixed": Forget gates with fixed values + - "none": Not using forget gates. Equivalent to forget gates with all + ones. + - fgate_bias_init: Whether to use special initalization for the bias terms in + the forget gate. This should only be used with fgate types in + ["bias_only", "fixed"]. + - decay_time_min: T_min for the forget gate bias initialization. See paper + for details. + - decay_time_max: T_max for the forget gate bias initalization. See paper + for details. + - use_output_norm: Whether to use output normalization. + - norm_eps: Epsilon for the RMSNorms + - qk_norm: Whether to use qk_norm + - qk_norm_share_param_across_head: In QK-norm, whether to share the RMSNorm + scaling parameters across heads. This is just for backward compatibility. + - use_k_shift: Whether to use data-dependent key shift + - use_v_shift: Whether to use data-dependent value shift + - initializer_range: standard deviation for initialization + - layer_idx: The block index of this layer. Needed for KV-cache + """ + super().__init__() + + self.num_heads = num_heads + if num_kv_heads is None: + self.num_kv_heads = self.num_heads + else: + raise NotImplementedError("GQA has not been tested.") + self.num_kv_heads = num_kv_heads + self.num_kv_groups = num_heads // self.num_kv_heads + self.hidden_size = hidden_size + self.head_dim = self.hidden_size // self.num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + self.kv_dim = self.num_kv_heads * self.head_dim + self.window_size = window_size + self.max_position_embeddings = max_position_embeddings + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + if use_k_shift: + self.k_proj = ShiftLinear(self.hidden_size, self.kv_dim, self.num_heads, bias=False) + else: + self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) + + if use_v_shift: + self.v_proj = ShiftLinear(self.hidden_size, self.kv_dim, self.num_heads, bias=False) + else: + self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) + + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.use_k_shift = use_k_shift + self.use_v_shift = use_v_shift + + + device = next(self.parameters()).device + # Forget gate + assert fgate_type in ["full", "bias_only", "fixed", "none"] + self.fgate_type = fgate_type + self.fgate_bias_init = fgate_bias_init + if fgate_type == "full": + assert not fgate_bias_init + self.fgate_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True) + elif fgate_type == "bias_only": + self.fgate_bias = nn.Parameter(torch.zeros(size=(self.num_heads,), device=device)) + self.fgate_bias._no_weight_decay = True + elif fgate_type == "fixed": + assert fgate_bias_init, "You must set fgate_bias_init = True with fixed fgate" + fgate_bias = torch.zeros(size=(self.num_heads,), device=device) + self.register_buffer("fgate_bias", fgate_bias) + elif fgate_type == "none": + pass + else: + raise ValueError(f"Unknown fgate type {fgate_type}") + + + + # Forget gate intialization for data-independent and fixed forget gates + if fgate_bias_init: + assert decay_time_min is not None and decay_time_max is not None + assert decay_time_min > 0 and decay_time_max > 0 + with torch.no_grad(): + log_decay_time = torch.linspace(math.log(decay_time_min), math.log(decay_time_max), steps=self.num_heads) + decay_time = torch.exp(log_decay_time) + # Such that t = -1 / log(sigmoid(b)) + bias_init = -torch.log(torch.expm1(1 / decay_time)) + self.fgate_bias.copy_(bias_init) + else: + assert decay_time_min is None and decay_time_max is None + + if use_output_gate: + self.ogate_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.ogate_act = ogate_act + assert ogate_act in ["silu", "sigmoid"] + else: + self.ogate_proj = None + + if use_output_norm: + self.output_norm = GroupRMSNorm(num_groups=self.num_heads, hidden_size=self.hidden_size, eps=norm_eps) + else: + self.output_norm = None + + + if use_rope: + self.rotary = RotaryEmbedding(self.head_dim, base=rope_base) + else: + self.rotary = None + + + self.qk_norm = qk_norm + self.qk_norm_share_param_across_head = qk_norm_share_param_across_head + if qk_norm: + if self.qk_norm_share_param_across_head: + # This is an incorrect implemention kept just for backward compatibility + self.q_norm = RMSNorm(self.head_dim) + self.k_norm = RMSNorm(self.head_dim) + else: + self.q_norm = GroupRMSNorm(num_groups=self.num_heads, hidden_size=self.hidden_size) + self.k_norm = GroupRMSNorm(num_groups=self.num_heads, hidden_size=self.hidden_size) + + self.initializer_range = initializer_range + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + # This will actually be overwritten by outer init. + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + We assume that during decoding attention mask is always 1. Otherwise it won't work. + """ + batch_size, q_len, _ = hidden_states.size() + if use_cache: + key_shift_state = past_key_values.key_shift_cache[self.layer_idx] + value_shift_state = past_key_values.value_shift_cache[self.layer_idx] + else: + key_shift_state = value_shift_state = None + + # Shift states are updated in place + q = self.q_proj(hidden_states) + if self.use_k_shift: + k = self.k_proj(hidden_states, key_shift_state) + else: + k = self.k_proj(hidden_states) + if self.use_v_shift: + v = self.v_proj(hidden_states, value_shift_state) + else: + v = self.v_proj(hidden_states) + + if self.qk_norm and (not self.qk_norm_share_param_across_head): + q = self.q_norm(q).to(q.dtype) + k = self.k_norm(k).to(k.dtype) + + q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads) + k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads) + v = rearrange(v, 'b t (h d) -> b h t d', h=self.num_kv_heads) + + + if self.qk_norm and (self.qk_norm_share_param_across_head): + q = self.q_norm(q).to(q.dtype) + k = self.k_norm(k).to(k.dtype) + + + seqlen_offset, max_seqlen = 0, q.shape[1] + if past_key_values is not None: + seqlen_offset = past_key_values.get_seq_length(self.layer_idx) + max_seqlen = q.shape[1] + seqlen_offset + + if attention_mask is not None: + # to deliminate the offsets of padding tokens + seqlen_offset = (seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]) + max_seqlen = q.shape[1] + max(seqlen_offset) + + if self.max_position_embeddings is not None: + max_seqlen = max(max_seqlen, self.max_position_embeddings) + if self.rotary is not None: + q, k = self.rotary(q, k, seqlen_offset, max_seqlen) + + if self.fgate_type == "full": + fgate_logit = self.fgate_proj(hidden_states) + fgate_logit = rearrange(fgate_logit, "b t h -> b h t") + log_fgate = torch.nn.functional.logsigmoid(fgate_logit.float()) + elif self.fgate_type == "none": + log_fgate = torch.zeros((batch_size, self.num_heads, q_len), dtype=torch.float32, device=hidden_states.device) + else: + assert self.fgate_type in ["fixed", "bias_only"] + fgate_logit = torch.broadcast_to(self.fgate_bias, (batch_size, q_len, self.num_heads)) + fgate_logit = rearrange(fgate_logit, "b t h -> b h t") + log_fgate = torch.nn.functional.logsigmoid(fgate_logit.float()) + + k = rearrange(k, 'b t h d -> b h t d') + if past_key_values is not None: + k, v, log_fgate = past_key_values.update(k, v, log_fgate, self.layer_idx) + # k, v = rearrange(k, 'b h t d -> b t h d'), rearrange(v, 'b h t d -> b t h d') + q = rearrange(q, 'b t h d -> b h t d') + + if self.num_kv_groups > 1: + assert False + k = rearrange(k.unsqueeze(-2).repeat(1, 1, 1, self.num_kv_groups, 1), 'b t h g d -> b t (h g) d') + v = rearrange(v.unsqueeze(-2).repeat(1, 1, 1, self.num_kv_groups, 1), 'b t h g d -> b t (h g) d') + + # Contains at least one padding token in the sequence + if attention_mask is not None: + B, _, T = log_fgate.size() + assert attention_mask.size() == (B, T), ((B, T), attention_mask.size()) + seq_start = T - attention_mask.sum(dim=-1) + o = forgetting_attention( + q, k, v, + log_fgate, + head_first=True, + seq_start=seq_start, + sm_scale=1 / math.sqrt(self.head_dim), + ) + o = rearrange(o, "b h t d -> b t h d") + else: + o = forgetting_attention( + q, k, v, + log_fgate, + head_first=True, + sm_scale=1 / math.sqrt(self.head_dim), + ) + o = rearrange(o, "b h t d -> b t h d") + + o = o.reshape(batch_size, q_len, self.hidden_size) + + if self.output_norm is not None: + o = self.output_norm(o) + + if self.ogate_proj is not None: + # ogate = self.ogate act(self.ogate_proj(hidden_states)) + # o = o * ogate + # ogate = act_gate(self.ogate_proj(hidden_states), o) + ogate_logit = self.ogate_proj(hidden_states) + dtype = ogate_logit.dtype + if self.ogate_act == "silu": + o = swiglu_linear(ogate_logit, o, self.o_proj.weight.to(dtype), self.o_proj.bias.to(dtype) if self.o_proj.bias is not None else self.o_proj.bias) + elif self.ogate_act == "sigmoid": + o = glu_linear(ogate_logit, o, self.o_proj.weight.to(dtype), self.o_proj.bias.to(dtype) if self.o_proj.bias is not None else self.o_proj.bias) + else: + raise ValueError(f"Unknown ogate act {self.ogate_act}") + else: + o = self.o_proj(o) + + if not output_attentions: + attentions = None + else: + SAVE_HEADS = [0, 1, 2, 3] + # (B, H, T, T) + score = q[:, SAVE_HEADS] @ k[:, SAVE_HEADS].mT + log_lambda = torch.cumsum(log_fgate, dim=-1) + decay_bias = (log_lambda[:, SAVE_HEADS, :, None] - log_lambda[:, SAVE_HEADS, None, :]).to(torch.bfloat16) + # normalized_score = torch.softmax(score, dim=-1) + attentions = (score, decay_bias) + + return o, attentions, past_key_values + + def init_shift_state(self, batch_size: int): + param = next(self.parameters()) + state = dict() + try: + dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled("cuda") else torch.float32 + except TypeError: + # Support legacy torch version + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else torch.float32 + if self.use_k_shift: + state['key_shift'] = param.new_zeros(batch_size, self.kv_dim, dtype=dtype) + else: + state['key_shift'] = None + if self.use_v_shift: + state['value_shift'] = param.new_zeros(batch_size, self.kv_dim, dtype=dtype) + else: + state['value_shift'] = None + return state + + +class ForgettingTransformerMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[float] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish' + ) -> ForgettingTransformerMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + self.hidden_act = hidden_act + assert hidden_act in ["swish", "sigmoid"] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + # TODO: maybe wrap swiglu_linear in custom_fwd/custom_bwd + if self.hidden_act == "swish": + return swiglu_linear( + gate, y, + self.down_proj.weight.to(y.dtype), + self.down_proj.bias.to(y.dtype) if self.down_proj.bias is not None else self.down_proj.bias + ) + elif self.hidden_act == "sigmoid": + return glu_linear( + gate, y, + self.down_proj.weight.to(y.dtype), + self.down_proj.bias.to(y.dtype) if self.down_proj.bias is not None else self.down_proj.bias + ) + else: + raise ValueError() + + +class ForgettingTransformerBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn = ForgettingAttentionLayer( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + window_size=config.window_size, + max_position_embeddings=config.max_position_embeddings, + rope_base=config.rope_base, + use_rope=config.use_rope, + use_output_gate=config.use_output_gate, + ogate_act=config.ogate_act, + fgate_type=config.fgate_type, + fgate_bias_init=config.fgate_bias_init, + decay_time_min=config.decay_time_min, + decay_time_max=config.decay_time_max, + use_output_norm = config.use_output_norm, + norm_eps=config.norm_eps, + qk_norm=config.qk_norm, + qk_norm_share_param_across_head=config.qk_norm_share_param_across_head, + use_k_shift=config.use_k_shift, + use_v_shift=config.use_v_shift, + initializer_range=config.initializer_range, + layer_idx=layer_idx + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = ForgettingTransformerMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act + ) + + def forward_attn( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ): + # residual handled outside of this + # residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + return hidden_states, attentions, past_key_values + + def forward_mlp( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + ): + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + gradient_checkpointing: bool = False + # **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + + if gradient_checkpointing: + forward_attn = partial(torch.utils.checkpoint.checkpoint, self.forward_attn, use_reentrant=False) + forward_mlp = partial(torch.utils.checkpoint.checkpoint, self.forward_mlp, use_reentrant=False) + else: + forward_attn = self.forward_attn + forward_mlp = self.forward_mlp + + hidden_states, attentions, past_key_values = forward_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + + hidden_states = forward_mlp( + hidden_states, + residual, + ) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attentions,) + + if use_cache: + outputs += (past_key_values,) + + return outputs + + + +class ForgettingTransformerPreTrainedModel(PreTrainedModel): + + config_class = ForgettingTransformerConfig + supports_gradient_checkpointing = True + _no_split_modules = ['ForgettingTransformerBlock'] + + def __init__(self, config, *inputs, **kwargs): + # 动态修复 config_class 以支持远程代码加载 + if hasattr(config, '__class__'): + config_module = config.__class__.__module__ + if 'transformers_modules' in config_module or config_module == 'configuration_forgetting_transformer': + self.__class__.config_class = config.__class__ + super().__init__(config, *inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + ): + # if isinstance(module, (nn.Linear, nn.Conv1d)): + if isinstance(module, (nn.Linear)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class ForgettingTransformerModel(ForgettingTransformerPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([ForgettingTransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, CausalLMOutputWithPast]: + # if output_attentions: + # warnings.warn( + # "`ForgettingTransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`." + # ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if use_cache: + # use_legacy_cache = not isinstance(past_key_values, Cache) + # if use_legacy_cache: + # past_key_values = FgateDynamicCache.from_legacy_cache(past_key_values) + if past_key_values is None: + past_key_values = FgateDynamicCache() + for layer_idx, layer in enumerate(self.layers): + shift_state = layer.attn.init_shift_state( + batch_size=input_ids.size(0), + ) + past_key_values.update_shift_cache( + key_shift_state=shift_state["key_shift"], + value_shift_state=shift_state["value_shift"], + layer_idx=layer_idx + ) + else: + assert isinstance(past_key_values, FgateDynamicCache) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = {} if output_attentions else None + next_decoder_cache = None + + for layer_id, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + gradient_checkpointing=self.gradient_checkpointing and self.training + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + OUTPUT_ATTN_LAYERS = [0, 7, 15, 23] + if layer_id in OUTPUT_ATTN_LAYERS: + # all_attns += (layer_outputs[1],) + all_attns[layer_id] = layer_outputs[1] + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + # next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class ForgettingTransformerForCausalLM(ForgettingTransformerPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = ForgettingTransformerModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True, reduction='none') + else: + loss_fct = nn.CrossEntropyLoss(reduction='none') + logits = self.lm_head(hidden_states) + # Enable model parallelism + labels = labels.to(logits.device) + # labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + loss = loss.view(*labels.size()) + del logits + logits = None + else: + logits = self.lm_head(hidden_states) + + if not return_dict: + raise NotImplementedError + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7f434981dc7bf9b220ed13f2cf53f70c18da7df0 --- /dev/null +++ b/__init__.py @@ -0,0 +1 @@ +# for HF remote code diff --git a/__pycache__/__init__.cpython-310.pyc b/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a28a7d4ba90993f27488ffb8cc3583fc77551ef8 Binary files /dev/null and b/__pycache__/__init__.cpython-310.pyc differ diff --git a/__pycache__/configuration_forgetting_transformer.cpython-310.pyc b/__pycache__/configuration_forgetting_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d96c62d5aa527a06659247c804b1983d91ad468 Binary files /dev/null and b/__pycache__/configuration_forgetting_transformer.cpython-310.pyc differ diff --git a/__pycache__/fgate_cache.cpython-310.pyc b/__pycache__/fgate_cache.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec6b896076d3b5f02940a665a10dff4dcbf3b286 Binary files /dev/null and b/__pycache__/fgate_cache.cpython-310.pyc differ diff --git a/__pycache__/glu_linear.cpython-310.pyc b/__pycache__/glu_linear.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b7097c4b724b5926a20af310597db54124c7042 Binary files /dev/null and b/__pycache__/glu_linear.cpython-310.pyc differ diff --git a/__pycache__/modeling_forgetting_transformer.cpython-310.pyc b/__pycache__/modeling_forgetting_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d117d03f35cee067a7c2cae3680f54ecf6004d3 Binary files /dev/null and b/__pycache__/modeling_forgetting_transformer.cpython-310.pyc differ diff --git a/__pycache__/token_shift.cpython-310.pyc b/__pycache__/token_shift.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c36abad55ebe157bb655b3acddd7d4615e8fb395 Binary files /dev/null and b/__pycache__/token_shift.cpython-310.pyc differ diff --git a/configuration_forgetting_transformer.py b/configuration_forgetting_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..7f4e7d87a4fd925592b9ef29a66cf924fa95f36e --- /dev/null +++ b/configuration_forgetting_transformer.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +from typing import Optional +from transformers.configuration_utils import PretrainedConfig + +class ForgettingTransformerConfig(PretrainedConfig): + model_type = 'forgetting_transformer' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + hidden_ratio: Optional[float] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 32, + num_kv_heads: int = None, + hidden_act: str = "swish", + window_size: Optional[int] = None, + max_position_embeddings: int = 2048, + initializer_range: float = 0.02, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + attention_bias: bool = False, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + rope_base: float = 500000.0, + use_rope: bool = False, + use_output_gate: bool = False, + ogate_act: str = "sigmoid", + fgate_type: str = "full", + fgate_bias_init: bool = False, + decay_time_min: Optional[float] = None, + decay_time_max: Optional[float] = None, + use_output_norm: bool = False, + qk_norm: bool = False, + qk_norm_share_param_across_head: bool = False, + use_k_shift: bool = False, + use_v_shift: bool = False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.window_size = window_size + self.max_position_embeddings = max_position_embeddings + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_norm = fuse_norm + self.rope_base = rope_base + self.use_rope = use_rope + self.use_output_gate = use_output_gate + self.ogate_act = ogate_act + self.fgate_type = fgate_type + self.fgate_bias_init = fgate_bias_init + self.decay_time_min = decay_time_min + self.decay_time_max = decay_time_max + self.use_output_norm = use_output_norm + self.qk_norm = qk_norm + self.qk_norm_share_param_across_head = qk_norm_share_param_across_head + self.use_k_shift = use_k_shift + self.use_v_shift = use_v_shift + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/fgate_cache.py b/fgate_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..ba956105bfa56d210a1e4e7ac9ba887abf6b6de7 --- /dev/null +++ b/fgate_cache.py @@ -0,0 +1,143 @@ +from typing import List, Tuple, Optional, Any, Dict +import torch + +class FgateDynamicCache: + """ + A cache that grows dynamically as more tokens are generated. + Custom cache for Forgetting Transformer that does not inherit from transformers.Cache. + """ + + def __init__(self, num_hidden_layers: Optional[int] = None) -> None: + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + self.log_fgate_cache: List[torch.Tensor] = [] + self.key_shift_cache: List[torch.Tensor] = [] + self.value_shift_cache: List[torch.Tensor] = [] + self._seen_tokens = 0 + + def update_shift_cache( + self, + key_shift_state: torch.Tensor, + value_shift_state: torch.Tensor, + layer_idx, + ): + assert layer_idx == len(self.key_shift_cache) == len(self.value_shift_cache) + self.key_shift_cache.append(key_shift_state) + self.value_shift_cache.append(value_shift_state) + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + if layer_idx < len(self): + return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + for layer_idx in range(len(self)): + yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]) + + def __len__(self): + return len(self.key_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + log_fgate_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert log_fgate_states.ndim == 3, f"log_fgate must be (B, H, T), but get {log_fgate_states.size()}" + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + self.log_fgate_cache.append(log_fgate_states) + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + self.log_fgate_cache[layer_idx] = torch.cat([self.log_fgate_cache[layer_idx], log_fgate_states], dim=-1) + + return self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_length(self) -> Optional[int]: + return None + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], ...]: + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]),) + return legacy_cache + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_layers: Optional[int] = None) -> "FgateDynamicCache": + """ + Converts a cache in the legacy cache format into an equivalent FgateDynamicCache. + + Args: + past_key_values: Optional legacy cache format + num_layers: Not used in this implementation + + Returns: + FgateDynamicCache instance + """ + cache = cls() + + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states, log_fgate_states = past_key_values[layer_idx] + cache.update(key_states, value_states, log_fgate_states, layer_idx) + + return cache + + def crop(self, max_length: int): + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + self._seen_tokens = max_length + for idx in range(len(self.key_cache)): + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + self.log_fgate_cache[idx] = self.log_fgate_cache[idx][..., :max_length] + + def batch_split(self, full_batch_size: int, split_size: int) -> List["FgateDynamicCache"]: + out = [] + for i in range(0, full_batch_size, split_size): + current_split = FgateDynamicCache() + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] + current_split.log_fgate_cache = [tensor[i : i + split_size] for tensor in self.log_fgate_cache] + out.append(current_split) + return out + + @classmethod + def from_batch_splits(cls, splits: List["FgateDynamicCache"]) -> "FgateDynamicCache": + cache = cls() + for idx in range(len(splits[0])): + layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) + layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) + layer_log_fgates = torch.cat([current.log_fgate_cache[idx] for current in splits], dim=0) + cache.update(layer_keys, layer_values, layer_log_fgates, idx) + return cache + + def batch_repeat_interleave(self, repeats: int): + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx].repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor): + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx][indices, ...] \ No newline at end of file diff --git a/fgate_cache.py.backup b/fgate_cache.py.backup new file mode 100644 index 0000000000000000000000000000000000000000..bef4bad152a1679650aa4c271e559cc76e0f67e1 --- /dev/null +++ b/fgate_cache.py.backup @@ -0,0 +1,203 @@ +from typing import List, Tuple, Optional, Any, Dict +import torch +from transformers.cache_utils import Cache + +class FgateDynamicCache(Cache): + """ + A cache that grows dynamically as more tokens are generated. This is the default for generative models. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = DynamicCache() + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + DynamicCache() + ``` + """ + + def __init__(self) -> None: + super().__init__() + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + self.log_fgate_cache: List[torch.Tensor] = [] + + self.key_shift_cache: List[torch.Tensor] = [] + self.value_shift_cache: List[torch.Tensor] = [] + + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + + def update_shift_cache( + self, + key_shift_state: torch.Tensor, + value_shift_state: torch.Tensor, + layer_idx, + ): + assert layer_idx == len(self.key_shift_cache) == len(self.value_shift_cache) + self.key_shift_cache.append(key_shift_state) + self.value_shift_cache.append(value_shift_state) + + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + log_fgate_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + assert log_fgate_states.ndim == 3, f"log_fgate must be (B, H, T), but get {log_fgate_states.size()}" + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + self.log_fgate_cache.append(log_fgate_states) + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + self.log_fgate_cache[layer_idx] = torch.cat([self.log_fgate_cache[layer_idx], log_fgate_states], dim=-1) + + return self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" + return None + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]),) + return legacy_cache + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_layers: Optional[int] = None) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" + raise NotImplementedError + assert num_layers is not None + cache = cls(num_layers) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states, log_fgate_states = past_key_values[layer_idx] + cache.update(key_states, value_states, log_fgate_states, layer_idx) + return cache + + def crop(self, max_length: int): + """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" + # In case it is negative + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + self._seen_tokens = max_length + for idx in range(len(self.key_cache)): + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + self.log_fgate_cache[idx] = self.log_fgate_cache[idx][..., :max_length] + + def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" + out = [] + for i in range(0, full_batch_size, split_size): + current_split = DynamicCache() + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] + current_split.log_fgate_cache = [tensor[i : i + split_size] for tensor in self.log_fgate_cache] + out.append(current_split) + return out + + @classmethod + def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache": + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" + cache = cls() + for idx in range(len(splits[0])): + layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) + layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) + layer_log_fgates = torch.cat([current.log_fgate_cache[idx] for current in splits], dim=0) + cache.update(layer_keys, layer_values, layer_log_fgates, idx) + return cache + + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx].repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx][indices, ...] diff --git a/glu_linear.py b/glu_linear.py new file mode 100644 index 0000000000000000000000000000000000000000..fff265ac29aee9c5804bfccd5716ba7b446cbbfb --- /dev/null +++ b/glu_linear.py @@ -0,0 +1,61 @@ +import torch +import torch.nn.functional as F + + +glu_fwd_codestring = """ +template T glu_fwd(T x, T y) { + return float(y) / (1.0f + ::exp(-float(x))); +} +""" +glu_bwd_codestring = """ +template T glu_bwd(T x, T y, T g, T& dx, T& dy) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + dx = x_sigmoid * (1.0f - x_sigmoid) * float(g) * float(y); + dy = x_sigmoid * float(g); +} +""" + +glu_bwd_with_output_codestring = """ +template T glu_bwd_with_output(T x, T y, T g, T& dx, T& dy, T& z) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + dx = x_sigmoid * (1.0f - x_sigmoid) * float(g) * float(y); + dy = x_sigmoid * float(g); + z = x_sigmoid * float(y); +} +""" + +glu_fwd = torch.cuda.jiterator._create_jit_fn(glu_fwd_codestring) +glu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(glu_bwd_codestring, num_outputs=2) +glu_bwd_with_output = torch.cuda.jiterator._create_multi_output_jit_fn(glu_bwd_with_output_codestring, num_outputs=3) + + +class GLULinearFunction(torch.autograd.Function): + r""" + Gated Linear Unit (GLU) function followed by a linear transformation. + + .. math:: + \text{GLULinear}(x, y, W, b) = (sh(x) * y) W + b + + This simple wrap discards the intermediate results of GLU(x, y) to save memory. + """ + + @staticmethod + def forward(ctx, x, y, weight, bias): + z = glu_fwd(x, y) + out = F.linear(z.to(weight.dtype), weight, bias) + # We don't store z, will be recomputed in the backward pass to save memory + ctx.save_for_backward(x, y, weight) + ctx.linear_bias_is_none = bias is None + return out + + @staticmethod + def backward(ctx, dout, *args): + x, y, weight = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dz = F.linear(dout, weight.t()).view_as(x) + dx, dy, z = glu_bwd_with_output(x, y, dz) + dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1])) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + return dx, dy, dlinear_weight, dlinear_bias + +glu_linear = GLULinearFunction.apply diff --git a/modeling_forgetting_transformer.py b/modeling_forgetting_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..deff96ee4c8558b72b8a525d4b38184b2a46af77 --- /dev/null +++ b/modeling_forgetting_transformer.py @@ -0,0 +1,910 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +# from fla.layers.attn import Attention +from fla.modules import FusedCrossEntropyLoss, RMSNorm +from fla.modules.layernorm import group_norm_fn +from fla.modules.activations import swiglu_linear + +from fla.modules import RotaryEmbedding +from einops import rearrange + +# 动态导入配置类以支持本地和HuggingFace Hub加载 +try: + from .configuration_forgetting_transformer import ForgettingTransformerConfig +except (ImportError, ValueError): + try: + from configuration_forgetting_transformer import ForgettingTransformerConfig + except ImportError: + from forgetting_transformer.model.forgetting_transformer.configuration_forgetting_transformer import ForgettingTransformerConfig + +from forgetting_transformer.ops.forgetting_attention_std import forgetting_attention_std as forgetting_attention +from .fgate_cache import FgateDynamicCache +from .glu_linear import glu_linear +from .token_shift import token_shift + +from functools import partial + +logger = logging.get_logger(__name__) + + +class ShiftLinear(nn.Module): + + def __init__( + self, + input_dim: int, + output_dim: int, + num_heads: int, + bias: bool, + shift_bias: bool = False + ): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.num_heads = num_heads + assert self.output_dim % self.num_heads == 0 + + self.linear = nn.Linear(input_dim, output_dim, bias=bias) + self.shift_proj = nn.Linear(input_dim, num_heads, bias=shift_bias) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim})" + return s + + def forward(self, x: torch.Tensor, shift_state: Optional[torch.Tensor]) -> torch.Tensor: + assert x.ndim == 3, "Input must be (B, T, D)" + B, T, D = x.size() + out = self.linear(x) + # (B, T, H, 1) + alpha = torch.sigmoid(self.shift_proj(x).float()).float() + # left, right, top, bottom (B, T=H, D=W) + # out_prev = nn.functional.pad(out, (0, 0, 1, -1)) + # out_prev = torch.roll(out, shifts=1, dims=1) + + out_per_head = rearrange(out, 'b t (h d) -> b t h d', h=self.num_heads) + if T > 1: + # TODO: note in this case cache is not used + result_per_head = token_shift(out_per_head, alpha, 1.0 - alpha) + else: + shift_state_per_head = rearrange(shift_state, 'b (h d) -> b 1 h d', h=self.num_heads) + result_per_head = (alpha[..., None] * shift_state_per_head + (1 - alpha[..., None]) * out_per_head) + + result_per_head = result_per_head.to(out.dtype) + + if shift_state is not None: + shift_state.copy_(out[:, -1, :]) + + result = rearrange(result_per_head, 'b t h d -> b t (h d)', h=self.num_heads) + return result + +class GroupRMSNorm(nn.Module): + def __init__( + self, + num_groups: int, + hidden_size: int, + elementwise_affine: bool = True, + bias: bool = False, + eps: float = 1e-5 + ) -> GroupRMSNorm: + super().__init__() + + if hidden_size % num_groups != 0: + raise ValueError('num_channels must be divisible by num_groups') + + self.num_groups = num_groups + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + if bias: + self.bias = nn.Parameter(torch.zeros(hidden_size)) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.num_groups}, {self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return group_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=True, + num_groups=self.num_groups + ) + +class ForgettingAttentionLayer(nn.Module): + + def __init__( + self, + hidden_size: int = 2048, + num_heads: int = 32, + num_kv_heads: Optional[int] = None, + window_size: Optional[int] = None, + max_position_embeddings: Optional[int] = None, + use_rope: bool = False, + rope_base: float = 500000.0, + use_output_gate: bool = False, + ogate_act: str = "sigmoid", + fgate_type: str = "full", + fgate_bias_init: bool = False, + decay_time_min: Optional[float] = None, + decay_time_max: Optional[float] = None, + use_output_norm: bool = False, + norm_eps: float = 1e-6, + qk_norm: bool = False, + qk_norm_share_param_across_head: bool = False, + use_k_shift: bool = False, + use_v_shift: bool = False, + initializer_range: float = 0.02, + layer_idx: int = None + ): + """ + Forgetting Attention layer. + + Arguments: + - hidden_size: Input dimension and qkv dimension + - num_heads: Number of heads + - num_kv_heads: Not used. Should be None + - window_size: Not used. Should be None + - max_position_embeddings: Not used. Should be None + - use_rope: Whether to use RoPE. Default is False + - rope_base: the theta hyperparameter in RoPE. This has no effect if + use_rope=False + - use_output_gate: Whether to use output gates. Note that using output gates + introduces extra parameters and you may want to reduce parameters from + other components (e.g., MLPs) + - ogate_act: Activation for the output gate. Either "sigmoid" or "silu" + - fgate_type: Forget gate type. The following are supported: + - "full": The default data-dependent forget gate + - "bias_only": The data-independent forget gate + - "fixed": Forget gates with fixed values + - "none": Not using forget gates. Equivalent to forget gates with all + ones. + - fgate_bias_init: Whether to use special initalization for the bias terms in + the forget gate. This should only be used with fgate types in + ["bias_only", "fixed"]. + - decay_time_min: T_min for the forget gate bias initialization. See paper + for details. + - decay_time_max: T_max for the forget gate bias initalization. See paper + for details. + - use_output_norm: Whether to use output normalization. + - norm_eps: Epsilon for the RMSNorms + - qk_norm: Whether to use qk_norm + - qk_norm_share_param_across_head: In QK-norm, whether to share the RMSNorm + scaling parameters across heads. This is just for backward compatibility. + - use_k_shift: Whether to use data-dependent key shift + - use_v_shift: Whether to use data-dependent value shift + - initializer_range: standard deviation for initialization + - layer_idx: The block index of this layer. Needed for KV-cache + """ + super().__init__() + + self.num_heads = num_heads + if num_kv_heads is None: + self.num_kv_heads = self.num_heads + else: + raise NotImplementedError("GQA has not been tested.") + self.num_kv_heads = num_kv_heads + self.num_kv_groups = num_heads // self.num_kv_heads + self.hidden_size = hidden_size + self.head_dim = self.hidden_size // self.num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + self.kv_dim = self.num_kv_heads * self.head_dim + self.window_size = window_size + self.max_position_embeddings = max_position_embeddings + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + if use_k_shift: + self.k_proj = ShiftLinear(self.hidden_size, self.kv_dim, self.num_heads, bias=False) + else: + self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) + + if use_v_shift: + self.v_proj = ShiftLinear(self.hidden_size, self.kv_dim, self.num_heads, bias=False) + else: + self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) + + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.use_k_shift = use_k_shift + self.use_v_shift = use_v_shift + + + device = next(self.parameters()).device + # Forget gate + assert fgate_type in ["full", "bias_only", "fixed", "none"] + self.fgate_type = fgate_type + self.fgate_bias_init = fgate_bias_init + if fgate_type == "full": + assert not fgate_bias_init + self.fgate_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True) + elif fgate_type == "bias_only": + self.fgate_bias = nn.Parameter(torch.zeros(size=(self.num_heads,), device=device)) + self.fgate_bias._no_weight_decay = True + elif fgate_type == "fixed": + assert fgate_bias_init, "You must set fgate_bias_init = True with fixed fgate" + fgate_bias = torch.zeros(size=(self.num_heads,), device=device) + self.register_buffer("fgate_bias", fgate_bias) + elif fgate_type == "none": + pass + else: + raise ValueError(f"Unknown fgate type {fgate_type}") + + + + # Forget gate intialization for data-independent and fixed forget gates + if fgate_bias_init: + assert decay_time_min is not None and decay_time_max is not None + assert decay_time_min > 0 and decay_time_max > 0 + with torch.no_grad(): + log_decay_time = torch.linspace(math.log(decay_time_min), math.log(decay_time_max), steps=self.num_heads) + decay_time = torch.exp(log_decay_time) + # Such that t = -1 / log(sigmoid(b)) + bias_init = -torch.log(torch.expm1(1 / decay_time)) + self.fgate_bias.copy_(bias_init) + else: + assert decay_time_min is None and decay_time_max is None + + if use_output_gate: + self.ogate_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.ogate_act = ogate_act + assert ogate_act in ["silu", "sigmoid"] + else: + self.ogate_proj = None + + if use_output_norm: + self.output_norm = GroupRMSNorm(num_groups=self.num_heads, hidden_size=self.hidden_size, eps=norm_eps) + else: + self.output_norm = None + + + if use_rope: + self.rotary = RotaryEmbedding(self.head_dim, base=rope_base) + else: + self.rotary = None + + + self.qk_norm = qk_norm + self.qk_norm_share_param_across_head = qk_norm_share_param_across_head + if qk_norm: + if self.qk_norm_share_param_across_head: + # This is an incorrect implemention kept just for backward compatibility + self.q_norm = RMSNorm(self.head_dim) + self.k_norm = RMSNorm(self.head_dim) + else: + self.q_norm = GroupRMSNorm(num_groups=self.num_heads, hidden_size=self.hidden_size) + self.k_norm = GroupRMSNorm(num_groups=self.num_heads, hidden_size=self.hidden_size) + + self.initializer_range = initializer_range + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + # This will actually be overwritten by outer init. + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=self.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + We assume that during decoding attention mask is always 1. Otherwise it won't work. + """ + batch_size, q_len, _ = hidden_states.size() + if use_cache: + key_shift_state = past_key_values.key_shift_cache[self.layer_idx] + value_shift_state = past_key_values.value_shift_cache[self.layer_idx] + else: + key_shift_state = value_shift_state = None + + # Shift states are updated in place + q = self.q_proj(hidden_states) + if self.use_k_shift: + k = self.k_proj(hidden_states, key_shift_state) + else: + k = self.k_proj(hidden_states) + if self.use_v_shift: + v = self.v_proj(hidden_states, value_shift_state) + else: + v = self.v_proj(hidden_states) + + if self.qk_norm and (not self.qk_norm_share_param_across_head): + q = self.q_norm(q).to(q.dtype) + k = self.k_norm(k).to(k.dtype) + + q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads) + k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads) + v = rearrange(v, 'b t (h d) -> b h t d', h=self.num_kv_heads) + + + if self.qk_norm and (self.qk_norm_share_param_across_head): + q = self.q_norm(q).to(q.dtype) + k = self.k_norm(k).to(k.dtype) + + + seqlen_offset, max_seqlen = 0, q.shape[1] + if past_key_values is not None: + seqlen_offset = past_key_values.get_seq_length(self.layer_idx) + max_seqlen = q.shape[1] + seqlen_offset + + if attention_mask is not None: + # to deliminate the offsets of padding tokens + seqlen_offset = (seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]) + max_seqlen = q.shape[1] + max(seqlen_offset) + + if self.max_position_embeddings is not None: + max_seqlen = max(max_seqlen, self.max_position_embeddings) + if self.rotary is not None: + q, k = self.rotary(q, k, seqlen_offset, max_seqlen) + + if self.fgate_type == "full": + fgate_logit = self.fgate_proj(hidden_states) + fgate_logit = rearrange(fgate_logit, "b t h -> b h t") + log_fgate = torch.nn.functional.logsigmoid(fgate_logit.float()) + elif self.fgate_type == "none": + log_fgate = torch.zeros((batch_size, self.num_heads, q_len), dtype=torch.float32, device=hidden_states.device) + else: + assert self.fgate_type in ["fixed", "bias_only"] + fgate_logit = torch.broadcast_to(self.fgate_bias, (batch_size, q_len, self.num_heads)) + fgate_logit = rearrange(fgate_logit, "b t h -> b h t") + log_fgate = torch.nn.functional.logsigmoid(fgate_logit.float()) + + k = rearrange(k, 'b t h d -> b h t d') + if past_key_values is not None: + k, v, log_fgate = past_key_values.update(k, v, log_fgate, self.layer_idx) + # k, v = rearrange(k, 'b h t d -> b t h d'), rearrange(v, 'b h t d -> b t h d') + q = rearrange(q, 'b t h d -> b h t d') + + if self.num_kv_groups > 1: + assert False + k = rearrange(k.unsqueeze(-2).repeat(1, 1, 1, self.num_kv_groups, 1), 'b t h g d -> b t (h g) d') + v = rearrange(v.unsqueeze(-2).repeat(1, 1, 1, self.num_kv_groups, 1), 'b t h g d -> b t (h g) d') + + # Contains at least one padding token in the sequence + if attention_mask is not None: + B, _, T = log_fgate.size() + assert attention_mask.size() == (B, T), ((B, T), attention_mask.size()) + seq_start = T - attention_mask.sum(dim=-1) + o = forgetting_attention( + q, k, v, + log_fgate, + head_first=True, + seq_start=seq_start, + sm_scale=1 / math.sqrt(self.head_dim), + ) + o = rearrange(o, "b h t d -> b t h d") + else: + o = forgetting_attention( + q, k, v, + log_fgate, + head_first=True, + sm_scale=1 / math.sqrt(self.head_dim), + ) + o = rearrange(o, "b h t d -> b t h d") + + o = o.reshape(batch_size, q_len, self.hidden_size) + + if self.output_norm is not None: + o = self.output_norm(o) + + if self.ogate_proj is not None: + # ogate = self.ogate act(self.ogate_proj(hidden_states)) + # o = o * ogate + # ogate = act_gate(self.ogate_proj(hidden_states), o) + ogate_logit = self.ogate_proj(hidden_states) + dtype = ogate_logit.dtype + if self.ogate_act == "silu": + o = swiglu_linear(ogate_logit, o, self.o_proj.weight.to(dtype), self.o_proj.bias.to(dtype) if self.o_proj.bias is not None else self.o_proj.bias) + elif self.ogate_act == "sigmoid": + o = glu_linear(ogate_logit, o, self.o_proj.weight.to(dtype), self.o_proj.bias.to(dtype) if self.o_proj.bias is not None else self.o_proj.bias) + else: + raise ValueError(f"Unknown ogate act {self.ogate_act}") + else: + o = self.o_proj(o) + + if not output_attentions: + attentions = None + else: + SAVE_HEADS = [0, 1, 2, 3] + # (B, H, T, T) + score = q[:, SAVE_HEADS] @ k[:, SAVE_HEADS].mT + log_lambda = torch.cumsum(log_fgate, dim=-1) + decay_bias = (log_lambda[:, SAVE_HEADS, :, None] - log_lambda[:, SAVE_HEADS, None, :]).to(torch.bfloat16) + # normalized_score = torch.softmax(score, dim=-1) + attentions = (score, decay_bias) + + return o, attentions, past_key_values + + def init_shift_state(self, batch_size: int): + param = next(self.parameters()) + state = dict() + try: + dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled("cuda") else torch.float32 + except TypeError: + # Support legacy torch version + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else torch.float32 + if self.use_k_shift: + state['key_shift'] = param.new_zeros(batch_size, self.kv_dim, dtype=dtype) + else: + state['key_shift'] = None + if self.use_v_shift: + state['value_shift'] = param.new_zeros(batch_size, self.kv_dim, dtype=dtype) + else: + state['value_shift'] = None + return state + + +class ForgettingTransformerMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[float] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish' + ) -> ForgettingTransformerMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + self.hidden_act = hidden_act + assert hidden_act in ["swish", "sigmoid"] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + # TODO: maybe wrap swiglu_linear in custom_fwd/custom_bwd + if self.hidden_act == "swish": + return swiglu_linear( + gate, y, + self.down_proj.weight.to(y.dtype), + self.down_proj.bias.to(y.dtype) if self.down_proj.bias is not None else self.down_proj.bias + ) + elif self.hidden_act == "sigmoid": + return glu_linear( + gate, y, + self.down_proj.weight.to(y.dtype), + self.down_proj.bias.to(y.dtype) if self.down_proj.bias is not None else self.down_proj.bias + ) + else: + raise ValueError() + + +class ForgettingTransformerBlock(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn = ForgettingAttentionLayer( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + window_size=config.window_size, + max_position_embeddings=config.max_position_embeddings, + rope_base=config.rope_base, + use_rope=config.use_rope, + use_output_gate=config.use_output_gate, + ogate_act=config.ogate_act, + fgate_type=config.fgate_type, + fgate_bias_init=config.fgate_bias_init, + decay_time_min=config.decay_time_min, + decay_time_max=config.decay_time_max, + use_output_norm = config.use_output_norm, + norm_eps=config.norm_eps, + qk_norm=config.qk_norm, + qk_norm_share_param_across_head=config.qk_norm_share_param_across_head, + use_k_shift=config.use_k_shift, + use_v_shift=config.use_v_shift, + initializer_range=config.initializer_range, + layer_idx=layer_idx + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = ForgettingTransformerMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act + ) + + def forward_attn( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ): + # residual handled outside of this + # residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + return hidden_states, attentions, past_key_values + + def forward_mlp( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + ): + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + gradient_checkpointing: bool = False + # **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + + if gradient_checkpointing: + forward_attn = partial(torch.utils.checkpoint.checkpoint, self.forward_attn, use_reentrant=False) + forward_mlp = partial(torch.utils.checkpoint.checkpoint, self.forward_mlp, use_reentrant=False) + else: + forward_attn = self.forward_attn + forward_mlp = self.forward_mlp + + hidden_states, attentions, past_key_values = forward_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + + hidden_states = forward_mlp( + hidden_states, + residual, + ) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attentions,) + + if use_cache: + outputs += (past_key_values,) + + return outputs + + + +class ForgettingTransformerPreTrainedModel(PreTrainedModel): + + config_class = ForgettingTransformerConfig + supports_gradient_checkpointing = True + _no_split_modules = ['ForgettingTransformerBlock'] + + def __init__(self, config, *inputs, **kwargs): + # 动态修复 config_class 以支持远程代码加载 + if hasattr(config, '__class__'): + config_module = config.__class__.__module__ + if 'transformers_modules' in config_module or config_module == 'configuration_forgetting_transformer': + self.__class__.config_class = config.__class__ + super().__init__(config, *inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + ): + # if isinstance(module, (nn.Linear, nn.Conv1d)): + if isinstance(module, (nn.Linear)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class ForgettingTransformerModel(ForgettingTransformerPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([ForgettingTransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, CausalLMOutputWithPast]: + # if output_attentions: + # warnings.warn( + # "`ForgettingTransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`." + # ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if use_cache: + # use_legacy_cache = not isinstance(past_key_values, Cache) + # if use_legacy_cache: + # past_key_values = FgateDynamicCache.from_legacy_cache(past_key_values) + if past_key_values is None: + past_key_values = FgateDynamicCache() + for layer_idx, layer in enumerate(self.layers): + shift_state = layer.attn.init_shift_state( + batch_size=input_ids.size(0), + ) + past_key_values.update_shift_cache( + key_shift_state=shift_state["key_shift"], + value_shift_state=shift_state["value_shift"], + layer_idx=layer_idx + ) + else: + assert isinstance(past_key_values, FgateDynamicCache) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = {} if output_attentions else None + next_decoder_cache = None + + for layer_id, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + gradient_checkpointing=self.gradient_checkpointing and self.training + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + OUTPUT_ATTN_LAYERS = [0, 7, 15, 23] + if layer_id in OUTPUT_ATTN_LAYERS: + # all_attns += (layer_outputs[1],) + all_attns[layer_id] = layer_outputs[1] + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + # next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class ForgettingTransformerForCausalLM(ForgettingTransformerPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = ForgettingTransformerModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True, reduction='none') + else: + loss_fct = nn.CrossEntropyLoss(reduction='none') + logits = self.lm_head(hidden_states) + # Enable model parallelism + labels = labels.to(logits.device) + # labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + loss = loss.view(*labels.size()) + del logits + logits = None + else: + logits = self.lm_head(hidden_states) + + if not return_dict: + raise NotImplementedError + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/ops/.ipynb_checkpoints/forgetting_attention-checkpoint.py b/ops/.ipynb_checkpoints/forgetting_attention-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..d6cba53a9a3ae49a448247e739f7f3d644caeca5 --- /dev/null +++ b/ops/.ipynb_checkpoints/forgetting_attention-checkpoint.py @@ -0,0 +1,1138 @@ +""" +Implementation of Forgetting Attention. + +Our code is adapted from https://github.com/FlagOpen/FlagAttention/blob/ee91638dec6da8c00c4113d179f469e0ffcd5852/src/flag_attn/flash.py. The code is modified to implement Forgetting Attention. + +The original license info from FlagAttention: + +Copyright 2023 BAAI + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import pytest +import math +import torch +import triton +import triton.language as tl +from einops import rearrange +from typing import Optional + + +__all__ = ["forgetting_attention"] + + +# File flash.py +def maybe_contiguous(x): + # only when the inner most dimension is contiguous can LDGSTS be used + # so inner-dimension contiguity is enforced. + return x.contiguous() if x.stride(-1) != 1 else x + +def rounded_multiple(a, b): + return (a + b - 1) // b * b + +# --------------------------- public API --------------------------- +class ForgettingAttention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, log_fgate, seq_start, causal, sm_scale, return_log_normalizer): + assert causal, "Only causal attention is supported" + Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Dq == Dk == Dv, "feature size of q, k, v should be equal" + assert Dk in {16, 32, 64, 128}, "We only support head dims in {16, 32, 64, 128}" + + B, H, M, D = q.shape + if seq_start is not None: + has_seq_start = True + assert seq_start.shape == (B,) + else: + has_seq_start = False + seq_start = torch.zeros((B,), device=q.device, dtype=torch.long) + N = k.shape[2] + assert log_fgate.shape == (B, H, N) + log_fgate = log_fgate.float() + if has_seq_start: + log_fgate = log_fgate.clone() + # We absolutely don't want masked value to affect result. If we + # don't do this then it could via affecting numerical precision of + # cumsum + mask_index = (torch.arange(N, device=q.device)[None, None, :] < seq_start[:, None, None]) + mask_index = torch.broadcast_to(mask_index, log_fgate.size()) + log_fgate[mask_index] = 0.0 + + log_lambda = torch.cumsum(log_fgate, dim=-1, dtype=log_fgate.dtype).float() + + Hk, Hv = k.shape[1], v.shape[1] + assert Hk == Hv, "num of heads in k and v should be equal" + assert H == Hk, "groupped query attention has not been tested. You can uncomment this if you know what you are doing." + assert H % Hk == 0, "number of heads in q must be a multiple of that in k & v" + num_groups = H // Hk + + P_SEQ = N - M + larger_m = M > N + assert (not larger_m), "The key/value tensors must be longer than the query tensor" + + if sm_scale is None: + sm_scale = 1. / math.sqrt(D) + + # contiguity + q, k, v = maybe_contiguous(q), maybe_contiguous(k), maybe_contiguous(v) + + # to work around https://github.com/openai/triton/issues/2441 + device = torch.cuda.device_of(q) + + with torch.cuda.device(device): + + config = get_fwd_config(B, H, M, N, D, causal) + BLOCK_M, BLOCK_N, num_stages, num_warps = config + + divisible_m = M % BLOCK_M == 0 + divisible_n = N % BLOCK_N == 0 + # consider using 3d grid to avoid div & rem + grid = (triton.cdiv(M, BLOCK_M), H, B) + o = torch.empty_like(q) + L = torch.empty((B, H, M), device=q.device, dtype=torch.float32) + _fwd_kernel[grid]( + q, k, v, log_lambda, seq_start, sm_scale, + L, o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + log_lambda.stride(0), log_lambda.stride(1), log_lambda.stride(2), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + B, H, M, N, P_SEQ, num_groups, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D, + IS_CAUSAL=causal, LARGER_M=larger_m, HAS_SEQ_START=has_seq_start, + DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, + num_warps=num_warps, num_stages=num_stages, + ) + + # autograd context maintenance + ctx.save_for_backward(q, k, v, o, L, log_lambda, seq_start) + ctx.sm_scale = sm_scale + ctx.causal = causal + ctx.has_seq_start = has_seq_start + + has_extra_return = return_log_normalizer + if has_extra_return: + outs = ( + o, + L if return_log_normalizer else None, + ) + return outs + return o + + @staticmethod + def backward(ctx, do, *ignored): + q, k, v, o, L, log_lambda, seq_start = ctx.saved_tensors + sm_scale = ctx.sm_scale + causal = ctx.causal + has_seq_start = ctx.has_seq_start + + B, H, M, D = q.shape + N = k.shape[2] + Hk = k.shape[1] + num_groups = H // Hk + P_SEQ = N - M + larger_m = M > N + + if sm_scale is None: + sm_scale = 1. / math.sqrt(D) + + # to work around https://github.com/openai/triton/issues/2441 + device = torch.cuda.device_of(q) + with torch.cuda.device(device): + config = get_bwd_config(B, H, M, N, D, causal) + BLOCK_M, BLOCK_N, num_stages, num_warps = config + + divisible_m = M % BLOCK_M == 0 + divisible_n = N % BLOCK_N == 0 + + delta = torch.empty_like(L) + grid = (triton.cdiv(M, BLOCK_M), H, B) + _bwd_preprocess[grid]( + o, do, + delta, + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + do.stride(0), do.stride(1), do.stride(2), do.stride(3), + delta.stride(0), delta.stride(1), delta.stride(2), + M, + BLOCK_M=BLOCK_M, D_HEAD=D, + DIVISIBLE_M=divisible_m, + ) + + # NOTE that dk & dv always have the same number of heads as q, instead of q. + BLOCK_M, BLOCK_N, num_stages, num_warps = get_bwd_kv_config(B, H, M, N, D, causal) + divisible_m = M % BLOCK_M == 0 + divisible_n = N % BLOCK_N == 0 + + dk = torch.empty((B, H, N, D), dtype=k.dtype, device=q.device) + dv = torch.empty((B, H, N, D), dtype=v.dtype, device=q.device) + dlog_lambda = torch.empty((B, H, N), dtype=log_lambda.dtype, device=q.device) + grid = (triton.cdiv(N, BLOCK_N), H, B) + _bwd_kv_kernel[grid]( + q, k, v, log_lambda, seq_start, sm_scale, do, + dk, dv, dlog_lambda, + L, delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + log_lambda.stride(0), log_lambda.stride(1), log_lambda.stride(2), + do.stride(0), do.stride(1), do.stride(2), do.stride(3), + dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), + dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), + dlog_lambda.stride(0), dlog_lambda.stride(1), dlog_lambda.stride(2), + B, H, M, N, P_SEQ, + num_groups, + BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal, + DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, HAS_SEQ_START=has_seq_start, + num_stages=num_stages, num_warps=num_warps, + ) + + BLOCK_M, BLOCK_N, num_stages, num_warps = get_bwd_q_config(B, H, M, N, D, causal) + divisible_m = M % BLOCK_M == 0 + divisible_n = N % BLOCK_N == 0 + dq = torch.zeros_like(q) + grid = (triton.cdiv(M, BLOCK_M), H, B) + _bwd_q_kernel[grid]( + q, k, v, log_lambda, seq_start, sm_scale, do, + dq, dlog_lambda, + L, delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + log_lambda.stride(0), log_lambda.stride(1), log_lambda.stride(2), + do.stride(0), do.stride(1), do.stride(2), do.stride(3), + dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3), + dlog_lambda.stride(0), dlog_lambda.stride(1), dlog_lambda.stride(2), + B, H, M, N, P_SEQ, + num_groups, + BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, + CAUSAL=causal, LARGER_M=larger_m, HAS_SEQ_START=has_seq_start, + DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, + num_stages=num_stages, num_warps = num_warps, + ) + dk = dk.reshape((B, Hk, num_groups, N, D)).sum(2) + dv = dv.reshape((B, Hk, num_groups, N, D)).sum(2) + dcumsum = torch.cumsum(dlog_lambda, dim=-1, dtype=log_lambda.dtype) + dlog_fgate = dlog_lambda + dcumsum[..., -1:] - dcumsum + dlog_fgate = dlog_fgate.float() + return dq, dk, dv, dlog_fgate, None, None, None, None, None, None, None + + +def forgetting_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + log_fgate: torch.Tensor, + *, + head_first: bool = False, + seq_start: Optional[torch.Tensor] = None, + sm_scale: Optional[float] = None, +): + """ + A FlashAttention-based implementation of Forgetting Attention. + + Note: + - We recommand bfloat16/float16 for q, k, v and float32 for log_fgate. float32 for + q, k, v is also supported, but the kernel will not use tensor cores if q, k, v are + in float32 (which would be slow). + - We only support seqlen_q <= seqlen_k + - We only support causal attention + - Head dimension must be in one of {16, 32, 64, 128} + + Arguments: + - q: (batch_size, seqlen_q, num_heads, head_dim) unless head_first=True. + - k: (batch_size, seqlen_k, num_heads, head_dim) unless head_first=True. + - v: (batch_size, seqlen_k, num_heads, head_dim) unless head_first=True. + - log_fgate: (batch_size, seqlen_k, num_heads) unless head_first=True. + This should be the **log** of the forget gates. This is typically the + output of torch.nn.functional.logsigmoid. + - head_first: if True, the order the num_heads and seqlen_* axis of the all + FloatTensor inputs and outputs should be (num_heads, seq_len_*) instead of + (seq_len_*, num_heads) + - seq_start: If not None, should be LongTensor with shape (batch_size,) + and range in [0, seq_len_k). For each batch index batch_id, no attention + will be allocated to tokens before the token index seq_start[batch_id]. + This is useful for left-padded inputs. + - sm_scale: The scaling of attention scores before applying softmax. If + None, it defaults to (1.0 / math.sqrt(head_dim)) + + Returns: + out (torch.Tensor): (batch_size, seqlen_q, num_heads, head_dim) unless head_first=True. + """ + if not head_first: + q, k, v = [rearrange(item, "b t h d -> b h t d") for item in (q, k, v)] + log_fgate = rearrange(log_fgate, "b t h -> b h t") + out = ForgettingAttention.apply(q, k, v, log_fgate, seq_start, True, sm_scale, False) + if not head_first: + out = rearrange(out, "b h t d -> b t h d") + return out + + +# --------------------------- Forward --------------------------- +# NOTE: this function can be overwritten at runtime to use your custom config +def get_fwd_config(B, H, M, N, D, causal): + assert causal + if torch.cuda.get_device_capability() == (8, 0): + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 32, 3, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 4, 4 + elif torch.cuda.get_device_capability() == (9, 0): + # H100 + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 8 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 2, 8 + elif torch.cuda.get_device_capability() == (8, 6): + if not causal: + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4 + else: # causal + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4 + elif torch.cuda.get_device_capability() == (8, 9): + # L40S + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 2, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 + return (BLOCK_M, BLOCK_N, num_stages, num_warps) + + +@triton.jit +def _fwd_kernel( + Q, K, V, LOG_LAMBDA, SEQ_START, sm_scale, + L, O, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_log_lambda_z, stride_log_lambda_h, stride_log_lambda_n, + stride_oz, stride_oh, stride_om, stride_ok, + Z, H, M, N, P_SEQ, + num_groups, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, HAS_SEQ_START: tl.constexpr, + DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, +): + input_dtype = Q.dtype.element_ty + # -- grid id -- + start_m = tl.program_id(0) + off_h = tl.program_id(1) + off_z = tl.program_id(2) + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + log2e: tl.constexpr = 1.4426950408889634 + loge2: tl.constexpr = 0.6931471805599453 + qk_scale = sm_scale * log2e + + # offset pointers for (batch, head) + off_hk = off_h // num_groups + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_hk * stride_kh + V += off_z * stride_vz + off_hk * stride_vh + LOG_LAMBDA += off_z * stride_log_lambda_z + off_h * stride_log_lambda_h + O += off_z * stride_oz + off_h * stride_oh + L += (off_z * H + off_h) * M # l's shape is (B, H, M) + + offs_m_base = tl.arange(0, BLOCK_M) + offs_m = start_m * BLOCK_M + offs_m_base + offs_n_base = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + + + # initialize pointers to value-like data + q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) + log_lambda_out_ptrs = LOG_LAMBDA + (P_SEQ + offs_m) * stride_log_lambda_n + o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok) # (BLOCK_M, BLOCK_DMODEL) + l_ptrs = L + offs_m + + # initialize pointer to m and l, fp32 for accumulators + m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + # load q + if DIVISIBLE_M: + q = tl.load(q_ptrs, cache_modifier=".cg") + log_lambda_out = tl.load(log_lambda_out_ptrs, cache_modifier=".cg") + else: + mask_m = offs_m < M + q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=".cg") + log_lambda_out = tl.load(log_lambda_out_ptrs, mask=mask_m, cache_modifier=".cg") + + #Dot I trick: to place q in registers, it saves shared memory + # if BLOCK_DMODEL < 128: + # I = tl.where(offs_k[:, None] == offs_k, + # tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype), + # tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype)) + # q = tl.dot(q, I, input_precision="ieee").to(input_dtype) + # else: + # I = tl.where(offs_m_base[:, None] == offs_m_base, + # tl.full((BLOCK_M, BLOCK_M), 1.0, dtype=input_dtype), + # tl.full((BLOCK_M, BLOCK_M), 0.0, dtype=input_dtype)) + # q = tl.dot(I, q, input_precision="ieee").to(input_dtype) + + # NOTE: Loop-Bound-For-N + # The indices in m-dimension that this block may access is in `[start_m * BLOCK_M, (start_m + 1) * BLOCK_M)`. + # According to the rule of causal masking, then max index in n-dimension that this block may access + # is `P_SEQ + (start_m + 1) * BLOCK_M`. + # However, the upper bound of index in n-dimension should never exceed the sequence length of k/v(`P_SEQ + N_CTX`). + # `P_SEQ + (start_m + 1) * BLOCK_M` may be larger than `N`. + # At this case, there would be illegal memory access when loading k & v tiles + # if mask_n is not applied for loading(only when `DIVISIBLE_N`` is true). + # See also https://github.com/FlagOpen/FlagAttention/pull/8 + if IS_CAUSAL: + hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M) + if LARGER_M: + hi = tl.maximum(0, hi) + else: + hi = N + + offs_n_init = offs_n_base + if HAS_SEQ_START: + SEQ_START += off_z + seq_start = tl.load(SEQ_START) + lo = tl.minimum(seq_start, hi) + lo = (lo // BLOCK_N) * BLOCK_N + offs_n_init += lo + else: + lo = 0 + seq_start = 0 + + # loop over k, v and update accumulators + k_ptrs = K + (offs_k[:, None] * stride_kk + offs_n_init[None, :] * stride_kn) # (BLOCK_DMODEL, BLOCK_N) + v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL) + log_lambda_in_ptrs = LOG_LAMBDA + (offs_n_init * stride_log_lambda_n) # (BLOCK_N, BLOCK_DMODEL) + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + offs_n = start_n + offs_n_base + + # -- load k, v -- + if DIVISIBLE_N: + k = tl.load(k_ptrs, cache_modifier=".cg") + v = tl.load(v_ptrs, cache_modifier=".cg") + log_lambda_in = tl.load(log_lambda_in_ptrs, cache_modifier=".cg") + else: + mask_n = offs_n < N + k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=".cg") + v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=".cg") + log_lambda_in = tl.load(log_lambda_in_ptrs, mask=mask_n, cache_modifier=".cg") + + # -- compute qk --- + # s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + s = tl.dot(q, k, input_precision="ieee") * qk_scale + decay_bias = log_lambda_out[:, None] - log_lambda_in[None, :] + s += decay_bias * log2e + + if not DIVISIBLE_N: + s = tl.where(mask_n[None, :], s, float("-inf")) + if IS_CAUSAL: + causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :] + s = tl.where(causal_mask, s, float("-inf")) + if HAS_SEQ_START: + s = tl.where(offs_n[None, :] >= seq_start, s, float("-inf")) + + + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(s, 1)) + alpha = tl.math.exp2((m_i - m_i_new)) + p = tl.math.exp2(s - m_i_new[:, None]) + + # -- compute partial sumexpn before applying dropout + p_sum = tl.sum(p, 1) + + + # -- scale and update acc: acc *= alpha[:, None]-- + acc *= alpha[:, None] + acc += tl.dot(p.to(input_dtype), v, input_precision="ieee") + + # -- update m_i and l_i -- + l_i = l_i * alpha + p_sum + m_i = m_i_new + # update pointers + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vn + log_lambda_in_ptrs += BLOCK_N * stride_log_lambda_n + + # write back l & o + if IS_CAUSAL and (LARGER_M or HAS_SEQ_START): + is_empty_line = (offs_m + P_SEQ) < seq_start + acc = tl.where(is_empty_line[:, None], 0.0, acc * (1.0 / l_i[:, None])) + l = tl.where(is_empty_line, float("-inf"), m_i * loge2 + tl.log(l_i)) + else: + acc = acc * (1.0 / l_i[:, None]) + l = m_i * loge2 + tl.log(l_i) # log(normalizer) + + + if DIVISIBLE_M: + tl.store(l_ptrs, l, cache_modifier=".cg") + tl.store(o_ptrs, acc.to(input_dtype), cache_modifier=".cg") + else: + tl.store(l_ptrs, l, mask=mask_m, cache_modifier=".cg") + tl.store(o_ptrs, acc.to(input_dtype), mask=mask_m[:, None], cache_modifier=".cg") + + +# --------------------------- Backward --------------------------- +# NOTE: this function can be overwritten at runtime to use your custom config +def get_bwd_config(B, H, M, N, D, causal): + if torch.cuda.get_device_capability() == (9, 0): + if not causal: + BLOCK_M = 128 if D <= 64 else 64 + BLOCK_N = 64 + num_stages = 2 + num_warps = 4 + else: + BLOCK_M = 64 + BLOCK_N = 64 + num_stages = 3 if D <= 64 else 2 + num_warps = 4 + elif torch.cuda.get_device_capability() == (8, 0): + if not causal: + BLOCK_M = 128 if D <= 64 else 64 + BLOCK_N = 64 + num_stages = 2 + num_warps = 4 + else: + BLOCK_M = 64 + BLOCK_N = 64 + num_stages = 3 if D <= 64 else 2 + num_warps = 4 + elif torch.cuda.get_device_capability() == (8, 6): # tune for RTX-3090, device_capability(8, 6) + if not causal: + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 8 + else: + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4 + return (BLOCK_M, BLOCK_N, num_stages, num_warps) + +def get_bwd_kv_config(B, H, M, N, D, causal): + assert causal + if torch.cuda.get_device_capability() == (8, 0): # A100 + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 4, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 128, 4, 8 + elif torch.cuda.get_device_capability() == (8, 6): # tune for RTX-3090, device_capability(8, 6) + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4 + elif torch.cuda.get_device_capability() == (8, 9): # L40S + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 128, 4, 8 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 128, 2, 8 + elif torch.cuda.get_device_capability() == (9, 0): # H100 + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 + return (BLOCK_M, BLOCK_N, num_stages, num_warps) + +def get_bwd_q_config(B, H, M, N, D, causal): + assert causal + if torch.cuda.get_device_capability() == (8, 0): # A100 + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 4, 8 + elif torch.cuda.get_device_capability() == (8, 6): # tune for RTX-3090, device_capability(8, 6) + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4 + elif torch.cuda.get_device_capability() == (8, 9): # L40S + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 4, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4 + elif torch.cuda.get_device_capability() == (9, 0): # H100 + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 4, 8 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 2, 8 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 + return (BLOCK_M, BLOCK_N, num_stages, num_warps) + + +@triton.jit +def _bwd_preprocess( + Out, DO, + Delta, + stride_oz, stride_oh, stride_om, stride_ok, + stride_doz, stride_doh, stride_dom, stride_dok, + stride_dz, stride_dh, stride_dm, + M, + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, + DIVISIBLE_M: tl.constexpr, +): + off_h = tl.program_id(1) + off_z = tl.program_id(2) + Out += off_z * stride_oz + off_h * stride_oh + DO += off_z * stride_doz + off_h * stride_doh + Delta += off_z * stride_dz + off_h * stride_dh + + # compute (Out * Dout).sum() for vector interpretation + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + + # load + o_ptrs = Out + off_m[:, None] * stride_om + off_n[None, :] * stride_ok + do_ptrs = DO + off_m[:, None] * stride_dom + off_n[None, :] * stride_dok + + if DIVISIBLE_M: + o = tl.load(o_ptrs).to(tl.float32) + do = tl.load(do_ptrs).to(tl.float32) + else: + mask_m = off_m < M + o = tl.load(o_ptrs, mask=mask_m[:, None]).to(tl.float32) + do = tl.load(do_ptrs, mask=mask_m[:, None]).to(tl.float32) + + # compute + delta = tl.sum(o * do, axis=1) + + # write-back + d_ptrs = Delta + off_m * stride_dm + if DIVISIBLE_M: + tl.store(d_ptrs, delta) + else: + tl.store(d_ptrs, delta, mask=mask_m) + + +@triton.jit +def _bwd_kv_kernel( + Q, K, V, LOG_LAMBDA, SEQ_START, sm_scale, DO, + DK, DV, DLOG_LAMBDA, + L, + D, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_log_lambda_z, stride_log_lambda_h, stride_log_lambda_n, + stride_doz, stride_doh, stride_dom, stride_dok, + stride_dkz, stride_dkh, stride_dkn, stride_dkk, + stride_dvz, stride_dvh, stride_dvn, stride_dvk, + stride_dlog_lambda_z, stride_dlog_lambda_h, stride_dlog_lambda_n, + Z, H, M, N, P_SEQ, + num_groups, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + CAUSAL: tl.constexpr, + DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, HAS_SEQ_START: tl.constexpr, +): + input_dtype = Q.dtype.element_ty + # -- grid id -- + start_n = tl.program_id(0) + off_h = tl.program_id(1) + off_z = tl.program_id(2) + log2e: tl.constexpr = 1.4426950408889634 + qk_scale = sm_scale * log2e + + # offset pointers for (batch, head) + off_hk = off_h // num_groups + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_hk * stride_kh + V += off_z * stride_vz + off_hk * stride_vh + LOG_LAMBDA += off_z * stride_log_lambda_z + off_h * stride_log_lambda_h + DO += off_z * stride_doz + off_h * stride_doh + + # offset pointers for batch/head + DK += off_z * stride_dkz + off_h * stride_dkh + DV += off_z * stride_dvz + off_h * stride_dvh + DLOG_LAMBDA += off_z * stride_dlog_lambda_z + off_h * stride_dlog_lambda_h + + # offset pointers for batch/head + D += (off_z * H + off_h) * M + L += (off_z * H + off_h) * M + + if CAUSAL: + lo = tl.maximum(start_n * BLOCK_N - P_SEQ, 0) + lo = (lo // BLOCK_M) * BLOCK_M + else: + lo = 0 + + offs_m_init = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m_base = tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_DMODEL) + + # initialize pointers to value-like data + q_ptrs = Q + (offs_m_init[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) + log_lambda_out_ptrs = LOG_LAMBDA + (P_SEQ + offs_m_init) * stride_log_lambda_n # (BLOCK_N, BLOCK_DMODEL) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL) + log_lambda_in_ptrs = LOG_LAMBDA + (offs_n * stride_log_lambda_n) # (BLOCK_N, BLOCK_DMODEL) + do_ptrs = DO + (offs_m_init[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL) + + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :] * stride_dvk) # (BLOCK_N, BLOCK_DMODEL) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk) # (BLOCK_N, BLOCK_DMODEL) + dlog_lambda_in_ptrs = DLOG_LAMBDA + (offs_n * stride_dlog_lambda_n) # (BLOCK_N, BLOCK_DMODEL) + + # k and v stay in SRAM throughout + if DIVISIBLE_N: + v = tl.load(v_ptrs) + k = tl.load(k_ptrs) + log_lambda_in = tl.load(log_lambda_in_ptrs) + else: + mask_n = offs_n < N + v = tl.load(v_ptrs, mask=mask_n[:, None]) + k = tl.load(k_ptrs, mask=mask_n[:, None]) + log_lambda_in = tl.load(log_lambda_in_ptrs, mask=mask_n) + + # If the N block doesn't contain seq_start, no need to loop + if HAS_SEQ_START: + SEQ_START += off_z + seq_start = tl.load(SEQ_START) + hi = tl.where(start_n * BLOCK_N + BLOCK_N >= seq_start - 1, M, lo) + else: + hi = M + + # initialize dk amd dv + dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + dlog_lambda_in = tl.zeros([BLOCK_N], dtype=tl.float32) + + # loop over a col + for start_m in range(lo, hi, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m = start_m + offs_m_base + causal_mask = (P_SEQ + offs_m[None, :]) >= (offs_n[:, None]) # (BLOCK_M, BLOCK_N) + + # load q1, k1, q2, k2, v, do on-chip + if DIVISIBLE_M: + q = tl.load(q_ptrs) + log_lambda_out = tl.load(log_lambda_out_ptrs) + else: + mask_m = offs_m < M + valid_mask = mask_m[None, :] # & mask_n + q = tl.load(q_ptrs, mask=mask_m[:, None]) + log_lambda_out = tl.load(log_lambda_out_ptrs, mask=mask_m) + # recompute p = softmax(qk * sm_scale, dim=-1) + # s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + sT = tl.dot(k, tl.trans(q), input_precision="ieee") * qk_scale + decay_bias = log_lambda_out[None, :] - log_lambda_in[:, None] + sT += decay_bias * log2e + # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd) + # So masking on s is not needed. + # s = tl.where(valid_mask, s , float("-inf")) + # if CAUSAL: + # s = tl.where(causal_mask, s, float("-inf")) + + # -- recompute p --- + if DIVISIBLE_M: + l = tl.load(L + offs_m) + else: + l = tl.load(L + offs_m, mask=mask_m) + pT = tl.math.exp2(sT - l[None, :] * log2e) # (BLOCK_M, BLOCK_N) + + if not DIVISIBLE_M: + pT = tl.where(valid_mask, pT, 0.0) + if CAUSAL: + pT = tl.where(causal_mask, pT, 0.0) + + # compute dv = dot(p, do) + if DIVISIBLE_M: + do = tl.load(do_ptrs) + else: + do = tl.load(do_ptrs, mask=mask_m[:, None]) # (BLOCK_M, BLOCK_DMODEL) + + + dv += tl.dot(pT.to(input_dtype), do, input_precision="ieee") # (BLOCK_N, BLOCK_DMODEL) # still correct + + # compute dp = dot(v, do) + if DIVISIBLE_M: + delta = tl.load(D + offs_m) + else: + delta = tl.load(D + offs_m, mask=mask_m) + # dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + dpT = tl.dot(v, tl.trans(do), input_precision="ieee") + + + # compute ds = p * (dp - delta[:, None]) + dsT = pT * (dpT - delta[None, :]) # (BLOCK_M, BLOCK_N) + + if not DIVISIBLE_M: + dsT = tl.where(valid_mask, dsT, 0.0) + if CAUSAL: + dsT = tl.where(causal_mask, dsT, 0.0) + + # compute dk = dot(ds.T, q) masking + dk += tl.dot(dsT.to(input_dtype), q, input_precision="ieee") + dlog_lambda_in += -tl.sum(dsT, axis=1) + + # increment pointers + q_ptrs += BLOCK_M * stride_qm + log_lambda_out_ptrs += BLOCK_M * stride_log_lambda_n + do_ptrs += BLOCK_M * stride_dom + + dk *= sm_scale + if HAS_SEQ_START: + # Mask out + seq_mask = (offs_n >= seq_start) + dk = tl.where(seq_mask[:, None], dk, 0.0) + dv = tl.where(seq_mask[:, None], dv, 0.0) + dlog_lambda_in = tl.where(seq_mask, dlog_lambda_in, 0.0) + if DIVISIBLE_N: + tl.store(dk_ptrs, dk.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL) + tl.store(dv_ptrs, dv.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL,) + tl.store(dlog_lambda_in_ptrs, dlog_lambda_in.to(tl.float32)) # (BLOCK_N, BLOCK_DMODEL,) + else: + tl.store(dk_ptrs, dk.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL) + tl.store(dv_ptrs, dv.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL) + tl.store(dlog_lambda_in_ptrs, dlog_lambda_in.to(tl.float32), mask=mask_n) # (BLOCK_N, BLOCK_DMODEL,) + + +@triton.jit +def _bwd_q_kernel( + Q, K, V, LOG_LAMBDA, SEQ_START, sm_scale, DO, + DQ, DLOG_LAMBDA, + L, + D, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_log_lambda_z, stride_log_lambda_h, stride_log_lambda_n, + stride_doz, stride_doh, stride_dom, stride_dok, + stride_dqz, stride_dqh, stride_dqm, stride_dqk, + stride_dlog_lambda_z, stride_dlog_lambda_h, stride_dlog_lambda_n, + Z, H, M, N, P_SEQ, + num_groups, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, HAS_SEQ_START: tl.constexpr, + DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, +): + input_dtype = Q.dtype.element_ty + # -- grid id -- + start_m = tl.program_id(0) + off_h = tl.program_id(1) + off_z = tl.program_id(2) + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + log2e: tl.constexpr = 1.4426950408889634 + qk_scale = sm_scale * log2e + + # offset pointers for (batch, head) + off_hk = off_h // num_groups + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_hk * stride_kh + V += off_z * stride_vz + off_hk * stride_vh + LOG_LAMBDA += off_z * stride_log_lambda_z + off_h * stride_log_lambda_h + DO += off_z * stride_doz + off_h * stride_doh + D += (off_z * H + off_h) * M + L += (off_z * H + off_h) * M + + # offset pointers for batch/head + DQ += off_z * stride_dqz + off_h * stride_dqh + DLOG_LAMBDA += off_z * stride_dlog_lambda_z + off_h * stride_dlog_lambda_h + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_DMODEL) + + # initialize pointers to value-like data + q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) + log_lambda_out_ptrs = LOG_LAMBDA + (P_SEQ + offs_m) * stride_log_lambda_n + + dq_ptrs = DQ + (offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk) # (BLOCK_M, BLOCK_DMODEL) + dlog_lambda_out_ptrs = DLOG_LAMBDA + (P_SEQ + offs_m) * stride_dlog_lambda_n + do_ptrs = DO + (offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL) + + # pointer to row-wise quantities in value-like data + d_ptrs = D + offs_m + l_ptrs = L + offs_m + + # load q: it will stay in SRAM throughout + if DIVISIBLE_M: + q = tl.load(q_ptrs) + do = tl.load(do_ptrs) + delta = tl.load(d_ptrs) + l = tl.load(l_ptrs) + log_lambda_out = tl.load(log_lambda_out_ptrs) + else: + mask_m = offs_m < M + q = tl.load(q_ptrs, mask=mask_m[:, None]) + do = tl.load(do_ptrs, mask=mask_m[:, None]) + delta = tl.load(d_ptrs, mask=mask_m) + l = tl.load(l_ptrs, mask=mask_m) + log_lambda_out = tl.load(log_lambda_out_ptrs, mask=mask_m) + + # initialize dq + dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dlog_lambda_out = tl.zeros([BLOCK_M], dtype=tl.float32) + + # loop over k, v and update accumulator + # see note "Loop-Bound-For-N" + if CAUSAL: + hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M) + if LARGER_M: + hi = tl.maximum(0, hi) + else: + hi = N + + offs_n_base = tl.arange(0, BLOCK_N) + offs_n_init = offs_n_base + if HAS_SEQ_START: + SEQ_START += off_z + seq_start = tl.load(SEQ_START) + lo = tl.minimum(seq_start, hi) + lo = (lo // BLOCK_N) * BLOCK_N + offs_n_init += lo + else: + lo = 0 + k_ptrs = K + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL) + v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL) + log_lambda_in_ptrs = LOG_LAMBDA + (offs_n_init * stride_log_lambda_n) + + # loop over a row + for start_n in range(lo, hi, BLOCK_N): + offs_n = start_n + offs_n_base + + # load k1, k2, v on chip + if DIVISIBLE_N: + v = tl.load(v_ptrs) + k = tl.load(k_ptrs) + log_lambda_in = tl.load(log_lambda_in_ptrs) + else: + mask_n = offs_n < N + v = tl.load(v_ptrs, mask=mask_n[:, None]) + k = tl.load(k_ptrs, mask=mask_n[:, None]) + log_lambda_in = tl.load(log_lambda_in_ptrs, mask=mask_n) + + + # recompute p = softmax(qk * sm_scale, dim=-1) + if not DIVISIBLE_N: + valid_mask = mask_n[None, :] # & mask_m[:, None] + if CAUSAL: + causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N) + # s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + s = tl.dot(q, tl.trans(k), input_precision="ieee") * qk_scale + decay_bias = log_lambda_out[:, None] - log_lambda_in[None, :] + s += decay_bias * log2e + + # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd) + # So masking on s is not needed. + # if CAUSAL: + # s = tl.where(causal_mask & valid_mask, s, float("-inf")) + # else: + # s = tl.where(valid_mask, s, float("-inf")) + p = tl.math.exp2(s - l[:, None] * log2e) # (BLOCK_M, BLOCK_N) + + # compute dp = dot(v, do) + # dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + dp = tl.dot(do.to(input_dtype), tl.trans(v), input_precision="ieee") + + + # no need to mask dp + # if CAUSAL: + # dp = tl.where(causal_mask & valid_mask, dp, 0.0) + # else: + # dp = tl.where(valid_mask, dp, 0.0) + + # compute ds = p * (dp - delta[:, None]) + # move scale out to dq at last + ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N) + + # mask ds to ensure no small values + if not DIVISIBLE_N: + ds = tl.where(valid_mask, ds, 0.0) + if CAUSAL: + ds = tl.where(causal_mask, ds, 0.0) + if HAS_SEQ_START: + ds = tl.where(offs_n[None, :] >= seq_start, ds, 0.0) + + dq += tl.dot(ds.to(input_dtype), k, input_precision="ieee") + dlog_lambda_out += tl.sum(ds, axis=1) + + # increment pointers + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vn + log_lambda_in_ptrs += BLOCK_N * stride_log_lambda_n + + dq *= sm_scale + if DIVISIBLE_M: + tmp = tl.load(dlog_lambda_out_ptrs) + else: + tmp = tl.load(dlog_lambda_out_ptrs, mask=mask_m) + dlog_lambda_out += tmp + if DIVISIBLE_M: + tl.store(dq_ptrs, dq.to(input_dtype)) + tl.store(dlog_lambda_out_ptrs, dlog_lambda_out) + else: + tl.store(dq_ptrs, dq.to(input_dtype), mask=mask_m[:, None]) + tl.store(dlog_lambda_out_ptrs, dlog_lambda_out, mask=mask_m) + + + +@pytest.mark.parametrize("Z, H, M, N, HEAD_DIM", [(4, 2, 1020, 2098, 64), (4, 2, 1024, 2048, 64)]) +@pytest.mark.parametrize("causal", [True]) +def test_op(Z, H, M, N, HEAD_DIM, causal, dtype=torch.bfloat16): + torch.manual_seed(24) + q = (torch.empty((Z, H, M, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + fgate_logit = torch.empty((Z, H, N), dtype=torch.float32, device="cuda").uniform_(5, 10) + log_fgate = torch.nn.functional.logsigmoid(fgate_logit).requires_grad_() + seq_start = torch.randint(low=0, high=N, size=(Z,), dtype=torch.long, device="cuda") + # seq_start = torch.randint(low=0, high=10, size=(Z,), dtype=torch.long, device="cuda") + # seq_start = torch.full(fill_value=0, size=(Z,), dtype=torch.long, device="cuda") + sm_scale = 0.5 + dout = torch.randn_like(q) + # reference implementation + P_SEQ = N - M + mask = torch.tril(torch.ones((M, N), device="cuda"), diagonal=P_SEQ) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + p = p.float() + + log_lambda = torch.cumsum(log_fgate, dim=-1) + decay_bias = log_lambda[..., -M:, None] - log_lambda[..., None, :] + p = p + decay_bias + if causal: + p[:, :, mask == 0] = float("-inf") + + attention_mask = torch.arange(N, device="cuda") < seq_start[:, None, None, None] + p = torch.where(attention_mask, float("-inf"), p) + p = torch.softmax(p.float(), dim=-1).to(dtype) + p = p.clone() + p[torch.isnan(p)] = 0.0 + # p = torch.exp(p) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + ref_dlog_fgate, log_fgate.grad = log_fgate.grad.clone(), None + # triton implementation + tri_out = forgetting_attention(q, k, v, log_fgate, head_first=True, seq_start=seq_start, sm_scale=sm_scale) + tri_out = tri_out.to(dtype) + + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + tri_dlog_fgate, log_fgate.grad = log_fgate.grad.clone(), None + # compare + # assert torch.allclose(tri_log_normalizer[~torch.isnan(tri_log_normalizer)], ref_log_normalizer[~torch.isnan(ref_log_normalizer)], atol=1e-2, rtol=0) + assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0), (ref_out - tri_out).abs().max() + rtol = 0 + # Relative tolerance workaround for known hardware limitation of MI200 GPU. + # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + # if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": + # rtol = 1e-2 + assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol), (ref_dv - tri_dv).abs().max() + assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol), (ref_dk - tri_dk).abs().max() + assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol), (ref_dq - tri_dq).abs().max() + assert torch.allclose(ref_dlog_fgate, tri_dlog_fgate, atol=1e-2, rtol=rtol), (ref_dlog_fgate - tri_dlog_fgate).abs().max() + +try: + from flash_attn.flash_attn_interface import \ + flash_attn_qkvpacked_func as flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') +BATCH, N_HEADS, HEAD_DIM = 4, 32, 128 +# vary seq length for fixed head and batch=4 +configs = [] +for mode in ["fwd", "bwd"]: +# for mode in ["bwd"]: + # for causal in [True, False]: + for causal in [True]: + if mode == "bwd" and not causal: + continue + configs.append( + triton.testing.Benchmark( + x_names=["N_CTX"], + # x_vals=[2**i for i in range(10, 15)], + x_vals=[2**i for i in range(14, 15)], + line_arg="provider", + # line_vals=["triton-fp16", "flag"] + (["flash"] if HAS_FLASH else []), + # line_names=["Triton [FP16]", "Flag"] + (["Flash-2"] if HAS_FLASH else []), + line_vals=["flag"] + (["flash"] if HAS_FLASH else []), + line_names=["Flag"] + (["Flash-2"] if HAS_FLASH else []), + styles=[("red", "-"), ("blue", "-"), ("green", "-")], + ylabel="ms", + plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}", + args={ + "H": N_HEADS, + "BATCH": BATCH, + "HEAD_DIM": HEAD_DIM, + "mode": mode, + "causal": causal, + }, + )) + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"): + assert mode in ["fwd", "bwd"] + warmup = 25 + rep = 100 + dtype = torch.bfloat16 + if "flag" in provider: + q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + fgate_logit = torch.empty((BATCH, H, N_CTX), dtype=torch.float32, device="cuda").uniform_(5, 10) + log_fgate = torch.nn.functional.logsigmoid(fgate_logit).requires_grad_() + # if mode == "fwd" and "fp8" in provider: + # q = q.to(torch.float8_e5m2) + # k = k.to(torch.float8_e5m2) + # v = v.permute(0, 1, 3, 2).contiguous() + # v = v.permute(0, 1, 3, 2) + # v = v.to(torch.float8_e5m2) + sm_scale = 1.3 + fn = lambda: forgetting_attention(q, k, v, log_fgate, head_first=True, sm_scale=sm_scale) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + if provider == "flash": + qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + fn = lambda: flash_attn_func(qkv, causal=causal) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM + total_flops = 2 * flops_per_matmul + if causal: + total_flops *= 0.5 + if mode == "bwd": + total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) + return total_flops / ms * 1e-9 + + +if __name__ == "__main__": + # only works on post-Ampere GPUs right now + bench_flash_attention.run(save_path=".", print_data=True) diff --git a/ops/.ipynb_checkpoints/forgetting_attention_std-checkpoint.py b/ops/.ipynb_checkpoints/forgetting_attention_std-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..7763fbae480987a777057c67076272cfd8345af6 --- /dev/null +++ b/ops/.ipynb_checkpoints/forgetting_attention_std-checkpoint.py @@ -0,0 +1,72 @@ +""" +Forgetting Attention - 标准 Softmax 版本 +在 forgetting_attention.py 最后添加这个函数 +""" + +import math +import torch +import torch.nn.functional as F +from einops import rearrange +from typing import Optional + + +def forgetting_attention_std( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + log_fgate: torch.Tensor, + *, + head_first: bool = False, + seq_start: Optional[torch.Tensor] = None, + sm_scale: Optional[float] = None, +) -> torch.Tensor: + """标准 Softmax 版本的 Forgetting Attention""" + + if not head_first: + q = rearrange(q, "b t h d -> b h t d") + k = rearrange(k, "b t h d -> b h t d") + v = rearrange(v, "b t h d -> b h t d") + log_fgate = rearrange(log_fgate, "b t h -> b h t") + + B, H, T_q, D = q.shape + T_k = k.shape[2] + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(D) + + # 计算 QK 分数 + scores = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale + + # 处理 seq_start + log_fgate_masked = log_fgate.float() + if seq_start is not None: + log_fgate_masked = log_fgate_masked.clone() + mask_idx = torch.arange(T_k, device=q.device)[None, None, :] < seq_start[:, None, None] + log_fgate_masked[mask_idx] = 0.0 + + # 计算累积衰减 + log_lambda = torch.cumsum(log_fgate_masked, dim=-1) + decay_bias = log_lambda[:, :, :T_q, None] - log_lambda[:, :, None, :] + scores = scores + decay_bias + + # Causal mask + P_SEQ = T_k - T_q + causal_mask = torch.triu(torch.ones((T_q, T_k), dtype=torch.bool, device=q.device), diagonal=P_SEQ + 1) + scores = scores.masked_fill(causal_mask[None, None, :, :], float('-inf')) + + # seq_start mask + if seq_start is not None: + seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None] + scores = scores.masked_fill(seq_mask, float('-inf')) + + # Softmax + attn = F.softmax(scores, dim=-1) + attn = torch.nan_to_num(attn, 0.0) + + # 计算输出 + out = torch.matmul(attn.to(v.dtype), v) + + if not head_first: + out = rearrange(out, "b h t d -> b t h d") + + return out diff --git a/ops/.ipynb_checkpoints/geometric_attention_std-checkpoint.py b/ops/.ipynb_checkpoints/geometric_attention_std-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..071fadffcc62a02dd7528c6b793b00e86dbcb8f2 --- /dev/null +++ b/ops/.ipynb_checkpoints/geometric_attention_std-checkpoint.py @@ -0,0 +1,179 @@ +""" +Geometric Attention - 标准 Softmax 版本 +基于论文 "The Neural Data Router" (Csordás et al., 2022) +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from typing import Optional + + +def geometric_attention_std( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + head_first: bool = False, + seq_start: Optional[torch.Tensor] = None, + sm_scale: Optional[float] = None, + normalize: bool = True, +) -> torch.Tensor: + """ + 标准 Softmax 版本的 Geometric Attention + + Args: + q: Query tensor [B, T, H, D] or [B, H, T, D] if head_first + k: Key tensor [B, T, H, D] or [B, H, T, D] if head_first + v: Value tensor [B, T, H, D] or [B, H, T, D] if head_first + head_first: 是否head维度在前 + seq_start: 序列起始位置 [B] + sm_scale: scaling factor,默认 1/sqrt(D) + normalize: 是否归一化attention weights + + Returns: + output: [B, T, H, D] or [B, H, T, D] if head_first + """ + + # Rearrange to head_first format + if not head_first: + q = rearrange(q, "b t h d -> b h t d") + k = rearrange(k, "b t h d -> b h t d") + v = rearrange(v, "b t h d -> b h t d") + + B, H, T_q, D = q.shape + T_k = k.shape[2] + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(D) + + # Step 1: 计算 content-based logits + logits = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale + # logits: [B, H, T_q, T_k] + + # Step 2: Mask diagonal (不允许attend到自己) + if T_q == T_k: + diag_mask = torch.eye(T_q, dtype=torch.bool, device=q.device) + logits = logits.masked_fill(diag_mask[None, None, :, :], float('-inf')) + + # Step 3: 处理 seq_start mask + if seq_start is not None: + seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None] + logits = logits.masked_fill(seq_mask, float('-inf')) + + # Step 4: Causal mask (如果需要) + # 注意:geometric attention论文中没有causal,如果你的任务需要可以取消注释 + # P_SEQ = T_k - T_q + # causal_mask = torch.triu(torch.ones((T_q, T_k), dtype=torch.bool, device=q.device), diagonal=P_SEQ + 1) + # logits = logits.masked_fill(causal_mask[None, None, :, :], float('-inf')) + + # Step 5: Geometric weighting (核心算法) + attn_weights = geometric_weighting(logits, normalize=normalize) + + # Step 6: 应用attention到values + out = torch.matmul(attn_weights.to(v.dtype), v) + + if not head_first: + out = rearrange(out, "b h t d -> b t h d") + + return out + + +def geometric_weighting( + logits: torch.Tensor, + normalize: bool = True, +) -> torch.Tensor: + """ + 计算geometric attention weights + + 实现论文中的 Equation 7: + A[i,j] = P[i,j] * ∏(1 - P[i,k]) for k closer to i than j + + Args: + logits: [B, H, T_q, T_k] attention logits + normalize: 是否归一化 + + Returns: + weights: [B, H, T_q, T_k] attention weights + """ + B, H, T_q, T_k = logits.shape + + # Step 1: Sigmoid to get matching probabilities + P = torch.sigmoid(logits) # [B, H, T_q, T_k] + + # Step 2: 使用 log-space 计算(数值稳定) + log_P = torch.log(P + 1e-10) + log_one_minus_P = torch.log(1.0 - P + 1e-10) + + # Step 3: 简化版本 - 使用cumsum实现几何分布 + # 这是一个高效的近似,避免了显式的循环 + + # 对于每个位置i,计算其左侧所有位置的log(1-P)累积和 + log_decay_left = log_one_minus_P.cumsum(dim=-1) + + # 计算weights(简化版) + # 完整版本需要根据距离动态选择区间,这里用一个高效近似 + weights = torch.exp(log_P + log_decay_left.roll(1, dims=-1)) + + # 第一个位置特殊处理(没有左侧元素) + # 避免inplace操作 + weights_first = P[:, :, :, :1] # 获取第一列 + weights = torch.cat([weights_first, weights[:, :, :, 1:]], dim=-1) + + # Step 4: 归一化(可选) + if normalize: + weights = F.normalize(weights, p=1, dim=-1) + + # 处理NaN(如果所有位置都是-inf) + weights = torch.nan_to_num(weights, 0.0) + + return weights + + +def geometric_weighting_full( + logits: torch.Tensor, + normalize: bool = True, +) -> torch.Tensor: + """ + 完整版geometric weighting(更慢但更准确) + + 仅在需要最高精度时使用,训练时建议用上面的简化版 + """ + B, H, T_q, T_k = logits.shape + device = logits.device + + P = torch.sigmoid(logits) + log_P = torch.log(P + 1e-10) + log_one_minus_P = torch.log(1.0 - P + 1e-10) + + # 初始化weights + weights = torch.zeros_like(P) + + # 对每个(i,j)计算geometric weight + for i in range(T_q): + for j in range(T_k): + # 找出比j更接近i的所有位置k + if i < j: + # 向右看:closer positions are [i+1, ..., j-1] + closer_positions = range(i + 1, j) + elif i > j: + # 向左看:closer positions are [j+1, ..., i-1] + closer_positions = range(j + 1, i) + else: + # i == j (对角线),已经在外面mask掉了 + continue + + # 计算 ∏(1 - P[i,k]) in log-space + log_prod = sum(log_one_minus_P[:, :, i, k] for k in closer_positions) if closer_positions else 0.0 + + # weights[i,j] = P[i,j] * ∏(1 - P[i,k]) + weights[:, :, i, j] = torch.exp(log_P[:, :, i, j] + log_prod) + + if normalize: + weights = F.normalize(weights, p=1, dim=-1) + + weights = torch.nan_to_num(weights, 0.0) + + return weights \ No newline at end of file diff --git a/ops/.ipynb_checkpoints/sliding_window_attention_std-checkpoint.py b/ops/.ipynb_checkpoints/sliding_window_attention_std-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..16c1551290ef8844ee8e3ca9793017ec90f3f896 --- /dev/null +++ b/ops/.ipynb_checkpoints/sliding_window_attention_std-checkpoint.py @@ -0,0 +1,88 @@ +""" +Sliding Window / Hard Attention +Based on "Context Limitations Make Neural Language Models More Human-Like" +(Kuribayashi et al., 2022) +""" + +import math +import torch +import torch.nn.functional as F +from einops import rearrange +from typing import Optional + + +def sliding_window_attention_std( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + head_first: bool = False, + seq_start: Optional[torch.Tensor] = None, + sm_scale: Optional[float] = None, + window_size: int = 2, # 默认2-gram(看前1个token) +) -> torch.Tensor: + """ + Sliding Window Attention + + 硬截断:只能attend到最近window_size个token + """ + + if not head_first: + q = rearrange(q, "b t h d -> b h t d") + k = rearrange(k, "b t h d -> b h t d") + v = rearrange(v, "b t h d -> b h t d") + + B, H, T_q, D = q.shape + T_k = k.shape[2] + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(D) + + # Compute logits + logits = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale + + # Create sliding window mask + mask = create_sliding_window_mask(T_q, T_k, window_size, device=q.device) + logits = logits.masked_fill(~mask, float('-inf')) + + # Seq start mask + if seq_start is not None: + seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None] + logits = logits.masked_fill(seq_mask, float('-inf')) + + # Standard softmax + weights = F.softmax(logits, dim=-1) + + # Apply to values + out = torch.matmul(weights, v) + + if not head_first: + out = rearrange(out, "b h t d -> b t h d") + + return out + + +def create_sliding_window_mask( + T_q: int, + T_k: int, + window_size: int, + device: torch.device +) -> torch.Tensor: + """ + 创建sliding window mask + + window_size=1: 只看前1个token (2-gram) + window_size=2: 只看前2个token (3-gram) + """ + # 基础causal mask + mask = torch.tril(torch.ones(T_q, T_k, dtype=torch.bool, device=device)) + + # 应用window限制 + if window_size > 0 and window_size < T_k: + for i in range(T_q): + # 只保留 [i-window_size+1, i] 范围 + start = max(0, i - window_size + 1) + if start > 0: + mask[i, :start] = False + + return mask[None, None, :, :] # [1, 1, T_q, T_k] \ No newline at end of file diff --git a/ops/.ipynb_checkpoints/stickbreaking_attention_std-checkpoint.py b/ops/.ipynb_checkpoints/stickbreaking_attention_std-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..b193d12b63bcfd41ef851f082a9048ce734359c7 --- /dev/null +++ b/ops/.ipynb_checkpoints/stickbreaking_attention_std-checkpoint.py @@ -0,0 +1,117 @@ +""" +Stick-breaking Attention - ICLR 2025 +基于论文 "Scaling Stick-Breaking Attention" (Tan et al., 2025) +简化的PyTorch实现(不使用Triton) +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from typing import Optional + + +def stickbreaking_attention_std( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + head_first: bool = False, + seq_start: Optional[torch.Tensor] = None, + sm_scale: Optional[float] = None, + normalize: bool = True, + attend_current: bool = False, +) -> torch.Tensor: + """ + Stick-breaking attention + + Based on ICLR 2025 paper, simplified PyTorch implementation + A_{i,j} = exp(z_{i,j} - ∑_{k=i}^{j-1} softplus(z_{k,j})) + + Args: + q: query [B, T, H, D] or [B, H, T, D] if head_first + k: key [B, T, H, D] or [B, H, T, D] if head_first + v: value [B, T, H, D] or [B, H, T, D] if head_first + attend_current: whether to attend to current position + normalize: whether to normalize attention weights + """ + + if not head_first: + q = rearrange(q, "b t h d -> b h t d") + k = rearrange(k, "b t h d -> b h t d") + v = rearrange(v, "b t h d -> b h t d") + + B, H, T_q, D = q.shape + T_k = k.shape[2] + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(D) + + # Compute logits: QK^T / sqrt(d) + logits = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale + # [B, H, T_q, T_k] + + # Causal mask (optional: mask diagonal if not attend_current) + if T_q == T_k and not attend_current: + diag_mask = torch.eye(T_q, dtype=torch.bool, device=q.device) + logits = logits.masked_fill(diag_mask[None, None, :, :], float('-inf')) + + # Seq start mask + if seq_start is not None: + seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None] + logits = logits.masked_fill(seq_mask, float('-inf')) + + # Stick-breaking weighting + attn_weights = stickbreaking_weighting(logits, normalize=normalize) + + # Apply attention to values + out = torch.matmul(attn_weights.to(v.dtype), v) + + if not head_first: + out = rearrange(out, "b h t d -> b t h d") + + return out + + +def stickbreaking_weighting( + logits: torch.Tensor, + normalize: bool = True, +) -> torch.Tensor: + """ + Compute stick-breaking attention weights + + From paper Equation 4: + A_{i,j} = exp(z_{i,j} - ∑_{k=i}^{j-1} log(1 + exp(z_{k,j}))) + + Where log(1 + exp(x)) is softplus(x) + """ + B, H, T_q, T_k = logits.shape + device = logits.device + + # Softplus: log(1 + exp(x)) + # Numerically stable version from paper (Equation 5) + def softplus_stable(x): + # softplus(x) = log(1 + exp(x)) + # When x > 15, exp(x) is huge, just return x + return torch.where( + x > 15.0, + x, + torch.log1p(torch.exp(torch.clamp(x, max=15.0))) + ) + + # Compute softplus for all logits + logits_sp = softplus_stable(logits) # [B, H, T_q, T_k] + + # For each query position, compute cumulative sum + # We need to accumulate from left to right (position i to j-1) + log_weights = torch.zeros_like(logits) + + for i in range(T_q): + # For query i, we compute attention to all keys j + z_i = logits[:, :, i, :] # [B, H, T_k] + z_sp_i = logits_sp[:, :, i, :] # [B, H, T_k] + + # Cumulative sum of softplus + # csum[j] = ∑_{k=0}^{j} softplus(z_{i,k}) + csum = z_sp_i.cumsum(dim=-1) \ No newline at end of file diff --git a/ops/.ipynb_checkpoints/vanilla_attention_std-checkpoint.py b/ops/.ipynb_checkpoints/vanilla_attention_std-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..6e1eafefc101c388e3e633f51e442209b9802d14 --- /dev/null +++ b/ops/.ipynb_checkpoints/vanilla_attention_std-checkpoint.py @@ -0,0 +1,171 @@ +""" +Vanilla Transformer 的标准 Softmax Attention +用于替换 flash_attn 的实现 +""" +import math +import torch +import torch.nn.functional as F +from einops import rearrange +from typing import Optional, Tuple + +def vanilla_attention_std( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + causal: bool = True, + window_size: Optional[Tuple[int, int]] = None, + sm_scale: Optional[float] = None, +) -> torch.Tensor: + """ + 标准 Softmax Attention,兼容 flash_attn_func 的输入格式 + + Args: + q, k, v: [batch, seq_len, num_heads, head_dim] 格式 + causal: 是否使用因果mask + window_size: 滑动窗口大小 (left, right),(-1, -1) 表示无限制 + sm_scale: softmax 缩放因子 + + Returns: + output: [batch, seq_len, num_heads, head_dim] 格式 + """ + B, T_q, H, D = q.shape + T_k = k.shape[1] + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(D) + + # 转换为 [B, H, T, D] 格式进行计算 + q = rearrange(q, 'b t h d -> b h t d') + k = rearrange(k, 'b t h d -> b h t d') + v = rearrange(v, 'b t h d -> b h t d') + + # 计算 attention scores + scores = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale + + # Causal mask + if causal: + P_SEQ = T_k - T_q # 处理 KV cache 的情况 + causal_mask = torch.triu( + torch.ones((T_q, T_k), dtype=torch.bool, device=q.device), + diagonal=P_SEQ + 1 + ) + scores = scores.masked_fill(causal_mask[None, None, :, :], float('-inf')) + + # Window mask (sliding window attention) + if window_size is not None and window_size != (-1, -1): + left_window, right_window = window_size + window_mask = torch.ones((T_q, T_k), dtype=torch.bool, device=q.device) + for i in range(T_q): + # 计算每个查询位置的有效窗口范围 + start = max(0, i - left_window) + end = min(T_k, i + right_window + 1) + window_mask[i, start:end] = False + scores = scores.masked_fill(window_mask[None, None, :, :], float('-inf')) + + # Softmax + attn_weights = F.softmax(scores, dim=-1) + attn_weights = torch.nan_to_num(attn_weights, 0.0) + + # Apply attention to values + output = torch.matmul(attn_weights.to(v.dtype), v) + + # 转换回 [B, T, H, D] 格式 + output = rearrange(output, 'b h t d -> b t h d') + + return output + + +def vanilla_attention_varlen_std( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + causal: bool = True, + window_size: Optional[Tuple[int, int]] = None, + sm_scale: Optional[float] = None, +) -> torch.Tensor: + """ + 变长序列的标准 Softmax Attention,兼容 flash_attn_varlen_func + + Args: + q: [total_q_tokens, num_heads, head_dim] + k: [total_k_tokens, num_kv_heads, head_dim] + v: [total_k_tokens, num_kv_heads, head_dim] + cu_seqlens_q: 累积序列长度 [batch_size + 1] + cu_seqlens_k: 累积序列长度 [batch_size + 1] + max_seqlen_q: 最大查询序列长度 + max_seqlen_k: 最大键值序列长度 + + Returns: + output: [total_q_tokens, num_heads, head_dim] + """ + batch_size = cu_seqlens_q.shape[0] - 1 + H = q.shape[1] + D = q.shape[2] + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(D) + + outputs = [] + + # 逐批次处理 + for b in range(batch_size): + q_start, q_end = cu_seqlens_q[b].item(), cu_seqlens_q[b+1].item() + k_start, k_end = cu_seqlens_k[b].item(), cu_seqlens_k[b+1].item() + + if q_start == q_end: # 空序列 + continue + + # 提取当前批次的 q, k, v + q_b = q[q_start:q_end] # [T_q, H, D] + k_b = k[k_start:k_end] # [T_k, H, D] + v_b = v[k_start:k_end] # [T_k, H, D] + + T_q = q_b.shape[0] + T_k = k_b.shape[0] + + # 转换为 [H, T, D] 格式 + q_b = rearrange(q_b, 't h d -> h t d') + k_b = rearrange(k_b, 't h d -> h t d') + v_b = rearrange(v_b, 't h d -> h t d') + + # 计算 attention scores + scores = torch.matmul(q_b.float(), k_b.float().transpose(-2, -1)) * sm_scale + + # Causal mask + if causal: + P_SEQ = T_k - T_q + causal_mask = torch.triu( + torch.ones((T_q, T_k), dtype=torch.bool, device=q.device), + diagonal=P_SEQ + 1 + ) + scores = scores.masked_fill(causal_mask[None, :, :], float('-inf')) + + # Window mask + if window_size is not None and window_size != (-1, -1): + left_window, right_window = window_size + window_mask = torch.ones((T_q, T_k), dtype=torch.bool, device=q.device) + for i in range(T_q): + start = max(0, i - left_window) + end = min(T_k, i + right_window + 1) + window_mask[i, start:end] = False + scores = scores.masked_fill(window_mask[None, :, :], float('-inf')) + + # Softmax + attn_weights = F.softmax(scores, dim=-1) + attn_weights = torch.nan_to_num(attn_weights, 0.0) + + # Apply attention + output_b = torch.matmul(attn_weights.to(v_b.dtype), v_b) + + # 转换回 [T, H, D] 格式 + output_b = rearrange(output_b, 'h t d -> t h d') + outputs.append(output_b) + + # 拼接所有批次的输出 + output = torch.cat(outputs, dim=0) + + return output diff --git a/ops/__init__.py b/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dea7022bf8589cc068109d7dcb7ad0c4aa2d090c --- /dev/null +++ b/ops/__init__.py @@ -0,0 +1,3 @@ + +# Framework mock for ndr compatibility +from . import framework_mock diff --git a/ops/__pycache__/__init__.cpython-310.pyc b/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..853622d01db2909b6b5ff4d495719103e8a99501 Binary files /dev/null and b/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/ops/__pycache__/direction_sensitive_geometric.cpython-310.pyc b/ops/__pycache__/direction_sensitive_geometric.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..092300b527a3d2fd6765e414c893814b1b57d9c5 Binary files /dev/null and b/ops/__pycache__/direction_sensitive_geometric.cpython-310.pyc differ diff --git a/ops/__pycache__/forgetting_attention.cpython-310.pyc b/ops/__pycache__/forgetting_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad6a7a7570d15fca0e96b33d3145d2c1fb5f55f7 Binary files /dev/null and b/ops/__pycache__/forgetting_attention.cpython-310.pyc differ diff --git a/ops/__pycache__/forgetting_attention_std.cpython-310.pyc b/ops/__pycache__/forgetting_attention_std.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99d0e742efdae15ccb011b6ee2a6d1584cd6012a Binary files /dev/null and b/ops/__pycache__/forgetting_attention_std.cpython-310.pyc differ diff --git a/ops/__pycache__/framework_mock.cpython-310.pyc b/ops/__pycache__/framework_mock.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cc5a17ea3788ed1115ea303a3f9bec72c420f25 Binary files /dev/null and b/ops/__pycache__/framework_mock.cpython-310.pyc differ diff --git a/ops/__pycache__/geometric_attention_final.cpython-310.pyc b/ops/__pycache__/geometric_attention_final.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d17960915ca0ffc9c0c614a8f1ea7048b4c0693 Binary files /dev/null and b/ops/__pycache__/geometric_attention_final.cpython-310.pyc differ diff --git a/ops/__pycache__/geometric_attention_std.cpython-310.pyc b/ops/__pycache__/geometric_attention_std.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ad9fcfee6f87bdaaef85fc8162ee8c84540489b Binary files /dev/null and b/ops/__pycache__/geometric_attention_std.cpython-310.pyc differ diff --git a/ops/__pycache__/layer_with_visualization.cpython-310.pyc b/ops/__pycache__/layer_with_visualization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..caa68ee1d14b7e3840adc03d4eda66a53033a976 Binary files /dev/null and b/ops/__pycache__/layer_with_visualization.cpython-310.pyc differ diff --git a/ops/__pycache__/multi_head_attention.cpython-310.pyc b/ops/__pycache__/multi_head_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cd95bd30946237d7d6a82776964467d5638ca5d Binary files /dev/null and b/ops/__pycache__/multi_head_attention.cpython-310.pyc differ diff --git a/ops/__pycache__/multi_head_relative_pos_attention.cpython-310.pyc b/ops/__pycache__/multi_head_relative_pos_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e50d8f8b3129754fa4a61bc211f6cf08d068dc76 Binary files /dev/null and b/ops/__pycache__/multi_head_relative_pos_attention.cpython-310.pyc differ diff --git a/ops/__pycache__/sliding_window_attention_std.cpython-310.pyc b/ops/__pycache__/sliding_window_attention_std.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d183a1cabc4385af31d2c3d836806f3ba5e835e4 Binary files /dev/null and b/ops/__pycache__/sliding_window_attention_std.cpython-310.pyc differ diff --git a/ops/__pycache__/stickbreaking_attention_std.cpython-310.pyc b/ops/__pycache__/stickbreaking_attention_std.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a495be7146ee70b1fbc2dfea07fb71bc010b11d4 Binary files /dev/null and b/ops/__pycache__/stickbreaking_attention_std.cpython-310.pyc differ diff --git a/ops/__pycache__/vanilla_attention_std.cpython-310.pyc b/ops/__pycache__/vanilla_attention_std.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8277494474dd8814d01dec9d82109c097bebf57d Binary files /dev/null and b/ops/__pycache__/vanilla_attention_std.cpython-310.pyc differ diff --git a/ops/direction_sensitive_geometric.py b/ops/direction_sensitive_geometric.py new file mode 100644 index 0000000000000000000000000000000000000000..85b5db49c0e9bfd947a4deaf4a848db6b7c7ec12 --- /dev/null +++ b/ops/direction_sensitive_geometric.py @@ -0,0 +1,115 @@ +import torch +from forgetting_transformer.ops.multi_head_attention import AttentionMask, MultiHeadAttentionBase, AttentionMergeMixin +from typing import Optional +from forgetting_transformer.ops.geometric_attention import geometric_attention_activation +import math +from forgetting_transformer.ops.multi_head_relative_pos_attention import FixedRelativeMultiheadAttentionBase, shift + + +class DirectionSensitiveGeometricAttention(AttentionMergeMixin, FixedRelativeMultiheadAttentionBase): + def __init__(self, state_size: int, n_heads: int, dropout: float = 0.0, global_pos_bias: bool = True, + global_content_bias: bool = True, input_size: Optional[int] = None, + output_size: Optional[int] = None, normalize_score: bool = True): + super(AttentionMergeMixin, self).__init__(state_size, n_heads, dropout, input_size) + + self.data_to_kv = torch.nn.Linear(state_size, 2 * n_heads * self.projection_size, bias=False) + self.data_to_q = torch.nn.Linear(self.input_size, n_heads * self.projection_size, bias=False) + self.data_to_qp = torch.nn.Linear(self.input_size, n_heads * 2) + + self.global_content_bias = torch.nn.Parameter(torch.zeros([n_heads, self.projection_size])) \ + if global_content_bias else None + + self.s_bias = torch.nn.Parameter(torch.full([1], 0.0)) + self.scale = torch.nn.Parameter(torch.full([1], 1.0 / math.sqrt(self.projection_size))) + self.scale_pos = torch.nn.Parameter(torch.full([1], 1.0)) + self.normalize_score = normalize_score + + self.input_size = state_size if input_size is None else input_size + + print(f"DirectionSensitiveGeometricAttention: normalize score: {normalize_score}") + + super(DirectionSensitiveGeometricAttention, self).__init__(output_size) + self.reset_parameters() + + def get_attention_scores(self, mask: Optional[torch.Tensor], + q_content: torch.Tensor, k_content: torch.Tensor, + q_pos: torch.Tensor, + pos_offset: int) -> torch.Tensor: + + # content-content addressing + logits = torch.bmm(q_content, self.dropout(k_content).transpose(1, 2)) + + # directionality. Do scaling here, less flops. + prefer_back, prefer_front = (q_pos * self.scale_pos).unsqueeze(-2).expand(-1,-1,logits.shape[-1],-1).unbind(-1) + fpos = prefer_front.triu(1 + pos_offset) + prefer_back.tril(-1 + pos_offset) + + logits = logits * self.scale + fpos + self.s_bias + + logits = self.apply_logit_masks(logits.view(logits.shape[0] // self.n_heads, self.n_heads, *logits.shape[1:]), mask).flatten(0,1) + + logits.masked_fill_(torch.eye(logits.shape[-1], device=logits.device, dtype=torch.bool)[pos_offset : pos_offset + logits.shape[-2]], float("-inf")) + + return geometric_attention_activation(logits, mask, pos_offset, normalize=self.normalize_score) + + def add_head_specific_bias(self, data: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: + # data [batch * n_heads, len, c] + # bias [n_heads, c] + return (data.view(-1, bias.shape[0], *data.shape[1:]) + bias.unsqueeze(1).type_as(data)).view_as(data) \ + if bias is not None else data + + def _attention(self, mask: Optional[torch.Tensor], + q_content: torch.Tensor, k_content: torch.Tensor, + q_pos: torch.Tensor, + v: torch.Tensor, pos_offset: int) -> [torch.Tensor, torch.Tensor]: + + scores = self.get_attention_scores(mask, q_content, k_content, q_pos, pos_offset) + + # Scores shape: [n_batch * n_heads, n_out, n_in] + return self._attention_read(mask, scores, v) + + def forward(self, curr_state: torch.Tensor, attend_to: torch.Tensor, mask: Optional[AttentionMask], + pos_offset: int = 0, need_weights: bool = False): + # curr_state: [batch_size, out_len, c] + # attend_to: [batch_size, in_len, c] + batch_size, in_len = attend_to.shape[0:2] + out_len = curr_state.shape[1] + + k_content, v = self.transform_data(attend_to, self.data_to_kv, 2) + q, = self.transform_data(curr_state, self.data_to_q, 1) + q_pos, = self.transform_data(curr_state, self.data_to_qp, 1) + + q_content = self.add_head_specific_bias(q, self.global_content_bias) + + data, scores = self.merged_attention(batch_size, out_len, mask, q_content, k_content, q_pos, v, + pos_offset, need_weights=need_weights) + + if need_weights: + return data, scores + else: + return data + + def reset_parameters(self): + torch.nn.init.xavier_uniform_(self.data_to_q.weight) + torch.nn.init.xavier_uniform_(self.pos_to_pq.weight) + torch.nn.init.xavier_uniform_(self.data_to_kv.weight[:self.projection_size * self.n_heads]) + torch.nn.init.xavier_uniform_(self.data_to_kv.weight[self.projection_size * self.n_heads:]) + + if self.global_content_bias is not None: + self.global_content_bias.data.fill_(0) + + +class DirectionSensitiveGeometricAttentionMyInit(DirectionSensitiveGeometricAttention): + def xavier_manual_(self, tensor: torch.Tensor, fan_in: int, fan_out: int, gain: float = 1) -> torch.Tensor: + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + return torch.nn.init._no_grad_uniform_(tensor, -a, a) + + def reset_parameters(self): + self.xavier_manual_(self.data_to_q.weight, self.state_size, self.projection_size) + self.xavier_manual_(self.pos_to_pq.weight, self.state_size, 2) + self.xavier_manual_(self.data_to_kv.weight, self.state_size, self.projection_size) + self.xavier_manual_(self.multi_head_merge.weight, self.projection_size, self.state_size) + + if self.global_content_bias is not None: + self.global_content_bias.data.fill_(0) diff --git a/ops/direction_sensitive_geometric.py.bak b/ops/direction_sensitive_geometric.py.bak new file mode 100644 index 0000000000000000000000000000000000000000..0980451ccfcb531a6b0007c867bb93007b9cd66a --- /dev/null +++ b/ops/direction_sensitive_geometric.py.bak @@ -0,0 +1,115 @@ +import torch +from .multi_head_attention import AttentionMask, MultiHeadAttentionBase, AttentionMergeMixin +from typing import Optional +from .geometric_attention import geometric_attention_activation +import math +from .multi_head_relative_pos_attention import FixedRelativeMultiheadAttentionBase, shift + + +class DirectionSensitiveGeometricAttention(AttentionMergeMixin, FixedRelativeMultiheadAttentionBase): + def __init__(self, state_size: int, n_heads: int, dropout: float = 0.0, global_pos_bias: bool = True, + global_content_bias: bool = True, input_size: Optional[int] = None, + output_size: Optional[int] = None, normalize_score: bool = True): + super(AttentionMergeMixin, self).__init__(state_size, n_heads, dropout, input_size) + + self.data_to_kv = torch.nn.Linear(state_size, 2 * n_heads * self.projection_size, bias=False) + self.data_to_q = torch.nn.Linear(self.input_size, n_heads * self.projection_size, bias=False) + self.data_to_qp = torch.nn.Linear(self.input_size, n_heads * 2) + + self.global_content_bias = torch.nn.Parameter(torch.zeros([n_heads, self.projection_size])) \ + if global_content_bias else None + + self.s_bias = torch.nn.Parameter(torch.full([1], 0.0)) + self.scale = torch.nn.Parameter(torch.full([1], 1.0 / math.sqrt(self.projection_size))) + self.scale_pos = torch.nn.Parameter(torch.full([1], 1.0)) + self.normalize_score = normalize_score + + self.input_size = state_size if input_size is None else input_size + + print(f"DirectionSensitiveGeometricAttention: normalize score: {normalize_score}") + + super(DirectionSensitiveGeometricAttention, self).__init__(output_size) + self.reset_parameters() + + def get_attention_scores(self, mask: Optional[torch.Tensor], + q_content: torch.Tensor, k_content: torch.Tensor, + q_pos: torch.Tensor, + pos_offset: int) -> torch.Tensor: + + # content-content addressing + logits = torch.bmm(q_content, self.dropout(k_content).transpose(1, 2)) + + # directionality. Do scaling here, less flops. + prefer_back, prefer_front = (q_pos * self.scale_pos).unsqueeze(-2).expand(-1,-1,logits.shape[-1],-1).unbind(-1) + fpos = prefer_front.triu(1 + pos_offset) + prefer_back.tril(-1 + pos_offset) + + logits = logits * self.scale + fpos + self.s_bias + + logits = self.apply_logit_masks(logits.view(logits.shape[0] // self.n_heads, self.n_heads, *logits.shape[1:]), mask).flatten(0,1) + + logits.masked_fill_(torch.eye(logits.shape[-1], device=logits.device, dtype=torch.bool)[pos_offset : pos_offset + logits.shape[-2]], float("-inf")) + + return geometric_attention_activation(logits, mask, pos_offset, normalize=self.normalize_score) + + def add_head_specific_bias(self, data: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: + # data [batch * n_heads, len, c] + # bias [n_heads, c] + return (data.view(-1, bias.shape[0], *data.shape[1:]) + bias.unsqueeze(1).type_as(data)).view_as(data) \ + if bias is not None else data + + def _attention(self, mask: Optional[torch.Tensor], + q_content: torch.Tensor, k_content: torch.Tensor, + q_pos: torch.Tensor, + v: torch.Tensor, pos_offset: int) -> [torch.Tensor, torch.Tensor]: + + scores = self.get_attention_scores(mask, q_content, k_content, q_pos, pos_offset) + + # Scores shape: [n_batch * n_heads, n_out, n_in] + return self._attention_read(mask, scores, v) + + def forward(self, curr_state: torch.Tensor, attend_to: torch.Tensor, mask: Optional[AttentionMask], + pos_offset: int = 0, need_weights: bool = False): + # curr_state: [batch_size, out_len, c] + # attend_to: [batch_size, in_len, c] + batch_size, in_len = attend_to.shape[0:2] + out_len = curr_state.shape[1] + + k_content, v = self.transform_data(attend_to, self.data_to_kv, 2) + q, = self.transform_data(curr_state, self.data_to_q, 1) + q_pos, = self.transform_data(curr_state, self.data_to_qp, 1) + + q_content = self.add_head_specific_bias(q, self.global_content_bias) + + data, scores = self.merged_attention(batch_size, out_len, mask, q_content, k_content, q_pos, v, + pos_offset, need_weights=need_weights) + + if need_weights: + return data, scores + else: + return data + + def reset_parameters(self): + torch.nn.init.xavier_uniform_(self.data_to_q.weight) + torch.nn.init.xavier_uniform_(self.pos_to_pq.weight) + torch.nn.init.xavier_uniform_(self.data_to_kv.weight[:self.projection_size * self.n_heads]) + torch.nn.init.xavier_uniform_(self.data_to_kv.weight[self.projection_size * self.n_heads:]) + + if self.global_content_bias is not None: + self.global_content_bias.data.fill_(0) + + +class DirectionSensitiveGeometricAttentionMyInit(DirectionSensitiveGeometricAttention): + def xavier_manual_(self, tensor: torch.Tensor, fan_in: int, fan_out: int, gain: float = 1) -> torch.Tensor: + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + return torch.nn.init._no_grad_uniform_(tensor, -a, a) + + def reset_parameters(self): + self.xavier_manual_(self.data_to_q.weight, self.state_size, self.projection_size) + self.xavier_manual_(self.pos_to_pq.weight, self.state_size, 2) + self.xavier_manual_(self.data_to_kv.weight, self.state_size, self.projection_size) + self.xavier_manual_(self.multi_head_merge.weight, self.projection_size, self.state_size) + + if self.global_content_bias is not None: + self.global_content_bias.data.fill_(0) diff --git a/ops/forgetting_attention.py b/ops/forgetting_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d6cba53a9a3ae49a448247e739f7f3d644caeca5 --- /dev/null +++ b/ops/forgetting_attention.py @@ -0,0 +1,1138 @@ +""" +Implementation of Forgetting Attention. + +Our code is adapted from https://github.com/FlagOpen/FlagAttention/blob/ee91638dec6da8c00c4113d179f469e0ffcd5852/src/flag_attn/flash.py. The code is modified to implement Forgetting Attention. + +The original license info from FlagAttention: + +Copyright 2023 BAAI + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import pytest +import math +import torch +import triton +import triton.language as tl +from einops import rearrange +from typing import Optional + + +__all__ = ["forgetting_attention"] + + +# File flash.py +def maybe_contiguous(x): + # only when the inner most dimension is contiguous can LDGSTS be used + # so inner-dimension contiguity is enforced. + return x.contiguous() if x.stride(-1) != 1 else x + +def rounded_multiple(a, b): + return (a + b - 1) // b * b + +# --------------------------- public API --------------------------- +class ForgettingAttention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, log_fgate, seq_start, causal, sm_scale, return_log_normalizer): + assert causal, "Only causal attention is supported" + Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Dq == Dk == Dv, "feature size of q, k, v should be equal" + assert Dk in {16, 32, 64, 128}, "We only support head dims in {16, 32, 64, 128}" + + B, H, M, D = q.shape + if seq_start is not None: + has_seq_start = True + assert seq_start.shape == (B,) + else: + has_seq_start = False + seq_start = torch.zeros((B,), device=q.device, dtype=torch.long) + N = k.shape[2] + assert log_fgate.shape == (B, H, N) + log_fgate = log_fgate.float() + if has_seq_start: + log_fgate = log_fgate.clone() + # We absolutely don't want masked value to affect result. If we + # don't do this then it could via affecting numerical precision of + # cumsum + mask_index = (torch.arange(N, device=q.device)[None, None, :] < seq_start[:, None, None]) + mask_index = torch.broadcast_to(mask_index, log_fgate.size()) + log_fgate[mask_index] = 0.0 + + log_lambda = torch.cumsum(log_fgate, dim=-1, dtype=log_fgate.dtype).float() + + Hk, Hv = k.shape[1], v.shape[1] + assert Hk == Hv, "num of heads in k and v should be equal" + assert H == Hk, "groupped query attention has not been tested. You can uncomment this if you know what you are doing." + assert H % Hk == 0, "number of heads in q must be a multiple of that in k & v" + num_groups = H // Hk + + P_SEQ = N - M + larger_m = M > N + assert (not larger_m), "The key/value tensors must be longer than the query tensor" + + if sm_scale is None: + sm_scale = 1. / math.sqrt(D) + + # contiguity + q, k, v = maybe_contiguous(q), maybe_contiguous(k), maybe_contiguous(v) + + # to work around https://github.com/openai/triton/issues/2441 + device = torch.cuda.device_of(q) + + with torch.cuda.device(device): + + config = get_fwd_config(B, H, M, N, D, causal) + BLOCK_M, BLOCK_N, num_stages, num_warps = config + + divisible_m = M % BLOCK_M == 0 + divisible_n = N % BLOCK_N == 0 + # consider using 3d grid to avoid div & rem + grid = (triton.cdiv(M, BLOCK_M), H, B) + o = torch.empty_like(q) + L = torch.empty((B, H, M), device=q.device, dtype=torch.float32) + _fwd_kernel[grid]( + q, k, v, log_lambda, seq_start, sm_scale, + L, o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + log_lambda.stride(0), log_lambda.stride(1), log_lambda.stride(2), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + B, H, M, N, P_SEQ, num_groups, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D, + IS_CAUSAL=causal, LARGER_M=larger_m, HAS_SEQ_START=has_seq_start, + DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, + num_warps=num_warps, num_stages=num_stages, + ) + + # autograd context maintenance + ctx.save_for_backward(q, k, v, o, L, log_lambda, seq_start) + ctx.sm_scale = sm_scale + ctx.causal = causal + ctx.has_seq_start = has_seq_start + + has_extra_return = return_log_normalizer + if has_extra_return: + outs = ( + o, + L if return_log_normalizer else None, + ) + return outs + return o + + @staticmethod + def backward(ctx, do, *ignored): + q, k, v, o, L, log_lambda, seq_start = ctx.saved_tensors + sm_scale = ctx.sm_scale + causal = ctx.causal + has_seq_start = ctx.has_seq_start + + B, H, M, D = q.shape + N = k.shape[2] + Hk = k.shape[1] + num_groups = H // Hk + P_SEQ = N - M + larger_m = M > N + + if sm_scale is None: + sm_scale = 1. / math.sqrt(D) + + # to work around https://github.com/openai/triton/issues/2441 + device = torch.cuda.device_of(q) + with torch.cuda.device(device): + config = get_bwd_config(B, H, M, N, D, causal) + BLOCK_M, BLOCK_N, num_stages, num_warps = config + + divisible_m = M % BLOCK_M == 0 + divisible_n = N % BLOCK_N == 0 + + delta = torch.empty_like(L) + grid = (triton.cdiv(M, BLOCK_M), H, B) + _bwd_preprocess[grid]( + o, do, + delta, + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + do.stride(0), do.stride(1), do.stride(2), do.stride(3), + delta.stride(0), delta.stride(1), delta.stride(2), + M, + BLOCK_M=BLOCK_M, D_HEAD=D, + DIVISIBLE_M=divisible_m, + ) + + # NOTE that dk & dv always have the same number of heads as q, instead of q. + BLOCK_M, BLOCK_N, num_stages, num_warps = get_bwd_kv_config(B, H, M, N, D, causal) + divisible_m = M % BLOCK_M == 0 + divisible_n = N % BLOCK_N == 0 + + dk = torch.empty((B, H, N, D), dtype=k.dtype, device=q.device) + dv = torch.empty((B, H, N, D), dtype=v.dtype, device=q.device) + dlog_lambda = torch.empty((B, H, N), dtype=log_lambda.dtype, device=q.device) + grid = (triton.cdiv(N, BLOCK_N), H, B) + _bwd_kv_kernel[grid]( + q, k, v, log_lambda, seq_start, sm_scale, do, + dk, dv, dlog_lambda, + L, delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + log_lambda.stride(0), log_lambda.stride(1), log_lambda.stride(2), + do.stride(0), do.stride(1), do.stride(2), do.stride(3), + dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), + dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), + dlog_lambda.stride(0), dlog_lambda.stride(1), dlog_lambda.stride(2), + B, H, M, N, P_SEQ, + num_groups, + BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal, + DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, HAS_SEQ_START=has_seq_start, + num_stages=num_stages, num_warps=num_warps, + ) + + BLOCK_M, BLOCK_N, num_stages, num_warps = get_bwd_q_config(B, H, M, N, D, causal) + divisible_m = M % BLOCK_M == 0 + divisible_n = N % BLOCK_N == 0 + dq = torch.zeros_like(q) + grid = (triton.cdiv(M, BLOCK_M), H, B) + _bwd_q_kernel[grid]( + q, k, v, log_lambda, seq_start, sm_scale, do, + dq, dlog_lambda, + L, delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + log_lambda.stride(0), log_lambda.stride(1), log_lambda.stride(2), + do.stride(0), do.stride(1), do.stride(2), do.stride(3), + dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3), + dlog_lambda.stride(0), dlog_lambda.stride(1), dlog_lambda.stride(2), + B, H, M, N, P_SEQ, + num_groups, + BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, + CAUSAL=causal, LARGER_M=larger_m, HAS_SEQ_START=has_seq_start, + DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, + num_stages=num_stages, num_warps = num_warps, + ) + dk = dk.reshape((B, Hk, num_groups, N, D)).sum(2) + dv = dv.reshape((B, Hk, num_groups, N, D)).sum(2) + dcumsum = torch.cumsum(dlog_lambda, dim=-1, dtype=log_lambda.dtype) + dlog_fgate = dlog_lambda + dcumsum[..., -1:] - dcumsum + dlog_fgate = dlog_fgate.float() + return dq, dk, dv, dlog_fgate, None, None, None, None, None, None, None + + +def forgetting_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + log_fgate: torch.Tensor, + *, + head_first: bool = False, + seq_start: Optional[torch.Tensor] = None, + sm_scale: Optional[float] = None, +): + """ + A FlashAttention-based implementation of Forgetting Attention. + + Note: + - We recommand bfloat16/float16 for q, k, v and float32 for log_fgate. float32 for + q, k, v is also supported, but the kernel will not use tensor cores if q, k, v are + in float32 (which would be slow). + - We only support seqlen_q <= seqlen_k + - We only support causal attention + - Head dimension must be in one of {16, 32, 64, 128} + + Arguments: + - q: (batch_size, seqlen_q, num_heads, head_dim) unless head_first=True. + - k: (batch_size, seqlen_k, num_heads, head_dim) unless head_first=True. + - v: (batch_size, seqlen_k, num_heads, head_dim) unless head_first=True. + - log_fgate: (batch_size, seqlen_k, num_heads) unless head_first=True. + This should be the **log** of the forget gates. This is typically the + output of torch.nn.functional.logsigmoid. + - head_first: if True, the order the num_heads and seqlen_* axis of the all + FloatTensor inputs and outputs should be (num_heads, seq_len_*) instead of + (seq_len_*, num_heads) + - seq_start: If not None, should be LongTensor with shape (batch_size,) + and range in [0, seq_len_k). For each batch index batch_id, no attention + will be allocated to tokens before the token index seq_start[batch_id]. + This is useful for left-padded inputs. + - sm_scale: The scaling of attention scores before applying softmax. If + None, it defaults to (1.0 / math.sqrt(head_dim)) + + Returns: + out (torch.Tensor): (batch_size, seqlen_q, num_heads, head_dim) unless head_first=True. + """ + if not head_first: + q, k, v = [rearrange(item, "b t h d -> b h t d") for item in (q, k, v)] + log_fgate = rearrange(log_fgate, "b t h -> b h t") + out = ForgettingAttention.apply(q, k, v, log_fgate, seq_start, True, sm_scale, False) + if not head_first: + out = rearrange(out, "b h t d -> b t h d") + return out + + +# --------------------------- Forward --------------------------- +# NOTE: this function can be overwritten at runtime to use your custom config +def get_fwd_config(B, H, M, N, D, causal): + assert causal + if torch.cuda.get_device_capability() == (8, 0): + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 32, 3, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 4, 4 + elif torch.cuda.get_device_capability() == (9, 0): + # H100 + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 8 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 2, 8 + elif torch.cuda.get_device_capability() == (8, 6): + if not causal: + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4 + else: # causal + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4 + elif torch.cuda.get_device_capability() == (8, 9): + # L40S + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 2, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 + return (BLOCK_M, BLOCK_N, num_stages, num_warps) + + +@triton.jit +def _fwd_kernel( + Q, K, V, LOG_LAMBDA, SEQ_START, sm_scale, + L, O, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_log_lambda_z, stride_log_lambda_h, stride_log_lambda_n, + stride_oz, stride_oh, stride_om, stride_ok, + Z, H, M, N, P_SEQ, + num_groups, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, HAS_SEQ_START: tl.constexpr, + DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, +): + input_dtype = Q.dtype.element_ty + # -- grid id -- + start_m = tl.program_id(0) + off_h = tl.program_id(1) + off_z = tl.program_id(2) + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + log2e: tl.constexpr = 1.4426950408889634 + loge2: tl.constexpr = 0.6931471805599453 + qk_scale = sm_scale * log2e + + # offset pointers for (batch, head) + off_hk = off_h // num_groups + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_hk * stride_kh + V += off_z * stride_vz + off_hk * stride_vh + LOG_LAMBDA += off_z * stride_log_lambda_z + off_h * stride_log_lambda_h + O += off_z * stride_oz + off_h * stride_oh + L += (off_z * H + off_h) * M # l's shape is (B, H, M) + + offs_m_base = tl.arange(0, BLOCK_M) + offs_m = start_m * BLOCK_M + offs_m_base + offs_n_base = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + + + # initialize pointers to value-like data + q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) + log_lambda_out_ptrs = LOG_LAMBDA + (P_SEQ + offs_m) * stride_log_lambda_n + o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok) # (BLOCK_M, BLOCK_DMODEL) + l_ptrs = L + offs_m + + # initialize pointer to m and l, fp32 for accumulators + m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + # load q + if DIVISIBLE_M: + q = tl.load(q_ptrs, cache_modifier=".cg") + log_lambda_out = tl.load(log_lambda_out_ptrs, cache_modifier=".cg") + else: + mask_m = offs_m < M + q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=".cg") + log_lambda_out = tl.load(log_lambda_out_ptrs, mask=mask_m, cache_modifier=".cg") + + #Dot I trick: to place q in registers, it saves shared memory + # if BLOCK_DMODEL < 128: + # I = tl.where(offs_k[:, None] == offs_k, + # tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype), + # tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype)) + # q = tl.dot(q, I, input_precision="ieee").to(input_dtype) + # else: + # I = tl.where(offs_m_base[:, None] == offs_m_base, + # tl.full((BLOCK_M, BLOCK_M), 1.0, dtype=input_dtype), + # tl.full((BLOCK_M, BLOCK_M), 0.0, dtype=input_dtype)) + # q = tl.dot(I, q, input_precision="ieee").to(input_dtype) + + # NOTE: Loop-Bound-For-N + # The indices in m-dimension that this block may access is in `[start_m * BLOCK_M, (start_m + 1) * BLOCK_M)`. + # According to the rule of causal masking, then max index in n-dimension that this block may access + # is `P_SEQ + (start_m + 1) * BLOCK_M`. + # However, the upper bound of index in n-dimension should never exceed the sequence length of k/v(`P_SEQ + N_CTX`). + # `P_SEQ + (start_m + 1) * BLOCK_M` may be larger than `N`. + # At this case, there would be illegal memory access when loading k & v tiles + # if mask_n is not applied for loading(only when `DIVISIBLE_N`` is true). + # See also https://github.com/FlagOpen/FlagAttention/pull/8 + if IS_CAUSAL: + hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M) + if LARGER_M: + hi = tl.maximum(0, hi) + else: + hi = N + + offs_n_init = offs_n_base + if HAS_SEQ_START: + SEQ_START += off_z + seq_start = tl.load(SEQ_START) + lo = tl.minimum(seq_start, hi) + lo = (lo // BLOCK_N) * BLOCK_N + offs_n_init += lo + else: + lo = 0 + seq_start = 0 + + # loop over k, v and update accumulators + k_ptrs = K + (offs_k[:, None] * stride_kk + offs_n_init[None, :] * stride_kn) # (BLOCK_DMODEL, BLOCK_N) + v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL) + log_lambda_in_ptrs = LOG_LAMBDA + (offs_n_init * stride_log_lambda_n) # (BLOCK_N, BLOCK_DMODEL) + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + offs_n = start_n + offs_n_base + + # -- load k, v -- + if DIVISIBLE_N: + k = tl.load(k_ptrs, cache_modifier=".cg") + v = tl.load(v_ptrs, cache_modifier=".cg") + log_lambda_in = tl.load(log_lambda_in_ptrs, cache_modifier=".cg") + else: + mask_n = offs_n < N + k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=".cg") + v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=".cg") + log_lambda_in = tl.load(log_lambda_in_ptrs, mask=mask_n, cache_modifier=".cg") + + # -- compute qk --- + # s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + s = tl.dot(q, k, input_precision="ieee") * qk_scale + decay_bias = log_lambda_out[:, None] - log_lambda_in[None, :] + s += decay_bias * log2e + + if not DIVISIBLE_N: + s = tl.where(mask_n[None, :], s, float("-inf")) + if IS_CAUSAL: + causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :] + s = tl.where(causal_mask, s, float("-inf")) + if HAS_SEQ_START: + s = tl.where(offs_n[None, :] >= seq_start, s, float("-inf")) + + + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(s, 1)) + alpha = tl.math.exp2((m_i - m_i_new)) + p = tl.math.exp2(s - m_i_new[:, None]) + + # -- compute partial sumexpn before applying dropout + p_sum = tl.sum(p, 1) + + + # -- scale and update acc: acc *= alpha[:, None]-- + acc *= alpha[:, None] + acc += tl.dot(p.to(input_dtype), v, input_precision="ieee") + + # -- update m_i and l_i -- + l_i = l_i * alpha + p_sum + m_i = m_i_new + # update pointers + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vn + log_lambda_in_ptrs += BLOCK_N * stride_log_lambda_n + + # write back l & o + if IS_CAUSAL and (LARGER_M or HAS_SEQ_START): + is_empty_line = (offs_m + P_SEQ) < seq_start + acc = tl.where(is_empty_line[:, None], 0.0, acc * (1.0 / l_i[:, None])) + l = tl.where(is_empty_line, float("-inf"), m_i * loge2 + tl.log(l_i)) + else: + acc = acc * (1.0 / l_i[:, None]) + l = m_i * loge2 + tl.log(l_i) # log(normalizer) + + + if DIVISIBLE_M: + tl.store(l_ptrs, l, cache_modifier=".cg") + tl.store(o_ptrs, acc.to(input_dtype), cache_modifier=".cg") + else: + tl.store(l_ptrs, l, mask=mask_m, cache_modifier=".cg") + tl.store(o_ptrs, acc.to(input_dtype), mask=mask_m[:, None], cache_modifier=".cg") + + +# --------------------------- Backward --------------------------- +# NOTE: this function can be overwritten at runtime to use your custom config +def get_bwd_config(B, H, M, N, D, causal): + if torch.cuda.get_device_capability() == (9, 0): + if not causal: + BLOCK_M = 128 if D <= 64 else 64 + BLOCK_N = 64 + num_stages = 2 + num_warps = 4 + else: + BLOCK_M = 64 + BLOCK_N = 64 + num_stages = 3 if D <= 64 else 2 + num_warps = 4 + elif torch.cuda.get_device_capability() == (8, 0): + if not causal: + BLOCK_M = 128 if D <= 64 else 64 + BLOCK_N = 64 + num_stages = 2 + num_warps = 4 + else: + BLOCK_M = 64 + BLOCK_N = 64 + num_stages = 3 if D <= 64 else 2 + num_warps = 4 + elif torch.cuda.get_device_capability() == (8, 6): # tune for RTX-3090, device_capability(8, 6) + if not causal: + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 8 + else: + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4 + return (BLOCK_M, BLOCK_N, num_stages, num_warps) + +def get_bwd_kv_config(B, H, M, N, D, causal): + assert causal + if torch.cuda.get_device_capability() == (8, 0): # A100 + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 4, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 128, 4, 8 + elif torch.cuda.get_device_capability() == (8, 6): # tune for RTX-3090, device_capability(8, 6) + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4 + elif torch.cuda.get_device_capability() == (8, 9): # L40S + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 128, 4, 8 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 128, 2, 8 + elif torch.cuda.get_device_capability() == (9, 0): # H100 + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 + return (BLOCK_M, BLOCK_N, num_stages, num_warps) + +def get_bwd_q_config(B, H, M, N, D, causal): + assert causal + if torch.cuda.get_device_capability() == (8, 0): # A100 + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 4, 8 + elif torch.cuda.get_device_capability() == (8, 6): # tune for RTX-3090, device_capability(8, 6) + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4 + elif torch.cuda.get_device_capability() == (8, 9): # L40S + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 4, 4 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4 + elif torch.cuda.get_device_capability() == (9, 0): # H100 + if D <= 64: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 4, 8 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 2, 8 + else: + BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 + return (BLOCK_M, BLOCK_N, num_stages, num_warps) + + +@triton.jit +def _bwd_preprocess( + Out, DO, + Delta, + stride_oz, stride_oh, stride_om, stride_ok, + stride_doz, stride_doh, stride_dom, stride_dok, + stride_dz, stride_dh, stride_dm, + M, + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, + DIVISIBLE_M: tl.constexpr, +): + off_h = tl.program_id(1) + off_z = tl.program_id(2) + Out += off_z * stride_oz + off_h * stride_oh + DO += off_z * stride_doz + off_h * stride_doh + Delta += off_z * stride_dz + off_h * stride_dh + + # compute (Out * Dout).sum() for vector interpretation + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + + # load + o_ptrs = Out + off_m[:, None] * stride_om + off_n[None, :] * stride_ok + do_ptrs = DO + off_m[:, None] * stride_dom + off_n[None, :] * stride_dok + + if DIVISIBLE_M: + o = tl.load(o_ptrs).to(tl.float32) + do = tl.load(do_ptrs).to(tl.float32) + else: + mask_m = off_m < M + o = tl.load(o_ptrs, mask=mask_m[:, None]).to(tl.float32) + do = tl.load(do_ptrs, mask=mask_m[:, None]).to(tl.float32) + + # compute + delta = tl.sum(o * do, axis=1) + + # write-back + d_ptrs = Delta + off_m * stride_dm + if DIVISIBLE_M: + tl.store(d_ptrs, delta) + else: + tl.store(d_ptrs, delta, mask=mask_m) + + +@triton.jit +def _bwd_kv_kernel( + Q, K, V, LOG_LAMBDA, SEQ_START, sm_scale, DO, + DK, DV, DLOG_LAMBDA, + L, + D, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_log_lambda_z, stride_log_lambda_h, stride_log_lambda_n, + stride_doz, stride_doh, stride_dom, stride_dok, + stride_dkz, stride_dkh, stride_dkn, stride_dkk, + stride_dvz, stride_dvh, stride_dvn, stride_dvk, + stride_dlog_lambda_z, stride_dlog_lambda_h, stride_dlog_lambda_n, + Z, H, M, N, P_SEQ, + num_groups, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + CAUSAL: tl.constexpr, + DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, HAS_SEQ_START: tl.constexpr, +): + input_dtype = Q.dtype.element_ty + # -- grid id -- + start_n = tl.program_id(0) + off_h = tl.program_id(1) + off_z = tl.program_id(2) + log2e: tl.constexpr = 1.4426950408889634 + qk_scale = sm_scale * log2e + + # offset pointers for (batch, head) + off_hk = off_h // num_groups + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_hk * stride_kh + V += off_z * stride_vz + off_hk * stride_vh + LOG_LAMBDA += off_z * stride_log_lambda_z + off_h * stride_log_lambda_h + DO += off_z * stride_doz + off_h * stride_doh + + # offset pointers for batch/head + DK += off_z * stride_dkz + off_h * stride_dkh + DV += off_z * stride_dvz + off_h * stride_dvh + DLOG_LAMBDA += off_z * stride_dlog_lambda_z + off_h * stride_dlog_lambda_h + + # offset pointers for batch/head + D += (off_z * H + off_h) * M + L += (off_z * H + off_h) * M + + if CAUSAL: + lo = tl.maximum(start_n * BLOCK_N - P_SEQ, 0) + lo = (lo // BLOCK_M) * BLOCK_M + else: + lo = 0 + + offs_m_init = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m_base = tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_DMODEL) + + # initialize pointers to value-like data + q_ptrs = Q + (offs_m_init[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) + log_lambda_out_ptrs = LOG_LAMBDA + (P_SEQ + offs_m_init) * stride_log_lambda_n # (BLOCK_N, BLOCK_DMODEL) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL) + log_lambda_in_ptrs = LOG_LAMBDA + (offs_n * stride_log_lambda_n) # (BLOCK_N, BLOCK_DMODEL) + do_ptrs = DO + (offs_m_init[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL) + + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :] * stride_dvk) # (BLOCK_N, BLOCK_DMODEL) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk) # (BLOCK_N, BLOCK_DMODEL) + dlog_lambda_in_ptrs = DLOG_LAMBDA + (offs_n * stride_dlog_lambda_n) # (BLOCK_N, BLOCK_DMODEL) + + # k and v stay in SRAM throughout + if DIVISIBLE_N: + v = tl.load(v_ptrs) + k = tl.load(k_ptrs) + log_lambda_in = tl.load(log_lambda_in_ptrs) + else: + mask_n = offs_n < N + v = tl.load(v_ptrs, mask=mask_n[:, None]) + k = tl.load(k_ptrs, mask=mask_n[:, None]) + log_lambda_in = tl.load(log_lambda_in_ptrs, mask=mask_n) + + # If the N block doesn't contain seq_start, no need to loop + if HAS_SEQ_START: + SEQ_START += off_z + seq_start = tl.load(SEQ_START) + hi = tl.where(start_n * BLOCK_N + BLOCK_N >= seq_start - 1, M, lo) + else: + hi = M + + # initialize dk amd dv + dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + dlog_lambda_in = tl.zeros([BLOCK_N], dtype=tl.float32) + + # loop over a col + for start_m in range(lo, hi, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m = start_m + offs_m_base + causal_mask = (P_SEQ + offs_m[None, :]) >= (offs_n[:, None]) # (BLOCK_M, BLOCK_N) + + # load q1, k1, q2, k2, v, do on-chip + if DIVISIBLE_M: + q = tl.load(q_ptrs) + log_lambda_out = tl.load(log_lambda_out_ptrs) + else: + mask_m = offs_m < M + valid_mask = mask_m[None, :] # & mask_n + q = tl.load(q_ptrs, mask=mask_m[:, None]) + log_lambda_out = tl.load(log_lambda_out_ptrs, mask=mask_m) + # recompute p = softmax(qk * sm_scale, dim=-1) + # s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + sT = tl.dot(k, tl.trans(q), input_precision="ieee") * qk_scale + decay_bias = log_lambda_out[None, :] - log_lambda_in[:, None] + sT += decay_bias * log2e + # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd) + # So masking on s is not needed. + # s = tl.where(valid_mask, s , float("-inf")) + # if CAUSAL: + # s = tl.where(causal_mask, s, float("-inf")) + + # -- recompute p --- + if DIVISIBLE_M: + l = tl.load(L + offs_m) + else: + l = tl.load(L + offs_m, mask=mask_m) + pT = tl.math.exp2(sT - l[None, :] * log2e) # (BLOCK_M, BLOCK_N) + + if not DIVISIBLE_M: + pT = tl.where(valid_mask, pT, 0.0) + if CAUSAL: + pT = tl.where(causal_mask, pT, 0.0) + + # compute dv = dot(p, do) + if DIVISIBLE_M: + do = tl.load(do_ptrs) + else: + do = tl.load(do_ptrs, mask=mask_m[:, None]) # (BLOCK_M, BLOCK_DMODEL) + + + dv += tl.dot(pT.to(input_dtype), do, input_precision="ieee") # (BLOCK_N, BLOCK_DMODEL) # still correct + + # compute dp = dot(v, do) + if DIVISIBLE_M: + delta = tl.load(D + offs_m) + else: + delta = tl.load(D + offs_m, mask=mask_m) + # dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + dpT = tl.dot(v, tl.trans(do), input_precision="ieee") + + + # compute ds = p * (dp - delta[:, None]) + dsT = pT * (dpT - delta[None, :]) # (BLOCK_M, BLOCK_N) + + if not DIVISIBLE_M: + dsT = tl.where(valid_mask, dsT, 0.0) + if CAUSAL: + dsT = tl.where(causal_mask, dsT, 0.0) + + # compute dk = dot(ds.T, q) masking + dk += tl.dot(dsT.to(input_dtype), q, input_precision="ieee") + dlog_lambda_in += -tl.sum(dsT, axis=1) + + # increment pointers + q_ptrs += BLOCK_M * stride_qm + log_lambda_out_ptrs += BLOCK_M * stride_log_lambda_n + do_ptrs += BLOCK_M * stride_dom + + dk *= sm_scale + if HAS_SEQ_START: + # Mask out + seq_mask = (offs_n >= seq_start) + dk = tl.where(seq_mask[:, None], dk, 0.0) + dv = tl.where(seq_mask[:, None], dv, 0.0) + dlog_lambda_in = tl.where(seq_mask, dlog_lambda_in, 0.0) + if DIVISIBLE_N: + tl.store(dk_ptrs, dk.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL) + tl.store(dv_ptrs, dv.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL,) + tl.store(dlog_lambda_in_ptrs, dlog_lambda_in.to(tl.float32)) # (BLOCK_N, BLOCK_DMODEL,) + else: + tl.store(dk_ptrs, dk.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL) + tl.store(dv_ptrs, dv.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL) + tl.store(dlog_lambda_in_ptrs, dlog_lambda_in.to(tl.float32), mask=mask_n) # (BLOCK_N, BLOCK_DMODEL,) + + +@triton.jit +def _bwd_q_kernel( + Q, K, V, LOG_LAMBDA, SEQ_START, sm_scale, DO, + DQ, DLOG_LAMBDA, + L, + D, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_log_lambda_z, stride_log_lambda_h, stride_log_lambda_n, + stride_doz, stride_doh, stride_dom, stride_dok, + stride_dqz, stride_dqh, stride_dqm, stride_dqk, + stride_dlog_lambda_z, stride_dlog_lambda_h, stride_dlog_lambda_n, + Z, H, M, N, P_SEQ, + num_groups, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, HAS_SEQ_START: tl.constexpr, + DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, +): + input_dtype = Q.dtype.element_ty + # -- grid id -- + start_m = tl.program_id(0) + off_h = tl.program_id(1) + off_z = tl.program_id(2) + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + log2e: tl.constexpr = 1.4426950408889634 + qk_scale = sm_scale * log2e + + # offset pointers for (batch, head) + off_hk = off_h // num_groups + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_hk * stride_kh + V += off_z * stride_vz + off_hk * stride_vh + LOG_LAMBDA += off_z * stride_log_lambda_z + off_h * stride_log_lambda_h + DO += off_z * stride_doz + off_h * stride_doh + D += (off_z * H + off_h) * M + L += (off_z * H + off_h) * M + + # offset pointers for batch/head + DQ += off_z * stride_dqz + off_h * stride_dqh + DLOG_LAMBDA += off_z * stride_dlog_lambda_z + off_h * stride_dlog_lambda_h + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_DMODEL) + + # initialize pointers to value-like data + q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) + log_lambda_out_ptrs = LOG_LAMBDA + (P_SEQ + offs_m) * stride_log_lambda_n + + dq_ptrs = DQ + (offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk) # (BLOCK_M, BLOCK_DMODEL) + dlog_lambda_out_ptrs = DLOG_LAMBDA + (P_SEQ + offs_m) * stride_dlog_lambda_n + do_ptrs = DO + (offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL) + + # pointer to row-wise quantities in value-like data + d_ptrs = D + offs_m + l_ptrs = L + offs_m + + # load q: it will stay in SRAM throughout + if DIVISIBLE_M: + q = tl.load(q_ptrs) + do = tl.load(do_ptrs) + delta = tl.load(d_ptrs) + l = tl.load(l_ptrs) + log_lambda_out = tl.load(log_lambda_out_ptrs) + else: + mask_m = offs_m < M + q = tl.load(q_ptrs, mask=mask_m[:, None]) + do = tl.load(do_ptrs, mask=mask_m[:, None]) + delta = tl.load(d_ptrs, mask=mask_m) + l = tl.load(l_ptrs, mask=mask_m) + log_lambda_out = tl.load(log_lambda_out_ptrs, mask=mask_m) + + # initialize dq + dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dlog_lambda_out = tl.zeros([BLOCK_M], dtype=tl.float32) + + # loop over k, v and update accumulator + # see note "Loop-Bound-For-N" + if CAUSAL: + hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M) + if LARGER_M: + hi = tl.maximum(0, hi) + else: + hi = N + + offs_n_base = tl.arange(0, BLOCK_N) + offs_n_init = offs_n_base + if HAS_SEQ_START: + SEQ_START += off_z + seq_start = tl.load(SEQ_START) + lo = tl.minimum(seq_start, hi) + lo = (lo // BLOCK_N) * BLOCK_N + offs_n_init += lo + else: + lo = 0 + k_ptrs = K + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL) + v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL) + log_lambda_in_ptrs = LOG_LAMBDA + (offs_n_init * stride_log_lambda_n) + + # loop over a row + for start_n in range(lo, hi, BLOCK_N): + offs_n = start_n + offs_n_base + + # load k1, k2, v on chip + if DIVISIBLE_N: + v = tl.load(v_ptrs) + k = tl.load(k_ptrs) + log_lambda_in = tl.load(log_lambda_in_ptrs) + else: + mask_n = offs_n < N + v = tl.load(v_ptrs, mask=mask_n[:, None]) + k = tl.load(k_ptrs, mask=mask_n[:, None]) + log_lambda_in = tl.load(log_lambda_in_ptrs, mask=mask_n) + + + # recompute p = softmax(qk * sm_scale, dim=-1) + if not DIVISIBLE_N: + valid_mask = mask_n[None, :] # & mask_m[:, None] + if CAUSAL: + causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N) + # s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + s = tl.dot(q, tl.trans(k), input_precision="ieee") * qk_scale + decay_bias = log_lambda_out[:, None] - log_lambda_in[None, :] + s += decay_bias * log2e + + # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd) + # So masking on s is not needed. + # if CAUSAL: + # s = tl.where(causal_mask & valid_mask, s, float("-inf")) + # else: + # s = tl.where(valid_mask, s, float("-inf")) + p = tl.math.exp2(s - l[:, None] * log2e) # (BLOCK_M, BLOCK_N) + + # compute dp = dot(v, do) + # dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + dp = tl.dot(do.to(input_dtype), tl.trans(v), input_precision="ieee") + + + # no need to mask dp + # if CAUSAL: + # dp = tl.where(causal_mask & valid_mask, dp, 0.0) + # else: + # dp = tl.where(valid_mask, dp, 0.0) + + # compute ds = p * (dp - delta[:, None]) + # move scale out to dq at last + ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N) + + # mask ds to ensure no small values + if not DIVISIBLE_N: + ds = tl.where(valid_mask, ds, 0.0) + if CAUSAL: + ds = tl.where(causal_mask, ds, 0.0) + if HAS_SEQ_START: + ds = tl.where(offs_n[None, :] >= seq_start, ds, 0.0) + + dq += tl.dot(ds.to(input_dtype), k, input_precision="ieee") + dlog_lambda_out += tl.sum(ds, axis=1) + + # increment pointers + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vn + log_lambda_in_ptrs += BLOCK_N * stride_log_lambda_n + + dq *= sm_scale + if DIVISIBLE_M: + tmp = tl.load(dlog_lambda_out_ptrs) + else: + tmp = tl.load(dlog_lambda_out_ptrs, mask=mask_m) + dlog_lambda_out += tmp + if DIVISIBLE_M: + tl.store(dq_ptrs, dq.to(input_dtype)) + tl.store(dlog_lambda_out_ptrs, dlog_lambda_out) + else: + tl.store(dq_ptrs, dq.to(input_dtype), mask=mask_m[:, None]) + tl.store(dlog_lambda_out_ptrs, dlog_lambda_out, mask=mask_m) + + + +@pytest.mark.parametrize("Z, H, M, N, HEAD_DIM", [(4, 2, 1020, 2098, 64), (4, 2, 1024, 2048, 64)]) +@pytest.mark.parametrize("causal", [True]) +def test_op(Z, H, M, N, HEAD_DIM, causal, dtype=torch.bfloat16): + torch.manual_seed(24) + q = (torch.empty((Z, H, M, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + fgate_logit = torch.empty((Z, H, N), dtype=torch.float32, device="cuda").uniform_(5, 10) + log_fgate = torch.nn.functional.logsigmoid(fgate_logit).requires_grad_() + seq_start = torch.randint(low=0, high=N, size=(Z,), dtype=torch.long, device="cuda") + # seq_start = torch.randint(low=0, high=10, size=(Z,), dtype=torch.long, device="cuda") + # seq_start = torch.full(fill_value=0, size=(Z,), dtype=torch.long, device="cuda") + sm_scale = 0.5 + dout = torch.randn_like(q) + # reference implementation + P_SEQ = N - M + mask = torch.tril(torch.ones((M, N), device="cuda"), diagonal=P_SEQ) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + p = p.float() + + log_lambda = torch.cumsum(log_fgate, dim=-1) + decay_bias = log_lambda[..., -M:, None] - log_lambda[..., None, :] + p = p + decay_bias + if causal: + p[:, :, mask == 0] = float("-inf") + + attention_mask = torch.arange(N, device="cuda") < seq_start[:, None, None, None] + p = torch.where(attention_mask, float("-inf"), p) + p = torch.softmax(p.float(), dim=-1).to(dtype) + p = p.clone() + p[torch.isnan(p)] = 0.0 + # p = torch.exp(p) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + ref_dlog_fgate, log_fgate.grad = log_fgate.grad.clone(), None + # triton implementation + tri_out = forgetting_attention(q, k, v, log_fgate, head_first=True, seq_start=seq_start, sm_scale=sm_scale) + tri_out = tri_out.to(dtype) + + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + tri_dlog_fgate, log_fgate.grad = log_fgate.grad.clone(), None + # compare + # assert torch.allclose(tri_log_normalizer[~torch.isnan(tri_log_normalizer)], ref_log_normalizer[~torch.isnan(ref_log_normalizer)], atol=1e-2, rtol=0) + assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0), (ref_out - tri_out).abs().max() + rtol = 0 + # Relative tolerance workaround for known hardware limitation of MI200 GPU. + # For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + # if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": + # rtol = 1e-2 + assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=rtol), (ref_dv - tri_dv).abs().max() + assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=rtol), (ref_dk - tri_dk).abs().max() + assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=rtol), (ref_dq - tri_dq).abs().max() + assert torch.allclose(ref_dlog_fgate, tri_dlog_fgate, atol=1e-2, rtol=rtol), (ref_dlog_fgate - tri_dlog_fgate).abs().max() + +try: + from flash_attn.flash_attn_interface import \ + flash_attn_qkvpacked_func as flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') +BATCH, N_HEADS, HEAD_DIM = 4, 32, 128 +# vary seq length for fixed head and batch=4 +configs = [] +for mode in ["fwd", "bwd"]: +# for mode in ["bwd"]: + # for causal in [True, False]: + for causal in [True]: + if mode == "bwd" and not causal: + continue + configs.append( + triton.testing.Benchmark( + x_names=["N_CTX"], + # x_vals=[2**i for i in range(10, 15)], + x_vals=[2**i for i in range(14, 15)], + line_arg="provider", + # line_vals=["triton-fp16", "flag"] + (["flash"] if HAS_FLASH else []), + # line_names=["Triton [FP16]", "Flag"] + (["Flash-2"] if HAS_FLASH else []), + line_vals=["flag"] + (["flash"] if HAS_FLASH else []), + line_names=["Flag"] + (["Flash-2"] if HAS_FLASH else []), + styles=[("red", "-"), ("blue", "-"), ("green", "-")], + ylabel="ms", + plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}", + args={ + "H": N_HEADS, + "BATCH": BATCH, + "HEAD_DIM": HEAD_DIM, + "mode": mode, + "causal": causal, + }, + )) + + +@triton.testing.perf_report(configs) +def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, mode, provider, device="cuda"): + assert mode in ["fwd", "bwd"] + warmup = 25 + rep = 100 + dtype = torch.bfloat16 + if "flag" in provider: + q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + fgate_logit = torch.empty((BATCH, H, N_CTX), dtype=torch.float32, device="cuda").uniform_(5, 10) + log_fgate = torch.nn.functional.logsigmoid(fgate_logit).requires_grad_() + # if mode == "fwd" and "fp8" in provider: + # q = q.to(torch.float8_e5m2) + # k = k.to(torch.float8_e5m2) + # v = v.permute(0, 1, 3, 2).contiguous() + # v = v.permute(0, 1, 3, 2) + # v = v.to(torch.float8_e5m2) + sm_scale = 1.3 + fn = lambda: forgetting_attention(q, k, v, log_fgate, head_first=True, sm_scale=sm_scale) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + if provider == "flash": + qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) + fn = lambda: flash_attn_func(qkv, causal=causal) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM + total_flops = 2 * flops_per_matmul + if causal: + total_flops *= 0.5 + if mode == "bwd": + total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) + return total_flops / ms * 1e-9 + + +if __name__ == "__main__": + # only works on post-Ampere GPUs right now + bench_flash_attention.run(save_path=".", print_data=True) diff --git a/ops/forgetting_attention_std.py b/ops/forgetting_attention_std.py new file mode 100644 index 0000000000000000000000000000000000000000..7763fbae480987a777057c67076272cfd8345af6 --- /dev/null +++ b/ops/forgetting_attention_std.py @@ -0,0 +1,72 @@ +""" +Forgetting Attention - 标准 Softmax 版本 +在 forgetting_attention.py 最后添加这个函数 +""" + +import math +import torch +import torch.nn.functional as F +from einops import rearrange +from typing import Optional + + +def forgetting_attention_std( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + log_fgate: torch.Tensor, + *, + head_first: bool = False, + seq_start: Optional[torch.Tensor] = None, + sm_scale: Optional[float] = None, +) -> torch.Tensor: + """标准 Softmax 版本的 Forgetting Attention""" + + if not head_first: + q = rearrange(q, "b t h d -> b h t d") + k = rearrange(k, "b t h d -> b h t d") + v = rearrange(v, "b t h d -> b h t d") + log_fgate = rearrange(log_fgate, "b t h -> b h t") + + B, H, T_q, D = q.shape + T_k = k.shape[2] + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(D) + + # 计算 QK 分数 + scores = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale + + # 处理 seq_start + log_fgate_masked = log_fgate.float() + if seq_start is not None: + log_fgate_masked = log_fgate_masked.clone() + mask_idx = torch.arange(T_k, device=q.device)[None, None, :] < seq_start[:, None, None] + log_fgate_masked[mask_idx] = 0.0 + + # 计算累积衰减 + log_lambda = torch.cumsum(log_fgate_masked, dim=-1) + decay_bias = log_lambda[:, :, :T_q, None] - log_lambda[:, :, None, :] + scores = scores + decay_bias + + # Causal mask + P_SEQ = T_k - T_q + causal_mask = torch.triu(torch.ones((T_q, T_k), dtype=torch.bool, device=q.device), diagonal=P_SEQ + 1) + scores = scores.masked_fill(causal_mask[None, None, :, :], float('-inf')) + + # seq_start mask + if seq_start is not None: + seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None] + scores = scores.masked_fill(seq_mask, float('-inf')) + + # Softmax + attn = F.softmax(scores, dim=-1) + attn = torch.nan_to_num(attn, 0.0) + + # 计算输出 + out = torch.matmul(attn.to(v.dtype), v) + + if not head_first: + out = rearrange(out, "b h t d -> b t h d") + + return out diff --git a/ops/framework_mock.py b/ops/framework_mock.py new file mode 100644 index 0000000000000000000000000000000000000000..15389c862a4bf4b8aeff8d0c3bfd1ffaf05c3fab --- /dev/null +++ b/ops/framework_mock.py @@ -0,0 +1,25 @@ +""" +Mock framework module for ndr geometric attention +只保留必要的部分 +""" +import torch +from typing import Optional, Any + +class visualize: + """Mock visualize class""" + @staticmethod + def attention(*args, **kwargs): + """Dummy attention visualization""" + pass + + @staticmethod + def plot(*args, **kwargs): + """Dummy plot""" + pass + +# Mock其他可能需要的功能 +def get_logger(name: str): + """Mock logger""" + import logging + return logging.getLogger(name) + diff --git a/ops/geometric_attention/__init__.py b/ops/geometric_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6b40ea5964a343581d47eeb07e17b0919b8d55a --- /dev/null +++ b/ops/geometric_attention/__init__.py @@ -0,0 +1 @@ +from .cuda_interface import geometric_attention_activation diff --git a/ops/geometric_attention/__pycache__/__init__.cpython-310.pyc b/ops/geometric_attention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0f62f98d60c6072c8c5028b031cbcb750351893 Binary files /dev/null and b/ops/geometric_attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/ops/geometric_attention/__pycache__/cuda_interface.cpython-310.pyc b/ops/geometric_attention/__pycache__/cuda_interface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15147f591c204599c3ae68cfdd06ee425e2ef158 Binary files /dev/null and b/ops/geometric_attention/__pycache__/cuda_interface.cpython-310.pyc differ diff --git a/ops/geometric_attention/cuda_interface.cu b/ops/geometric_attention/cuda_interface.cu new file mode 100644 index 0000000000000000000000000000000000000000..4dd21cae0bc68d5bc1b990698d149849740eb507 --- /dev/null +++ b/ops/geometric_attention/cuda_interface.cu @@ -0,0 +1,177 @@ +#include + +__global__ void k_cuda_log_sigmoid_forward(int N, float * t, float *out_sigm, float *out_one_minus_sigm){ + int i = threadIdx.x + blockIdx.x * blockDim.x; + if (i 0) ? (coeff - 1) : (-coeff); + float r = (x < 0) ? (coeff - 1) : (-coeff); + grad_out[i] = - grad_sigm[i] * r + grad_one_minus_sigm[i] * r_one_minus; + } +} + +std::vector cuda_log_sigmoid_forward(torch::Tensor input){ + auto o1 = torch::empty_like(input); + auto o2 = torch::empty_like(input); + auto inf = input.flatten(); + + const int N = inf.size(0); + + const int threads = 256; + const int blocks = (N + threads - 1) / threads; + + k_cuda_log_sigmoid_forward<<>>(N, + input.data(), + o1.data(), + o2.data()); + + return {o1, o2}; +} + +std::vector cuda_log_sigmoid_backward(torch::Tensor input, torch::Tensor grad_sigm, torch::Tensor grad_one_minus_sigm){ + auto output = torch::empty_like(input); + auto N = input.flatten().size(0); + + const int threads = 256; + const int blocks = (N + threads - 1) / threads; + + k_cuda_log_sigmoid_backward<<>>(N, + input.data(), + grad_sigm.data(), + grad_one_minus_sigm.data(), + output.data()); + + return {output}; +} + + +typedef torch::PackedTensorAccessor32 float_accessor; + +__global__ void k_cuda_window_sum_forward(float_accessor csum, float_accessor out, int offset){ + const int in_p = threadIdx.z + blockIdx.z * blockDim.z; + const int out_p_mem = threadIdx.y + blockIdx.y * blockDim.y; + const int batch = threadIdx.x + blockIdx.x * blockDim.x; + + const int out_p = out_p_mem + offset; + + if (batch < out.size(0) & out_p_mem < out.size(1) & in_p < out.size(2)){ + float res; + if (in_p == out_p){ + res = 0; + } else { + const int offset = abs(out_p - in_p); + int p_i = out_p + offset - int(in_p > out_p); + const int n_i = out_p - offset; + + p_i = min(p_i, out.size(2) - 1); + + float d_n = (n_i >= 0) ? (csum[batch][out_p_mem][n_i]) : 0.0; + res = (csum[batch][out_p_mem][p_i]) - d_n; + } + + out[batch][out_p_mem][in_p] = res; + } + +} + +__global__ void k_cuda_window_sum_backward(float_accessor grad_in, float_accessor grad_out, int offset){ + const int in_p = threadIdx.z + blockIdx.z * blockDim.z; + const int out_p_mem = threadIdx.y + blockIdx.y * blockDim.y; + const int batch = threadIdx.x + blockIdx.x * blockDim.x; + + const int out_p = out_p_mem + offset; + + if (batch < grad_out.size(0) & out_p_mem < grad_out.size(1) & in_p < grad_out.size(2)){ + const int other = 2 * out_p - in_p; + + float res; + if (in_p == grad_out.size(2) - 1){ + res = 0; + for (int i = 0; i < other + int(in_p != out_p); ++i){ + res += grad_in[batch][out_p_mem][i]; + } + } else if (in_p == out_p){ + res = grad_in[batch][out_p_mem][min(in_p + 1, grad_out.size(2) - 1)]; + } else if (in_p < out_p){ + res = -grad_in[batch][out_p_mem][in_p]; + if (other < grad_in.size(2)) + res -= grad_in[batch][out_p_mem][other]; + } else { + res = grad_in[batch][out_p_mem][in_p + 1]; + if (other >= 0) + res += grad_in[batch][out_p_mem][other]; + } + + grad_out[batch][out_p_mem][in_p] = res; + } +} + +dim3 get_grid_size(torch::Tensor target, dim3 block_dim){ + return dim3( + (target.size(0) + block_dim.x - 1) / block_dim.x, + (target.size(1) + block_dim.y - 1) / block_dim.y, + (target.size(2) + block_dim.z - 1) / block_dim.z + ); +} + +torch::Tensor cuda_window_sum_forward(torch::Tensor input, int offset){ + auto out = torch::empty_like(input); + + dim3 block_size(2, 2, 32); + k_cuda_window_sum_forward<<>>( + input.packed_accessor32(), + out.packed_accessor32(), + offset + ); + + return out; +} + +torch::Tensor cuda_window_sum_backward(torch::Tensor grad_in, int offset){ + auto out = torch::empty_like(grad_in); + + dim3 block_size(2, 2, 32); + k_cuda_window_sum_backward<<>>( + grad_in.packed_accessor32(), + out.packed_accessor32(), + offset + ); + + return out; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "cuda_log_sigmoid_forward", + &cuda_log_sigmoid_forward, + "Log sigmoid, forward pass" + ); + m.def( + "cuda_log_sigmoid_backward", + &cuda_log_sigmoid_backward, + "Log sigmoid, backward pass" + ); + m.def( + "cuda_window_sum_forward", + &cuda_window_sum_forward, + "Window sum, forward pass" + ); + m.def( + "cuda_window_sum_backward", + &cuda_window_sum_backward, + "Window sum, backward pass" + ); +} \ No newline at end of file diff --git a/ops/geometric_attention/cuda_interface.py b/ops/geometric_attention/cuda_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..7544d0ea763c2125ddff15a76f06821ca4a35188 --- /dev/null +++ b/ops/geometric_attention/cuda_interface.py @@ -0,0 +1,93 @@ +import os +import torch +import multiprocessing +from typing import Tuple, Optional +import torch.nn.functional as F +import filelock # 用filelock替代framework.utils.LockFile + +# Just in time import +# https://pytorch.org/tutorials/advanced/cpp_extension + +dirname = os.path.dirname(__file__) +filename = os.path.join(dirname, 'cuda_interface.cu') +outdir = "./cache/geometric_attention" +os.makedirs(outdir, exist_ok=True) + +cuda_log_sigmoid_backward = None +cuda_log_sigmoid_forward = None +cuda_window_sum_forward = None +cuda_window_sum_backward = None + +def load_extension(): + global cuda_log_sigmoid_forward, cuda_log_sigmoid_backward + global cuda_window_sum_forward, cuda_window_sum_backward + if cuda_log_sigmoid_forward is not None: + return + + # 使用filelock替代framework.utils.LockFile + lock = filelock.FileLock(outdir + "/lock.lock") + with lock: + from torch.utils.cpp_extension import load + + os.environ["MAX_JOBS"] = str(multiprocessing.cpu_count()) + ext = load( + extra_cuda_cflags=['--ftemplate-depth=1024'], + name="geometric_attention_cuda_interface", + sources=[filename], verbose=True) + + cuda_log_sigmoid_forward = ext.cuda_log_sigmoid_forward + cuda_log_sigmoid_backward = ext.cuda_log_sigmoid_backward + cuda_window_sum_forward = ext.cuda_window_sum_forward + cuda_window_sum_backward = ext.cuda_window_sum_backward + + +class LogSigmoidFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x = x.detach().contiguous() + ctx.save_for_backward(x) + a, b = cuda_log_sigmoid_forward(x) + return a, b + + @staticmethod + def backward(ctx, grad_in_sigm: torch.Tensor, grad_in_one_minus: torch.tensor) -> torch.Tensor: + xf, = ctx.saved_tensors + ga = grad_in_sigm.contiguous() + gb = grad_in_one_minus.contiguous() + return cuda_log_sigmoid_backward(xf, ga, gb)[0] + + +class WindowSumFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, csum: torch.Tensor, offset: int) -> torch.Tensor: + ctx.saved_offset = offset + c2 = csum.detach().contiguous().flatten(end_dim=-3) + res = cuda_window_sum_forward(c2, offset) + return res.view_as(csum) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: + offset = ctx.saved_offset + go = grad_output.contiguous().flatten(end_dim=-3) + res = cuda_window_sum_backward(go, offset) + return res.view_as(grad_output), None + + +def window_sum(x: torch.Tensor, offset: int) -> torch.Tensor: + load_extension() + return WindowSumFunction.apply(x, offset) + + +def log_sigmoid(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + load_extension() + return LogSigmoidFunction.apply(x) + + +def geometric_attention_activation(logits: torch.Tensor, mask: Optional[torch.Tensor] = None, pos_offset: int = 0, + normalize: bool = True) -> torch.Tensor: + p, one_minus_p = log_sigmoid(logits) + not_previos = window_sum(one_minus_p.cumsum(-1), pos_offset) + + probs = (not_previos + p).exp() + + return F.normalize(probs, 1, -1) if normalize else probs diff --git a/ops/geometric_attention/cuda_interface.py.bak b/ops/geometric_attention/cuda_interface.py.bak new file mode 100644 index 0000000000000000000000000000000000000000..723cd5ace431e8a1f4cd28f679566550338b95e2 --- /dev/null +++ b/ops/geometric_attention/cuda_interface.py.bak @@ -0,0 +1,94 @@ +import os +import torch +import multiprocessing +from framework.utils import LockFile +from typing import Tuple, Optional +import torch.nn.functional as F + +# Just in time import +# https://pytorch.org/tutorials/advanced/cpp_extens + +dirname = os.path.dirname(__file__) +filename = os.path.join(dirname, 'cuda_interface.cu') +outdir = "./cache/geometric_attention" +os.makedirs(outdir, exist_ok=True) + +cuda_log_sigmoid_backward = None +cuda_log_sigmoid_forward = None +cuda_window_sum_forward = None +cuda_window_sum_backward = None + +def load_extension(): + global cuda_log_sigmoid_forward, cuda_log_sigmoid_backward + global cuda_window_sum_forward, cuda_window_sum_backward + if cuda_log_sigmoid_forward is not None: + return + + with LockFile(outdir + "/lock"): + from torch.utils.cpp_extension import load + + os.environ["MAX_JOBS"] = str(multiprocessing.cpu_count()) + ext = load( + extra_cuda_cflags=['--ftemplate-depth=1024'], + name="geometric_attention_cuda_interface", + sources=[filename], verbose=True) + #, build_directory=outdir) + + cuda_log_sigmoid_forward = ext.cuda_log_sigmoid_forward + cuda_log_sigmoid_backward = ext.cuda_log_sigmoid_backward + cuda_window_sum_forward = ext.cuda_window_sum_forward + cuda_window_sum_backward = ext.cuda_window_sum_backward + + +class LogSigmoidFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x = x.detach().contiguous() + ctx.save_for_backward(x) + a, b = cuda_log_sigmoid_forward(x) + return a, b + # return res_a.view_as(x), res_b.view_as(x) + + @staticmethod + def backward(ctx, grad_in_sigm: torch.Tensor, grad_in_one_minus: torch.tensor) -> torch.Tensor: + xf, = ctx.saved_tensors + ga = grad_in_sigm.contiguous() + gb = grad_in_one_minus.contiguous() + return cuda_log_sigmoid_backward(xf, ga, gb)[0] + + +class WindowSumFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, csum: torch.Tensor, offset: int) -> torch.Tensor: + ctx.saved_offset = offset + c2 = csum.detach().contiguous().flatten(end_dim=-3) + res = cuda_window_sum_forward(c2, offset) + return res.view_as(csum) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + offset = ctx.saved_offset + go = grad_output.contiguous().flatten(end_dim=-3) + res = cuda_window_sum_backward(go, offset) + return res.view_as(grad_output), None + + +def window_sum(x: torch.Tensor, offset: int) -> torch.Tensor: + load_extension() + return WindowSumFunction.apply(x, offset) + + +def log_sigmoid(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + load_extension() + return LogSigmoidFunction.apply(x) + + +def geometric_attention_activation(logits: torch.Tensor, mask: Optional[torch.Tensor] = None, pos_offset: int = 0, + normalize: bool = True) -> torch.Tensor: + p, one_minus_p = log_sigmoid(logits) + not_previos = window_sum(one_minus_p.cumsum(-1), pos_offset) + + probs = (not_previos + p).exp() + + # return probs + return F.normalize(probs, 1, -1) if normalize else probs \ No newline at end of file diff --git a/ops/geometric_attention_final.py b/ops/geometric_attention_final.py new file mode 100644 index 0000000000000000000000000000000000000000..487977dabb515fe280854aefcdec44dd5abbfbc2 --- /dev/null +++ b/ops/geometric_attention_final.py @@ -0,0 +1,109 @@ +""" +Geometric Attention - CUDA加速版本 (支持FP16) +""" + +import math +import torch +from einops import rearrange +from typing import Optional + +# 尝试导入CUDA版本 +try: + from forgetting_transformer.ops.geometric_attention.cuda_interface import ( + load_extension, + geometric_attention_activation, + ) + load_extension() + HAS_CUDA = True + print("✅ Using CUDA geometric attention (with FP16 support)") +except Exception as e: + HAS_CUDA = False + print(f"⚠️ CUDA not available: {e}") + + +def geometric_attention_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + head_first: bool = False, + seq_start: Optional[torch.Tensor] = None, + sm_scale: Optional[float] = None, + normalize: bool = True, +) -> torch.Tensor: + if not HAS_CUDA: + raise RuntimeError("CUDA not available") + + # ⭐ 保存原始dtype + original_dtype = q.dtype + needs_cast = original_dtype == torch.float16 + + # ⭐ 如果是FP16,转成FP32 + if needs_cast: + q = q.float() + k = k.float() + v = v.float() + + # Rearrange + if not head_first: + q = rearrange(q, "b t h d -> b h t d") + k = rearrange(k, "b t h d -> b h t d") + v = rearrange(v, "b t h d -> b h t d") + + B, H, T_q, D = q.shape + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(D) + + # Attention scores + logits = torch.matmul(q, k.transpose(-2, -1)) * sm_scale + + # CUDA kernel (FP32) + attn_weights = geometric_attention_activation( + logits, mask=None, pos_offset=0, normalize=normalize + ) + + # Apply to values + output = torch.matmul(attn_weights, v) + + # Rearrange back + if not head_first: + output = rearrange(output, "b h t d -> b t h d") + + # ⭐ 转回原始dtype + if needs_cast: + output = output.to(original_dtype) + + return output + + +def geometric_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + head_first: bool = False, + seq_start: Optional[torch.Tensor] = None, + sm_scale: Optional[float] = None, + normalize: bool = True, +) -> torch.Tensor: + """自动选择CUDA或Python""" + + if HAS_CUDA and q.is_cuda: + try: + return geometric_attention_cuda( + q, k, v, head_first=head_first, + seq_start=seq_start, sm_scale=sm_scale, + normalize=normalize + ) + except Exception as e: + # 不打印太多警告,会刷屏 + pass + + # Fallback + from forgetting_transformer.ops.geometric_attention_std import geometric_attention_std + return geometric_attention_std( + q, k, v, head_first=head_first, + seq_start=seq_start, sm_scale=sm_scale, + normalize=normalize + ) diff --git a/ops/geometric_attention_std.py b/ops/geometric_attention_std.py new file mode 100644 index 0000000000000000000000000000000000000000..071fadffcc62a02dd7528c6b793b00e86dbcb8f2 --- /dev/null +++ b/ops/geometric_attention_std.py @@ -0,0 +1,179 @@ +""" +Geometric Attention - 标准 Softmax 版本 +基于论文 "The Neural Data Router" (Csordás et al., 2022) +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from typing import Optional + + +def geometric_attention_std( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + head_first: bool = False, + seq_start: Optional[torch.Tensor] = None, + sm_scale: Optional[float] = None, + normalize: bool = True, +) -> torch.Tensor: + """ + 标准 Softmax 版本的 Geometric Attention + + Args: + q: Query tensor [B, T, H, D] or [B, H, T, D] if head_first + k: Key tensor [B, T, H, D] or [B, H, T, D] if head_first + v: Value tensor [B, T, H, D] or [B, H, T, D] if head_first + head_first: 是否head维度在前 + seq_start: 序列起始位置 [B] + sm_scale: scaling factor,默认 1/sqrt(D) + normalize: 是否归一化attention weights + + Returns: + output: [B, T, H, D] or [B, H, T, D] if head_first + """ + + # Rearrange to head_first format + if not head_first: + q = rearrange(q, "b t h d -> b h t d") + k = rearrange(k, "b t h d -> b h t d") + v = rearrange(v, "b t h d -> b h t d") + + B, H, T_q, D = q.shape + T_k = k.shape[2] + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(D) + + # Step 1: 计算 content-based logits + logits = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale + # logits: [B, H, T_q, T_k] + + # Step 2: Mask diagonal (不允许attend到自己) + if T_q == T_k: + diag_mask = torch.eye(T_q, dtype=torch.bool, device=q.device) + logits = logits.masked_fill(diag_mask[None, None, :, :], float('-inf')) + + # Step 3: 处理 seq_start mask + if seq_start is not None: + seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None] + logits = logits.masked_fill(seq_mask, float('-inf')) + + # Step 4: Causal mask (如果需要) + # 注意:geometric attention论文中没有causal,如果你的任务需要可以取消注释 + # P_SEQ = T_k - T_q + # causal_mask = torch.triu(torch.ones((T_q, T_k), dtype=torch.bool, device=q.device), diagonal=P_SEQ + 1) + # logits = logits.masked_fill(causal_mask[None, None, :, :], float('-inf')) + + # Step 5: Geometric weighting (核心算法) + attn_weights = geometric_weighting(logits, normalize=normalize) + + # Step 6: 应用attention到values + out = torch.matmul(attn_weights.to(v.dtype), v) + + if not head_first: + out = rearrange(out, "b h t d -> b t h d") + + return out + + +def geometric_weighting( + logits: torch.Tensor, + normalize: bool = True, +) -> torch.Tensor: + """ + 计算geometric attention weights + + 实现论文中的 Equation 7: + A[i,j] = P[i,j] * ∏(1 - P[i,k]) for k closer to i than j + + Args: + logits: [B, H, T_q, T_k] attention logits + normalize: 是否归一化 + + Returns: + weights: [B, H, T_q, T_k] attention weights + """ + B, H, T_q, T_k = logits.shape + + # Step 1: Sigmoid to get matching probabilities + P = torch.sigmoid(logits) # [B, H, T_q, T_k] + + # Step 2: 使用 log-space 计算(数值稳定) + log_P = torch.log(P + 1e-10) + log_one_minus_P = torch.log(1.0 - P + 1e-10) + + # Step 3: 简化版本 - 使用cumsum实现几何分布 + # 这是一个高效的近似,避免了显式的循环 + + # 对于每个位置i,计算其左侧所有位置的log(1-P)累积和 + log_decay_left = log_one_minus_P.cumsum(dim=-1) + + # 计算weights(简化版) + # 完整版本需要根据距离动态选择区间,这里用一个高效近似 + weights = torch.exp(log_P + log_decay_left.roll(1, dims=-1)) + + # 第一个位置特殊处理(没有左侧元素) + # 避免inplace操作 + weights_first = P[:, :, :, :1] # 获取第一列 + weights = torch.cat([weights_first, weights[:, :, :, 1:]], dim=-1) + + # Step 4: 归一化(可选) + if normalize: + weights = F.normalize(weights, p=1, dim=-1) + + # 处理NaN(如果所有位置都是-inf) + weights = torch.nan_to_num(weights, 0.0) + + return weights + + +def geometric_weighting_full( + logits: torch.Tensor, + normalize: bool = True, +) -> torch.Tensor: + """ + 完整版geometric weighting(更慢但更准确) + + 仅在需要最高精度时使用,训练时建议用上面的简化版 + """ + B, H, T_q, T_k = logits.shape + device = logits.device + + P = torch.sigmoid(logits) + log_P = torch.log(P + 1e-10) + log_one_minus_P = torch.log(1.0 - P + 1e-10) + + # 初始化weights + weights = torch.zeros_like(P) + + # 对每个(i,j)计算geometric weight + for i in range(T_q): + for j in range(T_k): + # 找出比j更接近i的所有位置k + if i < j: + # 向右看:closer positions are [i+1, ..., j-1] + closer_positions = range(i + 1, j) + elif i > j: + # 向左看:closer positions are [j+1, ..., i-1] + closer_positions = range(j + 1, i) + else: + # i == j (对角线),已经在外面mask掉了 + continue + + # 计算 ∏(1 - P[i,k]) in log-space + log_prod = sum(log_one_minus_P[:, :, i, k] for k in closer_positions) if closer_positions else 0.0 + + # weights[i,j] = P[i,j] * ∏(1 - P[i,k]) + weights[:, :, i, j] = torch.exp(log_P[:, :, i, j] + log_prod) + + if normalize: + weights = F.normalize(weights, p=1, dim=-1) + + weights = torch.nan_to_num(weights, 0.0) + + return weights \ No newline at end of file diff --git a/ops/layer_with_visualization.py b/ops/layer_with_visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..e4fe057a2ada96dbaaedec64d252ca73ce89f1a0 --- /dev/null +++ b/ops/layer_with_visualization.py @@ -0,0 +1,43 @@ +import torch +import torch.nn +from typing import Dict, Any + + +class LayerWithVisualization(torch.nn.Module): + def __init__(self): + super().__init__() + self.visualization_enabled = False + + def prepare(self): + # Should be called before the training step + pass + + def plot(self, options: Dict[str, Any]) -> Dict[str, Any]: + raise NotImplementedError() + + +class LayerVisualizer: + def __init__(self, module: torch.nn.Module, options: Dict[str, Any] = {}): + self.modules = [] + self.options = options + self.curr_options = None + for n, m in module.named_modules(): + if isinstance(m, LayerWithVisualization): + self.modules.append((n, m)) + + def plot(self) -> Dict[str, Any]: + res = {} + for n, m in self.modules: + res.update({f"{n}/{k}": v for k, v in m.plot(self.curr_options).items()}) + m.visualization_enabled = False + + self.curr_options = None + return res + + def prepare(self, options: Dict[str, Any] = {}): + self.curr_options = self.options.copy() + self.curr_options.update(options) + + for _, m in self.modules: + m.prepare() + m.visualization_enabled = True diff --git a/ops/multi_head_attention.py b/ops/multi_head_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..e425b8b590cadbee85e0c9a34c6894511d97764d --- /dev/null +++ b/ops/multi_head_attention.py @@ -0,0 +1,149 @@ +import torch +import torch.nn +import torch.nn.functional as F +import math +from typing import Optional, Callable, List, Union, Tuple, Dict, Any +from dataclasses import dataclass +from forgetting_transformer.ops.layer_with_visualization import LayerWithVisualization +import forgetting_transformer.ops.framework_mock as framework + + +@dataclass +class AttentionMask: + src_length_mask: Optional[torch.Tensor] + position_mask: Optional[torch.Tensor] + + +class MultiHeadAttentionBase(LayerWithVisualization): + def __init__(self, state_size: int, n_heads: int, dropout: float=0.1, projection_size: Optional[int] = None): + assert state_size % n_heads == 0 + super().__init__() + self.attention_to_visualize = [] + + self.state_size = state_size + self.projection_size = projection_size or (state_size // n_heads) + self.n_heads = n_heads + self.scale = 1.0 / math.sqrt(self.projection_size) + + self.dropout = torch.nn.Dropout(dropout) + + @staticmethod + def apply_logit_masks(logits: torch.Tensor, mask: Optional[AttentionMask], val: float = float("-inf")) -> torch.Tensor: + if mask.position_mask is not None: + # [..., N_out, N_in], broadcast works + logits = logits.masked_fill(mask.position_mask, val) + + if mask.src_length_mask is not None: + # [B, ...., N_in], needs manual shaping + b, i = mask.src_length_mask.shape + pad_dims = logits.ndim - 2 + logits = logits.masked_fill(mask.src_length_mask.view([b] + [1] * pad_dims + [i]), val) + + return logits + + def _masked_softmax(self, logits: torch.Tensor, mask: Optional[AttentionMask]) -> torch.Tensor: + if mask is None or (mask.src_length_mask is None and mask.position_mask is None): + return F.softmax(logits, -1) + + # Output shape: [n_batch * n_heads, n_time_dest, n_time_src] + bb, n_time_dest, n_time_src = logits.shape + + logits = logits.view(bb // self.n_heads, self.n_heads, n_time_dest, n_time_src) + logits = self.apply_logit_masks(logits, mask) + + logits = F.softmax(logits, -1) + return logits.view(bb, n_time_dest, n_time_src) + + def _attention_read(self, mask: Optional[AttentionMask], scores: torch.Tensor, v: torch.Tensor) -> \ + Tuple[torch.Tensor, torch.Tensor]: + # scores: [n_batch * n_heads, n_out, n_in] + # v: [n_nbatch * n_heads, n_in] + # Output data shape [n_batch * n_heads, n_time_dest, data_size] + # Out attention score shape: [n_batch, n_heads, n_time_dest, n_time_src] + s_reshape = scores.view(-1, self.n_heads, *scores.shape[1:]) + # scores = self.dropout(scores) + if self.visualization_enabled: + self.attention_to_visualize.append(s_reshape[0]) + return torch.bmm(scores, v), s_reshape + + def transform_data(self, input: torch.Tensor, proj: Callable[[torch.Tensor], torch.Tensor], + n_projs: int) -> List[torch.Tensor]: + # Input shape: [n_batch, n_steps, n_channels] + # Output: Tuple of n_projs tensors of dimension: [n_batch * n_heads, n_steps, projection_size] + n_batch, n_steps, _ = input.shape + transformed = proj(input).view(n_batch, n_steps, self.n_heads, n_projs, -1). \ + permute(0, 2, 1, 3, 4).contiguous().view(n_batch * self.n_heads, n_steps, n_projs, -1) + return transformed.unbind(dim=2) + + def plot(self, options: Dict[str, Any]) -> Dict[str, Any]: + r = {} + marks = options.get("steplabel") + if options.get("mha.plot_head_details") and self.attention_to_visualize[0].shape[0] > 1: + for head in range(self.attention_to_visualize[0].shape[0]): + r[f"head_{head}"] = framework.visualize.plot.AnimatedHeatmap( + torch.stack([layer[head] for _, layer in enumerate(self.attention_to_visualize)], 0), + ylabel="dest", xlabel="src", textval=False, x_marks=marks, y_marks=marks, ignore_wrong_marks=True) + + r["attention_max"] = framework.visualize.plot.AnimatedHeatmap( + torch.stack([layer.max(0)[0] for _, layer in enumerate(self.attention_to_visualize)], 0), + ylabel="dest", xlabel="src", textval=False, x_marks=marks, y_marks=marks, ignore_wrong_marks=True) + self.attention_to_visualize = [] + return r + + +class AttentionMergeMixin: + def __init__(self, out_size: Optional[int]) -> None: + self.multi_head_merge = torch.nn.Linear(self.n_heads * self.projection_size, out_size or self.state_size, + bias=False) + + def merged_attention(self, n_batch: int, n_out_steps: int, *args, need_weights: bool = False, **kwargs) -> \ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + data, scores = self._attention(*args, **kwargs) + + data = data.view(n_batch, self.n_heads, n_out_steps, -1).permute(0, 2, 1, 3).contiguous().\ + view(n_batch, n_out_steps, -1) + + return self.multi_head_merge(data), scores + + +class AbsPosAttentionBase(MultiHeadAttentionBase): + def get_attention_scores(self, mask: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor) -> torch.Tensor: + logits = torch.bmm(q, k.transpose(1, 2)) + return self._masked_softmax(logits * self.scale, mask) + + def _attention(self, mask: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> \ + torch.Tensor: + # all inputs should have a shape of [n_batch, n_steps, data_size] + # Output shape [n_batch * n_heads, n_time_dest, data_size] + scores = self.get_attention_scores(mask, q, k) + return self._attention_read(mask, scores, v) + + +class MultiHeadAttention(AttentionMergeMixin, AbsPosAttentionBase): + def __init__(self, state_size: int, n_heads: int, dropout: float = 0.1, input_size: Optional[int] = None, + out_size: Optional[int] = None): + super(AttentionMergeMixin, self).__init__(state_size, n_heads, dropout) + + self.data_to_kv = torch.nn.Linear(state_size, 2 * n_heads * self.projection_size, bias=False) + self.data_to_q = torch.nn.Linear(input_size or state_size, n_heads * self.projection_size, bias=False) + + super(MultiHeadAttention, self).__init__(out_size) + self.reset_parameters() + + def forward(self, curr_state: torch.Tensor, attend_to: torch.Tensor, mask: Optional[AttentionMask], + need_weights: bool = False): + # Input and output shape: [n_batch, n_steps, data_size] + k, v = self.transform_data(attend_to, self.data_to_kv, 2) + q, = self.transform_data(curr_state, self.data_to_q, 1) + + data, scores = self.merged_attention(curr_state.shape[0], q.shape[1], mask, q, k, v) + if need_weights: + return data, scores + else: + return data + + def reset_parameters(self): + torch.nn.init.xavier_uniform_(self.data_to_q.weight) + torch.nn.init.xavier_uniform_(self.data_to_kv.weight) + torch.nn.init.xavier_uniform_(self.data_to_kv.weight) diff --git a/ops/multi_head_relative_pos_attention.py b/ops/multi_head_relative_pos_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3af13d41ac2650dab6d533c4e8aaec4f990ae694 --- /dev/null +++ b/ops/multi_head_relative_pos_attention.py @@ -0,0 +1,185 @@ +import torch +import torch.nn +import torch.nn.functional as F +from typing import Optional, Dict, Any +from forgetting_transformer.ops.multi_head_attention import AttentionMask, MultiHeadAttentionBase, AttentionMergeMixin +import forgetting_transformer.ops.framework_mock as framework +import math +from matplotlib import cm + +def shift(posmat: torch.Tensor) -> torch.Tensor: + # Slice out a matrix diagonally. Each successive row is sliced one position to the left compared. + # shape: [n_batch, n_head, n_out, n_in * 2 - 1] + # return: [n_batch, n_head, n_out, n_in] + p = F.pad(posmat, (0, 1, 0, 1)).flatten(-2) # [n_batch, n_head, (n_out + 1) * n_in * 2] + p = p.narrow(-1, posmat.shape[-1] // 2, posmat.shape[-1] * posmat.shape[-2]).view_as(posmat) + + return p.narrow(-1, 0, (posmat.shape[-1] + 1) // 2) + + +class RelativeAttentionBase(MultiHeadAttentionBase): + def __init__(self, state_size: int, n_heads: int, dropout: float, projection_size: Optional[int] = None): + super().__init__(state_size, n_heads, dropout=dropout, projection_size=projection_size) + self.scale = torch.nn.Parameter(torch.tensor([self.scale])) + self.s_bias = torch.nn.Parameter(torch.tensor([0.0])) + self.vis_pos_vs_content = [] + + def get_attention_scores(self, mask: Optional[torch.Tensor], + q_content: torch.Tensor, k_content: torch.Tensor, + q_pos: torch.Tensor, k_pos: torch.Tensor, + pos_offset: int, ar_gate: Optional[torch.Tensor] = None) -> torch.Tensor: + + # shape of q_content, q_pos, k_pos: [n_batch * n_heads, n_steps, data_size] + # k_pos: [n_heads, n_in * 2 - 1, data_size] + # ar_gate: [n_batch*n_heads, n_out, 1] + # Output shape [n_batch * n_heads, n_out, data_size] + + n_batch = q_content.shape[0] // self.n_heads + n_out_steps = q_content.shape[1] + + # content-content addressing + content = torch.bmm(q_content, self.dropout(k_content).transpose(1, 2)) + + # content-pos addressing. + pos = torch.matmul(q_pos.view(n_batch, self.n_heads, n_out_steps, -1), self.dropout(k_pos).transpose(-1, -2)) # [n_batch, n_head, n_out, n_in * 2 - 1] + fpos = shift(pos).flatten(0, 1) + if ar_gate is not None: + fpos = fpos * ar_gate + pos.flatten(0, 1)[..., fpos.shape[-1] - 1:] * (1 - ar_gate) + + # return self._masked_softmax((fpos) * self.scale, mask) + if self.visualization_enabled: + self.vis_pos_vs_content.append((content.view(n_batch, self.n_heads, *content.shape[1:])[0] * self.scale, + fpos.view(n_batch, self.n_heads, *fpos.shape[1:])[0] * self.scale)) + + return self._masked_softmax((content + fpos) * self.scale, mask) + + def _attention(self, mask: Optional[torch.Tensor], + q_content: torch.Tensor, k_content: torch.Tensor, + q_pos: torch.Tensor, k_pos: torch.Tensor, + v: torch.Tensor, pos_offset: int, + ar_gate: Optional[torch.Tensor] = None) -> [torch.Tensor, torch.Tensor]: + + scores = self.get_attention_scores(mask, q_content, k_content, q_pos, k_pos, pos_offset, ar_gate) + + # Scores shape: [n_batch * n_heads, n_out, n_in] + return self._attention_read(mask, scores, v) + + def _get_pos_subset(self, pos_encoding: torch.Tensor, length: int, offset: int) -> torch.Tensor: + l_slice = 2 * length - 1 + assert pos_encoding.shape[0] > l_slice + return pos_encoding.narrow(0, pos_encoding.shape[0] // 2 - length + 1 - offset, 2 * length - 1) + + def plot(self, options: Dict[str, Any]) -> Dict[str, Any]: + r = {} + marks = options.get("steplabel") + if options.get("mha.plot_head_details") and self.vis_pos_vs_content: + for head in range(self.vis_pos_vs_content[0][0].shape[0]): + cont = torch.stack([layer[0][head] for _, layer in enumerate(self.vis_pos_vs_content)], 0) + pos = torch.stack([layer[1][head] for _, layer in enumerate(self.vis_pos_vs_content)], 0) + i = torch.stack([layer[head] for _, layer in enumerate(self.attention_to_visualize)], 0) + content = torch.stack([cont, pos], -1).softmax(-1)[..., 0] + + color = cm.get_cmap("brg")(content.cpu().numpy()) + color[..., -1] = (i * 0.95 + 0.05).cpu().numpy() + + r[f"content_vs_pos_{head}"] = framework.visualize.plot.AnimatedHeatmap(color, ylabel="dest", + xlabel="src", textval=False, x_marks=marks, y_marks=marks, cmap="brg", colorbar=True, + colorbar_ticks=[0, 0.99], colorbar_labels=["pos", "con"], ignore_wrong_marks=True) + + # r["attention_max"] = framework.visualize.plot.AnimatedHeatmap( + # torch.stack([layer.max(0)[0] for _, layer in enumerate(self.attention_to_visualize)], 0), + # ylabel="dest", xlabel="src", textval=False, x_marks=marks, y_marks=marks) + self.vis_pos_vs_content = [] + + r.update(super().plot(options)) + return r + + + +class FixedRelativeMultiheadAttentionBase(RelativeAttentionBase): + def __init__(self, state_size: int, n_heads: int, dropout: float = 0.0, input_size: Optional[int] = None, + projection_size: Optional[int] = None): + super().__init__(state_size, n_heads, dropout, projection_size) + + self.input_size = state_size if input_size is None else input_size + + self.pos_to_pq = torch.nn.Linear(state_size, self.n_heads * self.projection_size, bias=False) + self.register_buffer("pos_encoding", self._create_buffer(1000)) + + def _create_buffer(self, max_len: int): + return framework.layers.sinusoidal_pos_embedding(self.state_size, 2 * max_len - 1, -max_len + 1, + device=self.pos_to_pq.weight.device) + + def get_pos(self, l: int, offset: int) -> torch.Tensor: + if self.pos_encoding.shape[0] < 2 * (l + offset) - 1: + self.pos_encoding = self._create_buffer(int(2**math.ceil(math.log2(2 * (l + offset) - 1)))) + + return self.pos_to_pq(self._get_pos_subset(self.pos_encoding, l, offset)) + + +class FixedRelativeMultiheadAttention(AttentionMergeMixin, FixedRelativeMultiheadAttentionBase): + def __init__(self, state_size: int, n_heads: int, dropout: float = 0.0, global_pos_bias: bool = True, + global_content_bias: bool = True, input_size: Optional[int] = None, absolute_gate: bool = False, + projection_size: Optional[int] = None, output_size: Optional[int] = None): + super(AttentionMergeMixin, self).__init__(state_size, n_heads, dropout, input_size, projection_size=projection_size) + + self.data_to_kv = torch.nn.Linear(state_size, 2 * n_heads * self.projection_size, bias=False) + self.data_to_q = torch.nn.Linear(self.input_size, n_heads * self.projection_size, bias=False) + self.data_to_absgate = torch.nn.Linear(self.input_size, n_heads) \ + if absolute_gate else None + + self.global_content_bias = torch.nn.Parameter(torch.zeros([n_heads, self.projection_size])) \ + if global_content_bias else None + self.global_pos_bias = torch.nn.Parameter(torch.zeros([n_heads, self.projection_size])) \ + if global_pos_bias else None + + super(FixedRelativeMultiheadAttention, self).__init__(output_size) + self.reset_parameters() + + def add_head_specific_bias(self, data: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: + # data [batch * n_heads, len, c] + # bias [n_heads, c] + return (data.view(-1, bias.shape[0], *data.shape[1:]) + bias.unsqueeze(1).type_as(data)).view_as(data) \ + if bias is not None else data + + def forward(self, curr_state: torch.Tensor, attend_to: torch.Tensor, mask: Optional[AttentionMask], + pos_offset: int = 0, need_weights: bool = False): + # curr_state: [§size, out_len, c] + # attend_to: [batch_size, in_len, c] + batch_size, in_len = attend_to.shape[0:2] + out_len = curr_state.shape[1] + + k_content, v = self.transform_data(attend_to, self.data_to_kv, 2) + q, = self.transform_data(curr_state, self.data_to_q, 1) + + k_pos = self.get_pos(in_len, pos_offset).view(-1, self.n_heads, self.projection_size).\ + transpose(0, 1) # n_heads, 2*in_len -1 , projection_size + + q_content = self.add_head_specific_bias(q, self.global_content_bias) + q_pos = self.add_head_specific_bias(q, self.global_pos_bias) + + + absgate = torch.sigmoid(self.transform_data(curr_state, self.data_to_absgate, 1)[0]) \ + if self.data_to_absgate is not None else None + + data, scores = self.merged_attention(batch_size, out_len, mask, q_content, k_content, q_pos, k_pos, v, + pos_offset, ar_gate=absgate, need_weights=need_weights) + + if need_weights: + return data, scores + else: + return data + + def reset_parameters(self): + # # super().reset_parameters() + + torch.nn.init.xavier_uniform_(self.data_to_q.weight) + torch.nn.init.xavier_uniform_(self.pos_to_pq.weight) + torch.nn.init.xavier_uniform_(self.data_to_kv.weight) + torch.nn.init.xavier_uniform_(self.data_to_kv.weight) + + if self.global_content_bias is not None: + self.global_content_bias.data.fill_(0) + + if self.global_pos_bias is not None: + self.global_pos_bias.data.fill_(0) diff --git a/ops/multi_head_relative_pos_attention.py.bak b/ops/multi_head_relative_pos_attention.py.bak new file mode 100644 index 0000000000000000000000000000000000000000..9a1bc9c4b48523b026895c370ea7c310188ae947 --- /dev/null +++ b/ops/multi_head_relative_pos_attention.py.bak @@ -0,0 +1,185 @@ +import torch +import torch.nn +import torch.nn.functional as F +from typing import Optional, Dict, Any +from .multi_head_attention import AttentionMask, MultiHeadAttentionBase, AttentionMergeMixin +import framework +import math +from matplotlib import cm + +def shift(posmat: torch.Tensor) -> torch.Tensor: + # Slice out a matrix diagonally. Each successive row is sliced one position to the left compared. + # shape: [n_batch, n_head, n_out, n_in * 2 - 1] + # return: [n_batch, n_head, n_out, n_in] + p = F.pad(posmat, (0, 1, 0, 1)).flatten(-2) # [n_batch, n_head, (n_out + 1) * n_in * 2] + p = p.narrow(-1, posmat.shape[-1] // 2, posmat.shape[-1] * posmat.shape[-2]).view_as(posmat) + + return p.narrow(-1, 0, (posmat.shape[-1] + 1) // 2) + + +class RelativeAttentionBase(MultiHeadAttentionBase): + def __init__(self, state_size: int, n_heads: int, dropout: float, projection_size: Optional[int] = None): + super().__init__(state_size, n_heads, dropout=dropout, projection_size=projection_size) + self.scale = torch.nn.Parameter(torch.tensor([self.scale])) + self.s_bias = torch.nn.Parameter(torch.tensor([0.0])) + self.vis_pos_vs_content = [] + + def get_attention_scores(self, mask: Optional[torch.Tensor], + q_content: torch.Tensor, k_content: torch.Tensor, + q_pos: torch.Tensor, k_pos: torch.Tensor, + pos_offset: int, ar_gate: Optional[torch.Tensor] = None) -> torch.Tensor: + + # shape of q_content, q_pos, k_pos: [n_batch * n_heads, n_steps, data_size] + # k_pos: [n_heads, n_in * 2 - 1, data_size] + # ar_gate: [n_batch*n_heads, n_out, 1] + # Output shape [n_batch * n_heads, n_out, data_size] + + n_batch = q_content.shape[0] // self.n_heads + n_out_steps = q_content.shape[1] + + # content-content addressing + content = torch.bmm(q_content, self.dropout(k_content).transpose(1, 2)) + + # content-pos addressing. + pos = torch.matmul(q_pos.view(n_batch, self.n_heads, n_out_steps, -1), self.dropout(k_pos).transpose(-1, -2)) # [n_batch, n_head, n_out, n_in * 2 - 1] + fpos = shift(pos).flatten(0, 1) + if ar_gate is not None: + fpos = fpos * ar_gate + pos.flatten(0, 1)[..., fpos.shape[-1] - 1:] * (1 - ar_gate) + + # return self._masked_softmax((fpos) * self.scale, mask) + if self.visualization_enabled: + self.vis_pos_vs_content.append((content.view(n_batch, self.n_heads, *content.shape[1:])[0] * self.scale, + fpos.view(n_batch, self.n_heads, *fpos.shape[1:])[0] * self.scale)) + + return self._masked_softmax((content + fpos) * self.scale, mask) + + def _attention(self, mask: Optional[torch.Tensor], + q_content: torch.Tensor, k_content: torch.Tensor, + q_pos: torch.Tensor, k_pos: torch.Tensor, + v: torch.Tensor, pos_offset: int, + ar_gate: Optional[torch.Tensor] = None) -> [torch.Tensor, torch.Tensor]: + + scores = self.get_attention_scores(mask, q_content, k_content, q_pos, k_pos, pos_offset, ar_gate) + + # Scores shape: [n_batch * n_heads, n_out, n_in] + return self._attention_read(mask, scores, v) + + def _get_pos_subset(self, pos_encoding: torch.Tensor, length: int, offset: int) -> torch.Tensor: + l_slice = 2 * length - 1 + assert pos_encoding.shape[0] > l_slice + return pos_encoding.narrow(0, pos_encoding.shape[0] // 2 - length + 1 - offset, 2 * length - 1) + + def plot(self, options: Dict[str, Any]) -> Dict[str, Any]: + r = {} + marks = options.get("steplabel") + if options.get("mha.plot_head_details") and self.vis_pos_vs_content: + for head in range(self.vis_pos_vs_content[0][0].shape[0]): + cont = torch.stack([layer[0][head] for _, layer in enumerate(self.vis_pos_vs_content)], 0) + pos = torch.stack([layer[1][head] for _, layer in enumerate(self.vis_pos_vs_content)], 0) + i = torch.stack([layer[head] for _, layer in enumerate(self.attention_to_visualize)], 0) + content = torch.stack([cont, pos], -1).softmax(-1)[..., 0] + + color = cm.get_cmap("brg")(content.cpu().numpy()) + color[..., -1] = (i * 0.95 + 0.05).cpu().numpy() + + r[f"content_vs_pos_{head}"] = framework.visualize.plot.AnimatedHeatmap(color, ylabel="dest", + xlabel="src", textval=False, x_marks=marks, y_marks=marks, cmap="brg", colorbar=True, + colorbar_ticks=[0, 0.99], colorbar_labels=["pos", "con"], ignore_wrong_marks=True) + + # r["attention_max"] = framework.visualize.plot.AnimatedHeatmap( + # torch.stack([layer.max(0)[0] for _, layer in enumerate(self.attention_to_visualize)], 0), + # ylabel="dest", xlabel="src", textval=False, x_marks=marks, y_marks=marks) + self.vis_pos_vs_content = [] + + r.update(super().plot(options)) + return r + + + +class FixedRelativeMultiheadAttentionBase(RelativeAttentionBase): + def __init__(self, state_size: int, n_heads: int, dropout: float = 0.0, input_size: Optional[int] = None, + projection_size: Optional[int] = None): + super().__init__(state_size, n_heads, dropout, projection_size) + + self.input_size = state_size if input_size is None else input_size + + self.pos_to_pq = torch.nn.Linear(state_size, self.n_heads * self.projection_size, bias=False) + self.register_buffer("pos_encoding", self._create_buffer(1000)) + + def _create_buffer(self, max_len: int): + return framework.layers.sinusoidal_pos_embedding(self.state_size, 2 * max_len - 1, -max_len + 1, + device=self.pos_to_pq.weight.device) + + def get_pos(self, l: int, offset: int) -> torch.Tensor: + if self.pos_encoding.shape[0] < 2 * (l + offset) - 1: + self.pos_encoding = self._create_buffer(int(2**math.ceil(math.log2(2 * (l + offset) - 1)))) + + return self.pos_to_pq(self._get_pos_subset(self.pos_encoding, l, offset)) + + +class FixedRelativeMultiheadAttention(AttentionMergeMixin, FixedRelativeMultiheadAttentionBase): + def __init__(self, state_size: int, n_heads: int, dropout: float = 0.0, global_pos_bias: bool = True, + global_content_bias: bool = True, input_size: Optional[int] = None, absolute_gate: bool = False, + projection_size: Optional[int] = None, output_size: Optional[int] = None): + super(AttentionMergeMixin, self).__init__(state_size, n_heads, dropout, input_size, projection_size=projection_size) + + self.data_to_kv = torch.nn.Linear(state_size, 2 * n_heads * self.projection_size, bias=False) + self.data_to_q = torch.nn.Linear(self.input_size, n_heads * self.projection_size, bias=False) + self.data_to_absgate = torch.nn.Linear(self.input_size, n_heads) \ + if absolute_gate else None + + self.global_content_bias = torch.nn.Parameter(torch.zeros([n_heads, self.projection_size])) \ + if global_content_bias else None + self.global_pos_bias = torch.nn.Parameter(torch.zeros([n_heads, self.projection_size])) \ + if global_pos_bias else None + + super(FixedRelativeMultiheadAttention, self).__init__(output_size) + self.reset_parameters() + + def add_head_specific_bias(self, data: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: + # data [batch * n_heads, len, c] + # bias [n_heads, c] + return (data.view(-1, bias.shape[0], *data.shape[1:]) + bias.unsqueeze(1).type_as(data)).view_as(data) \ + if bias is not None else data + + def forward(self, curr_state: torch.Tensor, attend_to: torch.Tensor, mask: Optional[AttentionMask], + pos_offset: int = 0, need_weights: bool = False): + # curr_state: [§size, out_len, c] + # attend_to: [batch_size, in_len, c] + batch_size, in_len = attend_to.shape[0:2] + out_len = curr_state.shape[1] + + k_content, v = self.transform_data(attend_to, self.data_to_kv, 2) + q, = self.transform_data(curr_state, self.data_to_q, 1) + + k_pos = self.get_pos(in_len, pos_offset).view(-1, self.n_heads, self.projection_size).\ + transpose(0, 1) # n_heads, 2*in_len -1 , projection_size + + q_content = self.add_head_specific_bias(q, self.global_content_bias) + q_pos = self.add_head_specific_bias(q, self.global_pos_bias) + + + absgate = torch.sigmoid(self.transform_data(curr_state, self.data_to_absgate, 1)[0]) \ + if self.data_to_absgate is not None else None + + data, scores = self.merged_attention(batch_size, out_len, mask, q_content, k_content, q_pos, k_pos, v, + pos_offset, ar_gate=absgate, need_weights=need_weights) + + if need_weights: + return data, scores + else: + return data + + def reset_parameters(self): + # # super().reset_parameters() + + torch.nn.init.xavier_uniform_(self.data_to_q.weight) + torch.nn.init.xavier_uniform_(self.pos_to_pq.weight) + torch.nn.init.xavier_uniform_(self.data_to_kv.weight) + torch.nn.init.xavier_uniform_(self.data_to_kv.weight) + + if self.global_content_bias is not None: + self.global_content_bias.data.fill_(0) + + if self.global_pos_bias is not None: + self.global_pos_bias.data.fill_(0) diff --git a/ops/sliding_window_attention_std.py b/ops/sliding_window_attention_std.py new file mode 100644 index 0000000000000000000000000000000000000000..16c1551290ef8844ee8e3ca9793017ec90f3f896 --- /dev/null +++ b/ops/sliding_window_attention_std.py @@ -0,0 +1,88 @@ +""" +Sliding Window / Hard Attention +Based on "Context Limitations Make Neural Language Models More Human-Like" +(Kuribayashi et al., 2022) +""" + +import math +import torch +import torch.nn.functional as F +from einops import rearrange +from typing import Optional + + +def sliding_window_attention_std( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + head_first: bool = False, + seq_start: Optional[torch.Tensor] = None, + sm_scale: Optional[float] = None, + window_size: int = 2, # 默认2-gram(看前1个token) +) -> torch.Tensor: + """ + Sliding Window Attention + + 硬截断:只能attend到最近window_size个token + """ + + if not head_first: + q = rearrange(q, "b t h d -> b h t d") + k = rearrange(k, "b t h d -> b h t d") + v = rearrange(v, "b t h d -> b h t d") + + B, H, T_q, D = q.shape + T_k = k.shape[2] + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(D) + + # Compute logits + logits = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale + + # Create sliding window mask + mask = create_sliding_window_mask(T_q, T_k, window_size, device=q.device) + logits = logits.masked_fill(~mask, float('-inf')) + + # Seq start mask + if seq_start is not None: + seq_mask = torch.arange(T_k, device=q.device)[None, None, None, :] < seq_start[None, :, None, None] + logits = logits.masked_fill(seq_mask, float('-inf')) + + # Standard softmax + weights = F.softmax(logits, dim=-1) + + # Apply to values + out = torch.matmul(weights, v) + + if not head_first: + out = rearrange(out, "b h t d -> b t h d") + + return out + + +def create_sliding_window_mask( + T_q: int, + T_k: int, + window_size: int, + device: torch.device +) -> torch.Tensor: + """ + 创建sliding window mask + + window_size=1: 只看前1个token (2-gram) + window_size=2: 只看前2个token (3-gram) + """ + # 基础causal mask + mask = torch.tril(torch.ones(T_q, T_k, dtype=torch.bool, device=device)) + + # 应用window限制 + if window_size > 0 and window_size < T_k: + for i in range(T_q): + # 只保留 [i-window_size+1, i] 范围 + start = max(0, i - window_size + 1) + if start > 0: + mask[i, :start] = False + + return mask[None, None, :, :] # [1, 1, T_q, T_k] \ No newline at end of file diff --git a/ops/stickbreaking_attention_std.py b/ops/stickbreaking_attention_std.py new file mode 100644 index 0000000000000000000000000000000000000000..92d8be26c43829aabd11a01d1a181cb1bcee1b1d --- /dev/null +++ b/ops/stickbreaking_attention_std.py @@ -0,0 +1,46 @@ +""" +Stick-breaking Attention - 官方Triton实现 +""" + +from stickbreaking_attention.sb_attn import sb_attn +import math +import torch +from einops import rearrange +from typing import Optional + + +def stickbreaking_attention_std( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + head_first: bool = False, + seq_start: Optional[torch.Tensor] = None, + sm_scale: Optional[float] = None, + normalize: bool = True, + attend_current: bool = False, +) -> torch.Tensor: + """Stick-breaking attention using official Triton implementation""" + + if not head_first: + q = rearrange(q, "b t h d -> b h t d") + k = rearrange(k, "b t h d -> b h t d") + v = rearrange(v, "b t h d -> b h t d") + + B, H, T_q, D = q.shape + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(D) + + # 官方Triton实现 + # 返回 (output, remainder) + out, rem = sb_attn( + q, k, v, + inv_temp=sm_scale, + attend_current=attend_current + ) + + if not head_first: + out = rearrange(out, "b h t d -> b t h d") + + return out diff --git a/ops/transformer.py b/ops/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..39e9e122612a62bcbc685f818bf0810e87f5bc72 --- /dev/null +++ b/ops/transformer.py @@ -0,0 +1,165 @@ +import torch +import torch.nn +import torch.nn.functional as F +from .multi_head_attention import MultiHeadAttention, AttentionMask +from typing import Optional, Callable, Dict +from dataclasses import dataclass +# This file is based on PyTorch's internal implementation + +ActivationFunction = Callable[[torch.Tensor], torch.Tensor] + + +class TransformerEncoderLayer(torch.nn.Module): + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation: ActivationFunction = F.relu, + attention_dropout=0): + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiHeadAttention(d_model, nhead, dropout=attention_dropout) + self.linear1 = torch.nn.Linear(d_model, dim_feedforward) + self.dropout = torch.nn.Dropout(dropout) + self.linear2 = torch.nn.Linear(dim_feedforward, d_model) + + self.norm1 = torch.nn.LayerNorm(d_model) + self.norm2 = torch.nn.LayerNorm(d_model) + self.dropout1 = torch.nn.Dropout(dropout) + self.dropout2 = torch.nn.Dropout(dropout) + + self.activation = activation + self.reset_parameters() + + def forward(self, src: torch.Tensor, mask: Optional[AttentionMask] = None) -> torch.Tensor: + src2 = self.self_attn(src, src, mask) + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def reset_parameters(self): + torch.nn.init.xavier_uniform_(self.linear1.weight, gain=torch.nn.init.calculate_gain('relu') \ + if self.activation is F.relu else 1.0) + torch.nn.init.xavier_uniform_(self.linear2.weight) + + +class TransformerDecoderLayer(torch.nn.Module): + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation: ActivationFunction = F.relu, + attention_dropout=0): + super(TransformerDecoderLayer, self).__init__() + + self.self_attn = MultiHeadAttention(d_model, nhead, dropout=attention_dropout) + self.multihead_attn = MultiHeadAttention(d_model, nhead, dropout=attention_dropout) + # Implementation of Feedforward model + self.linear1 = torch.nn.Linear(d_model, dim_feedforward) + self.dropout = torch.nn.Dropout(dropout) + self.linear2 = torch.nn.Linear(dim_feedforward, d_model) + + self.norm1 = torch.nn.LayerNorm(d_model) + self.norm2 = torch.nn.LayerNorm(d_model) + self.norm3 = torch.nn.LayerNorm(d_model) + self.dropout1 = torch.nn.Dropout(dropout) + self.dropout2 = torch.nn.Dropout(dropout) + self.dropout3 = torch.nn.Dropout(dropout) + + self.activation = activation + self.reset_parameters() + + def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + full_target: Optional[torch.Tensor] = None, pos_offset: int = 0) -> torch.Tensor: + + assert pos_offset == 0 or tgt_mask is None + tgt2 = self.self_attn(tgt, tgt if full_target is None else full_target, mask=AttentionMask(None, tgt_mask)) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(tgt, memory, mask=AttentionMask(memory_key_padding_mask, None)) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def reset_parameters(self): + torch.nn.init.xavier_uniform_(self.linear1.weight, gain=torch.nn.init.calculate_gain('relu') \ + if self.activation is F.relu else 1.0) + torch.nn.init.xavier_uniform_(self.linear2.weight) + + +class TransformerDecoderBase(torch.nn.Module): + @dataclass + class State: + step: int + state: Dict[int, torch.Tensor] + + def __init__(self, d_model: int): + super().__init__() + self.d_model = d_model + + def create_state(self, batch_size: int, max_length: int, device: torch.device) -> State: + return self.State(0, {i: torch.empty([batch_size, max_length, self.d_model], device=device) + for i in range(len(self.layers))}) + + def one_step_forward(self, state: State, data: torch.Tensor, *args, **kwargs): + assert data.shape[1] == 1, f"For one-step forward should have one timesteps, but shape is {data.shape}" + assert state.step < state.state[0].shape[1] + + for i, l in enumerate(self.layers): + state.state[i][:, state.step:state.step + 1] = data + data = l(data, *args, **kwargs, full_target=state.state[i][:, :state.step + 1], + pos_offset=state.step) + + state.step += 1 + return data + + +class TransformerEncoder(torch.nn.Module): + def __init__(self, layer, n_layers: int, *args, **kwargs): + super().__init__() + self.layers = torch.nn.ModuleList([layer(*args, **kwargs) for _ in range(n_layers)]) + + def forward(self, data: torch.Tensor, *args, **kwargs): + for l in self.layers: + data = l(data, *args, **kwargs) + return data + + +class TransformerDecoder(TransformerDecoderBase): + def __init__(self, layer, n_layers: int, d_model: int, *args, **kwargs): + super().__init__(d_model) + self.layers = torch.nn.ModuleList([layer(d_model, *args, **kwargs) for _ in range(n_layers)]) + + def forward(self, data: torch.Tensor, *args, **kwargs): + for l in self.layers: + data = l(data, *args, **kwargs) + return data + + +def TransformerEncoderWithLayer(layer = TransformerEncoder): + return lambda *args, **kwargs: TransformerEncoder(layer, *args, **kwargs) + + +def TransformerDecoderWithLayer(layer = TransformerDecoder): + return lambda *args, **kwargs: TransformerDecoder(layer, *args, **kwargs) + + +class Transformer(torch.nn.Module): + def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6, + num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, + activation: ActivationFunction = F.relu, encoder_layer=TransformerEncoderWithLayer(), + decoder_layer=TransformerDecoderWithLayer(), attention_dropout: float = 0): + super().__init__() + + self.encoder = encoder_layer(num_encoder_layers, d_model, nhead, dim_feedforward, + dropout, activation, attention_dropout) + self.decoder = decoder_layer(num_decoder_layers, d_model, nhead, dim_feedforward, + dropout, activation, attention_dropout) + + def forward(self, src: torch.Tensor, tgt: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, + src_mask: Optional[AttentionMask] = None): + + memory = self.encoder(src, src_mask) + return self.decoder(tgt, memory, tgt_mask, src_mask.src_length_mask if src_mask is not None else None) + + @staticmethod + def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor: + return torch.triu(torch.ones(sz, sz, dtype=torch.bool, device=device), diagonal=1) diff --git a/ops/vanilla_attention_std.py b/ops/vanilla_attention_std.py new file mode 100644 index 0000000000000000000000000000000000000000..6e1eafefc101c388e3e633f51e442209b9802d14 --- /dev/null +++ b/ops/vanilla_attention_std.py @@ -0,0 +1,171 @@ +""" +Vanilla Transformer 的标准 Softmax Attention +用于替换 flash_attn 的实现 +""" +import math +import torch +import torch.nn.functional as F +from einops import rearrange +from typing import Optional, Tuple + +def vanilla_attention_std( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + causal: bool = True, + window_size: Optional[Tuple[int, int]] = None, + sm_scale: Optional[float] = None, +) -> torch.Tensor: + """ + 标准 Softmax Attention,兼容 flash_attn_func 的输入格式 + + Args: + q, k, v: [batch, seq_len, num_heads, head_dim] 格式 + causal: 是否使用因果mask + window_size: 滑动窗口大小 (left, right),(-1, -1) 表示无限制 + sm_scale: softmax 缩放因子 + + Returns: + output: [batch, seq_len, num_heads, head_dim] 格式 + """ + B, T_q, H, D = q.shape + T_k = k.shape[1] + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(D) + + # 转换为 [B, H, T, D] 格式进行计算 + q = rearrange(q, 'b t h d -> b h t d') + k = rearrange(k, 'b t h d -> b h t d') + v = rearrange(v, 'b t h d -> b h t d') + + # 计算 attention scores + scores = torch.matmul(q.float(), k.float().transpose(-2, -1)) * sm_scale + + # Causal mask + if causal: + P_SEQ = T_k - T_q # 处理 KV cache 的情况 + causal_mask = torch.triu( + torch.ones((T_q, T_k), dtype=torch.bool, device=q.device), + diagonal=P_SEQ + 1 + ) + scores = scores.masked_fill(causal_mask[None, None, :, :], float('-inf')) + + # Window mask (sliding window attention) + if window_size is not None and window_size != (-1, -1): + left_window, right_window = window_size + window_mask = torch.ones((T_q, T_k), dtype=torch.bool, device=q.device) + for i in range(T_q): + # 计算每个查询位置的有效窗口范围 + start = max(0, i - left_window) + end = min(T_k, i + right_window + 1) + window_mask[i, start:end] = False + scores = scores.masked_fill(window_mask[None, None, :, :], float('-inf')) + + # Softmax + attn_weights = F.softmax(scores, dim=-1) + attn_weights = torch.nan_to_num(attn_weights, 0.0) + + # Apply attention to values + output = torch.matmul(attn_weights.to(v.dtype), v) + + # 转换回 [B, T, H, D] 格式 + output = rearrange(output, 'b h t d -> b t h d') + + return output + + +def vanilla_attention_varlen_std( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + causal: bool = True, + window_size: Optional[Tuple[int, int]] = None, + sm_scale: Optional[float] = None, +) -> torch.Tensor: + """ + 变长序列的标准 Softmax Attention,兼容 flash_attn_varlen_func + + Args: + q: [total_q_tokens, num_heads, head_dim] + k: [total_k_tokens, num_kv_heads, head_dim] + v: [total_k_tokens, num_kv_heads, head_dim] + cu_seqlens_q: 累积序列长度 [batch_size + 1] + cu_seqlens_k: 累积序列长度 [batch_size + 1] + max_seqlen_q: 最大查询序列长度 + max_seqlen_k: 最大键值序列长度 + + Returns: + output: [total_q_tokens, num_heads, head_dim] + """ + batch_size = cu_seqlens_q.shape[0] - 1 + H = q.shape[1] + D = q.shape[2] + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(D) + + outputs = [] + + # 逐批次处理 + for b in range(batch_size): + q_start, q_end = cu_seqlens_q[b].item(), cu_seqlens_q[b+1].item() + k_start, k_end = cu_seqlens_k[b].item(), cu_seqlens_k[b+1].item() + + if q_start == q_end: # 空序列 + continue + + # 提取当前批次的 q, k, v + q_b = q[q_start:q_end] # [T_q, H, D] + k_b = k[k_start:k_end] # [T_k, H, D] + v_b = v[k_start:k_end] # [T_k, H, D] + + T_q = q_b.shape[0] + T_k = k_b.shape[0] + + # 转换为 [H, T, D] 格式 + q_b = rearrange(q_b, 't h d -> h t d') + k_b = rearrange(k_b, 't h d -> h t d') + v_b = rearrange(v_b, 't h d -> h t d') + + # 计算 attention scores + scores = torch.matmul(q_b.float(), k_b.float().transpose(-2, -1)) * sm_scale + + # Causal mask + if causal: + P_SEQ = T_k - T_q + causal_mask = torch.triu( + torch.ones((T_q, T_k), dtype=torch.bool, device=q.device), + diagonal=P_SEQ + 1 + ) + scores = scores.masked_fill(causal_mask[None, :, :], float('-inf')) + + # Window mask + if window_size is not None and window_size != (-1, -1): + left_window, right_window = window_size + window_mask = torch.ones((T_q, T_k), dtype=torch.bool, device=q.device) + for i in range(T_q): + start = max(0, i - left_window) + end = min(T_k, i + right_window + 1) + window_mask[i, start:end] = False + scores = scores.masked_fill(window_mask[None, :, :], float('-inf')) + + # Softmax + attn_weights = F.softmax(scores, dim=-1) + attn_weights = torch.nan_to_num(attn_weights, 0.0) + + # Apply attention + output_b = torch.matmul(attn_weights.to(v_b.dtype), v_b) + + # 转换回 [T, H, D] 格式 + output_b = rearrange(output_b, 'h t d -> t h d') + outputs.append(output_b) + + # 拼接所有批次的输出 + output = torch.cat(outputs, dim=0) + + return output diff --git a/token_shift.py b/token_shift.py new file mode 100644 index 0000000000000000000000000000000000000000..2e5550546bfd8df0e5be818e0838c78e4327716d --- /dev/null +++ b/token_shift.py @@ -0,0 +1,314 @@ +import torch + +import triton +import triton.language as tl +import pytest + +def maybe_contiguous(x): + # only when the inner most dimension is contiguous can LDGSTS be used + # so inner-dimension contiguity is enforced. + return x.contiguous() if x.stride(-1) != 1 else x + +@triton.jit +def shift_fwd_kernel( + X_PTR, + PREV_WEIGHT_PTR, + CURR_WEIGHT_PTR, + OUT_PTR, + + stride_x_b, stride_x_t, stride_x_h, stride_x_d, + stride_weight_b, stride_weight_t, stride_weight_h, + T: tl.constexpr, D: tl.constexpr, + BLOCK_T: tl.constexpr, +): + """ + everything is (B, T, D) + """ + b_offset = tl.program_id(axis=0).to(tl.int64) + t_offset = tl.program_id(axis=1).to(tl.int64) * BLOCK_T + h_offset = tl.program_id(axis=2).to(tl.int64) + + + x_ptr_offset = b_offset * stride_x_b + t_offset * stride_x_t + h_offset * stride_x_h + X_PTR += x_ptr_offset + OUT_PTR += x_ptr_offset + + weight_ptr_offset = b_offset * stride_weight_b + t_offset * stride_weight_t + h_offset * stride_weight_h + CURR_WEIGHT_PTR += weight_ptr_offset + PREV_WEIGHT_PTR += weight_ptr_offset + + x_ptr = X_PTR + tl.arange(0, BLOCK_T)[:, None] * stride_x_t + tl.arange(0, D)[None, :] * stride_x_d + t_offset_block = t_offset + tl.arange(0, BLOCK_T)[:, None] + x_mask = t_offset_block < T + + # Yeah this is correct + x_prev_ptr = x_ptr - stride_x_t + t_prev_offset_block = t_offset_block - 1 + x_prev_mask = ((t_prev_offset_block) < T) & (t_prev_offset_block >= 0) + + curr_weight_ptr = CURR_WEIGHT_PTR + tl.arange(0, BLOCK_T)[:, None] * stride_weight_t + prev_weight_ptr = PREV_WEIGHT_PTR + tl.arange(0, BLOCK_T)[:, None] * stride_weight_t + + + x = tl.load(x_ptr, mask=x_mask, other=0.0) + x_prev = tl.load(x_prev_ptr, mask=x_prev_mask, other=0.0) + curr_weight = tl.load(curr_weight_ptr, mask=x_mask, other=0.0) + prev_weight = tl.load(prev_weight_ptr, mask=x_mask, other=0.0) + + result = x * curr_weight.to(tl.float32) + x_prev * prev_weight.to(tl.float32) + result = result.to(x.dtype) + + out_ptr = OUT_PTR + tl.arange(0, BLOCK_T)[:, None] * stride_x_t + tl.arange(0, D)[None, :] * stride_x_d + tl.store(out_ptr, result, mask=x_mask) + + +@triton.jit +def shift_bwd_kernel( + X_PTR, + PREV_WEIGHT_PTR, + CURR_WEIGHT_PTR, + + DOUT_PTR, + DX_PTR, + DPREV_WEIGHT_PTR, + DCURR_WEIGHT_PTR, + + stride_x_b, stride_x_t, stride_x_h, stride_x_d, + stride_weight_b, stride_weight_t, stride_weight_h, + T: tl.constexpr, D: tl.constexpr, + BLOCK_T: tl.constexpr, +): + """ + everything is (B, T, D) + """ + b_offset = tl.program_id(axis=0).to(tl.int64) + t_offset = tl.program_id(axis=1).to(tl.int64) * BLOCK_T + h_offset = tl.program_id(axis=2).to(tl.int64) + + + x_ptr_offset = b_offset * stride_x_b + t_offset * stride_x_t + h_offset * stride_x_h + X_PTR += x_ptr_offset + DX_PTR += x_ptr_offset + DOUT_PTR += x_ptr_offset + + weight_ptr_offset = b_offset * stride_weight_b + t_offset * stride_weight_t + h_offset * stride_weight_h + CURR_WEIGHT_PTR += weight_ptr_offset + PREV_WEIGHT_PTR += weight_ptr_offset + DCURR_WEIGHT_PTR += weight_ptr_offset + DPREV_WEIGHT_PTR += weight_ptr_offset + + x_ptr = X_PTR + tl.arange(0, BLOCK_T)[:, None] * stride_x_t + tl.arange(0, D)[None, :] * stride_x_d + t_offset_block = t_offset + tl.arange(0, BLOCK_T)[:, None] + x_mask = t_offset_block < T + + dout_ptr = DOUT_PTR + tl.arange(0, BLOCK_T)[:, None] * stride_x_t + tl.arange(0, D)[None, :] * stride_x_d + + # Yeah this is correct + dout_next_ptr = dout_ptr + stride_x_t + t_next_offset_block = t_offset_block + 1 + x_next_mask = (t_next_offset_block) < T + + + # Yeah this is correct + x_prev_ptr = x_ptr - stride_x_t + t_prev_offset_block = t_offset_block - 1 + x_prev_mask = ((t_prev_offset_block) < T) & (t_prev_offset_block >= 0) + + curr_weight_ptr = CURR_WEIGHT_PTR + tl.arange(0, BLOCK_T)[:, None] * stride_weight_t + prev_weight_ptr = PREV_WEIGHT_PTR + tl.arange(0, BLOCK_T)[:, None] * stride_weight_t + next_prev_weight_ptr = prev_weight_ptr + stride_weight_t + + + x = tl.load(x_ptr, mask=x_mask, other=0.0) + x_prev = tl.load(x_prev_ptr, mask=x_prev_mask, other=0.0) + dout = tl.load(dout_ptr, mask=x_mask, other=0.0) + dout_next= tl.load(dout_next_ptr, mask=x_next_mask, other=0.0) + + curr_weight = tl.load(curr_weight_ptr, mask=x_mask, other=0.0) + next_prev_weight = tl.load(next_prev_weight_ptr, mask=x_next_mask, other=0.0) + + dx = dout * curr_weight.to(tl.float32) + dout_next * next_prev_weight.to(tl.float32) + dx = dx.to(x.dtype) + + dcurr_weight = tl.sum(dout.to(tl.float32) * x, axis=1, keep_dims=True) + dprev_weight = tl.sum(dout.to(tl.float32) * x_prev, axis=1, keep_dims=True) + + dx_ptr = DX_PTR + tl.arange(0, BLOCK_T)[:, None] * stride_x_t + tl.arange(0, D)[None, :] * stride_x_d + tl.store(dx_ptr, dx, mask=x_mask) + dcurr_weight_ptr = DCURR_WEIGHT_PTR + tl.arange(0, BLOCK_T)[:, None] * stride_weight_t + tl.store(dcurr_weight_ptr, dcurr_weight, mask=x_mask) + dprev_weight_ptr = DPREV_WEIGHT_PTR + tl.arange(0, BLOCK_T)[:, None] * stride_weight_t + tl.store(dprev_weight_ptr, dprev_weight, mask=x_mask) + + + +class TokenShift(torch.autograd.Function): + + @staticmethod + def forward(ctx, x: torch.Tensor, prev_weight: torch.Tensor, curr_weight: torch.Tensor): + + B, T, H, D = x.size() + assert D in {16, 32, 64, 128} + assert prev_weight.size() == curr_weight.size() == (B, T, H) + assert prev_weight.stride() == curr_weight.stride() + x = maybe_contiguous(x) + out = torch.empty_like(x) + + BLOCK_T = triton.next_power_of_2(min(64, T)) + + grid = lambda meta: (B, triton.cdiv(T, meta["BLOCK_T"]), H) + # NOTE: + # - Each torch.tensor object is implicitly converted into a pointer to its first element. + # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. + # - Don't forget to pass meta-parameters as keywords arguments. + shift_fwd_kernel[grid]( + x, + prev_weight, + curr_weight, + out, + *x.stride(), + *curr_weight.stride(), + T=T, D=D, + BLOCK_T=BLOCK_T, + ) + ctx.save_for_backward(x, prev_weight, curr_weight) + # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still + # running asynchronously at this point. + return out + + @staticmethod + def backward(ctx, dout: torch.Tensor): + + x, prev_weight, curr_weight = ctx.saved_tensors + B, T, H, D = x.size() + assert D in {16, 32, 64, 128} + assert prev_weight.size() == curr_weight.size() == (B, T, H) + assert prev_weight.stride() == curr_weight.stride() + x = maybe_contiguous(x) + assert dout.stride() == x.stride() + dx = torch.empty_like(x) + dcurr_weight = torch.empty_like(curr_weight) + dprev_weight = torch.empty_like(prev_weight) + + BLOCK_T = triton.next_power_of_2(min(64, T)) + + grid = lambda meta: (B, triton.cdiv(T, meta["BLOCK_T"]), H) + # NOTE: + # - Each torch.tensor object is implicitly converted into a pointer to its first element. + # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. + # - Don't forget to pass meta-parameters as keywords arguments. + shift_bwd_kernel[grid]( + x, + prev_weight, + curr_weight, + dout, + dx, + dprev_weight, + dcurr_weight, + *x.stride(), + *curr_weight.stride(), + T=T, + D=D, + BLOCK_T=BLOCK_T, + ) + # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still + # running asynchronously at this point. + return dx, dprev_weight, dcurr_weight + +def token_shift(x, prev_weight, curr_weight): + return TokenShift.apply(x, prev_weight, curr_weight) + + + +@pytest.mark.parametrize("B, T, H, D", [(4, 2048, 12, 128)]) +def test_op(B, T, H, D, dtype=torch.float32): + torch.manual_seed(24) + B = 4 + T = 2088 + H = 12 + D = 128 + # x = torch.rand(size, device='cuda') + x = torch.randn(B, T, H, D, device="cuda", dtype=dtype, requires_grad=True) + dout = torch.randn(B, T, H, D, device="cuda", dtype=dtype) + curr_weight = torch.rand(B, T, H, device="cuda", requires_grad=True) + + prev_weight = 1.0 - curr_weight + x_prev = torch.roll(x, shifts=1, dims=1) + x_prev[:, 0, :, :] = 0.0 + ref_out = (x_prev * prev_weight[..., None] + x * curr_weight[..., None]).to(dtype) + + ref_out.backward(dout) + ref_dx, x.grad = x.grad.clone(), None + ref_dcurr_weight, curr_weight.grad = curr_weight.grad.clone(), None + + + prev_weight = 1.0 - curr_weight + # out_torch = x if x.sum() > 0.0 else y + + tri_out = token_shift(x, prev_weight, curr_weight) + + + tri_out.backward(dout) + tri_dx, x.grad = x.grad.clone(), None + tri_dcurr_weight, curr_weight.grad = curr_weight.grad.clone(), None + + # out_torch = x if x.sum() > 0.0 else y + + # import pdb; pdb.set_trace() + + assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0), (ref_out - tri_out).abs().max() + assert torch.allclose(ref_dx, tri_dx, atol=1e-2, rtol=0), (ref_dx - tri_dx).abs().max() + assert torch.allclose(ref_dcurr_weight, tri_dcurr_weight, atol=1e-2, rtol=0), (ref_dcurr_weight - tri_dcurr_weight).abs().max() + +if __name__ == "__main__": + torch.manual_seed(0) + B = 4 + T = 2088 + H = 12 + D = 128 + # x = torch.rand(size, device='cuda') + x = torch.randn(B, T, H, D, device="cuda") + dout = torch.randn(B, T, H, D, device="cuda") + curr_weight = torch.rand(B, T, H, device="cuda") + prev_weight = 1.0 - curr_weight + # out_torch = x if x.sum() > 0.0 else y + result = shift_fwd(x, prev_weight, curr_weight) + print(result[0, :, 0, 0]) + import ipdb; ipdb.set_trace() + # # for mode in ["fwd", "bwd"]: + # configs.append( + # triton.testing.Benchmark( + # x_names=["SIZE"], + # # x_vals=[2**i for i in range(10, 15)], + # x_vals=[98432], + # line_arg="provider", + # # line_vals=["triton-fp16", "flag"] + (["flash"] if HAS_FLASH else []), + # # line_names=["Triton [FP16]", "Flag"] + (["Flash-2"] if HAS_FLASH else []), + # line_vals=["debug"], + # line_names=["Debug"], + # styles=[("red", "-")], + # ylabel="ms", + # plot_name="hi", + # args={}, + # ) + # ) + + + # @triton.testing.perf_report(configs) + # def bench_flash_attention(SIZE, provider, device="cuda"): + # warmup = 25 + # rep = 100 + # torch.manual_seed(0) + # size = 98432 + # # x = torch.rand(size, device='cuda') + # x = torch.ones(size, device="cuda") + # y = torch.rand(size, device="cuda") + # # out_torch = x if x.sum() > 0.0 else y + # fn = lambda: add(x, y) + # ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + # return ms + + + # if __name__ == "__main__": + # # only works on post-Ampere GPUs right now + # bench_flash_attention.run(save_path=".", print_data=True)