diff --git a/fla2/layers/__pycache__/attn.cpython-38.pyc b/fla2/layers/__pycache__/attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1e564b2ec166cf36c799c0867b4f057cc5801bd Binary files /dev/null and b/fla2/layers/__pycache__/attn.cpython-38.pyc differ diff --git a/fla2/layers/__pycache__/gla.cpython-38.pyc b/fla2/layers/__pycache__/gla.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dbb87d049450e01d3f55a438d3a4e9434cd78f3 Binary files /dev/null and b/fla2/layers/__pycache__/gla.cpython-38.pyc differ diff --git a/fla2/layers/__pycache__/gsa.cpython-312.pyc b/fla2/layers/__pycache__/gsa.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f96cb9aa242e5775f12d62e692f085efdfb24045 Binary files /dev/null and b/fla2/layers/__pycache__/gsa.cpython-312.pyc differ diff --git a/fla2/layers/__pycache__/gsa.cpython-38.pyc b/fla2/layers/__pycache__/gsa.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4e31f7d52290d6ef0da032f2a72b2ed2e6ea1b7 Binary files /dev/null and b/fla2/layers/__pycache__/gsa.cpython-38.pyc differ diff --git a/fla2/layers/__pycache__/hgrn.cpython-38.pyc b/fla2/layers/__pycache__/hgrn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40f7c35fa530ab3d498ec195fe2e9600aea4985b Binary files /dev/null and b/fla2/layers/__pycache__/hgrn.cpython-38.pyc differ diff --git a/fla2/layers/__pycache__/hgrn.cpython-39.pyc b/fla2/layers/__pycache__/hgrn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f8d970efef2f01acfc224f0cac98e33c8afae0b Binary files /dev/null and b/fla2/layers/__pycache__/hgrn.cpython-39.pyc differ diff --git a/fla2/layers/__pycache__/hgrn2.cpython-39.pyc b/fla2/layers/__pycache__/hgrn2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a22fa32e7681ab166d89a27542869036847ffe0 Binary files /dev/null and b/fla2/layers/__pycache__/hgrn2.cpython-39.pyc differ diff --git a/fla2/layers/__pycache__/linear_attn.cpython-38.pyc b/fla2/layers/__pycache__/linear_attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35a1a808c65695b36f02a9c3db5f081cb3720a87 Binary files /dev/null and b/fla2/layers/__pycache__/linear_attn.cpython-38.pyc differ diff --git a/fla2/layers/__pycache__/linear_attn.cpython-39.pyc b/fla2/layers/__pycache__/linear_attn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2e62b812b7144c8fd0ba231b235351e0fc19414 Binary files /dev/null and b/fla2/layers/__pycache__/linear_attn.cpython-39.pyc differ diff --git a/fla2/layers/__pycache__/mask_deltanet.cpython-310.pyc b/fla2/layers/__pycache__/mask_deltanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2225d3f77cc18c2c5ded98656c59f5ddbebd8f25 Binary files /dev/null and b/fla2/layers/__pycache__/mask_deltanet.cpython-310.pyc differ diff --git a/fla2/layers/__pycache__/mask_deltanet.cpython-312.pyc b/fla2/layers/__pycache__/mask_deltanet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e67fab89de7287ec55dac3c23b2501a1d7b6b357 Binary files /dev/null and b/fla2/layers/__pycache__/mask_deltanet.cpython-312.pyc differ diff --git a/fla2/layers/__pycache__/mask_gdn.cpython-310.pyc b/fla2/layers/__pycache__/mask_gdn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb02317717d5d0f0d0eb365b5c6e03c0345c1b85 Binary files /dev/null and b/fla2/layers/__pycache__/mask_gdn.cpython-310.pyc differ diff --git a/fla2/layers/__pycache__/mask_gdn.cpython-312.pyc b/fla2/layers/__pycache__/mask_gdn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e34a15c75f0e7418be0b7afb59fff893e8a7dbf7 Binary files /dev/null and b/fla2/layers/__pycache__/mask_gdn.cpython-312.pyc differ diff --git a/fla2/layers/__pycache__/multiscale_retention.cpython-39.pyc b/fla2/layers/__pycache__/multiscale_retention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4117f85350a8be56867d67ed104d6f58e0698244 Binary files /dev/null and b/fla2/layers/__pycache__/multiscale_retention.cpython-39.pyc differ diff --git a/fla2/layers/__pycache__/rebased.cpython-312.pyc b/fla2/layers/__pycache__/rebased.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c2ea84121af018a2fac1e08fd06ab50080b07bb Binary files /dev/null and b/fla2/layers/__pycache__/rebased.cpython-312.pyc differ diff --git a/fla2/layers/__pycache__/rebased.cpython-39.pyc b/fla2/layers/__pycache__/rebased.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6606e88a809ee68ac0776477a69d1486aadd6e7 Binary files /dev/null and b/fla2/layers/__pycache__/rebased.cpython-39.pyc differ diff --git a/fla2/layers/__pycache__/rwkv6.cpython-312.pyc b/fla2/layers/__pycache__/rwkv6.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..941b0019825f77e057c17db1a3c5e41de204f70e Binary files /dev/null and b/fla2/layers/__pycache__/rwkv6.cpython-312.pyc differ diff --git a/fla2/layers/__pycache__/rwkv6.cpython-38.pyc b/fla2/layers/__pycache__/rwkv6.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0077a4cb32ecb8820cbad8b098f33d4392db83d5 Binary files /dev/null and b/fla2/layers/__pycache__/rwkv6.cpython-38.pyc differ diff --git a/fla2/layers/__pycache__/rwkv6.cpython-39.pyc b/fla2/layers/__pycache__/rwkv6.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24cc7abc7339e7a1f0da88200bad87c79f9d00fd Binary files /dev/null and b/fla2/layers/__pycache__/rwkv6.cpython-39.pyc differ diff --git a/fla2/models/gsa/__pycache__/configuration_gsa.cpython-39.pyc b/fla2/models/gsa/__pycache__/configuration_gsa.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a0303581190d1898e5f8d72d2fc0001bb49d8d8 Binary files /dev/null and b/fla2/models/gsa/__pycache__/configuration_gsa.cpython-39.pyc differ diff --git a/fla2/models/hgrn/__pycache__/modeling_hgrn.cpython-38.pyc b/fla2/models/hgrn/__pycache__/modeling_hgrn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f19c17a369637c4a193636a691640a4b81055487 Binary files /dev/null and b/fla2/models/hgrn/__pycache__/modeling_hgrn.cpython-38.pyc differ diff --git a/fla2/models/hgrn/__pycache__/modeling_hgrn.cpython-39.pyc b/fla2/models/hgrn/__pycache__/modeling_hgrn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..477bdc2083e23951d5e2f73c52f7a6fb1fe57ae4 Binary files /dev/null and b/fla2/models/hgrn/__pycache__/modeling_hgrn.cpython-39.pyc differ diff --git a/fla2/models/hgrn2/__pycache__/__init__.cpython-312.pyc b/fla2/models/hgrn2/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..737de2b5f8c01ae5219e5189bc353e854439d1d0 Binary files /dev/null and b/fla2/models/hgrn2/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/hgrn2/__pycache__/configuration_hgrn2.cpython-312.pyc b/fla2/models/hgrn2/__pycache__/configuration_hgrn2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d471ddda7bd1aecb494fdb533980eb9693b1abc4 Binary files /dev/null and b/fla2/models/hgrn2/__pycache__/configuration_hgrn2.cpython-312.pyc differ diff --git a/fla2/models/hgrn2/__pycache__/configuration_hgrn2.cpython-39.pyc b/fla2/models/hgrn2/__pycache__/configuration_hgrn2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..071883d402f464ab1b2f1805c76c02f9cb3c12c8 Binary files /dev/null and b/fla2/models/hgrn2/__pycache__/configuration_hgrn2.cpython-39.pyc differ diff --git a/fla2/models/hgrn2/__pycache__/modeling_hgrn2.cpython-39.pyc b/fla2/models/hgrn2/__pycache__/modeling_hgrn2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb4caad82b2ffcf6fd1fe246461f728ce6185a34 Binary files /dev/null and b/fla2/models/hgrn2/__pycache__/modeling_hgrn2.cpython-39.pyc differ diff --git a/fla2/models/linear_attn/__pycache__/__init__.cpython-312.pyc b/fla2/models/linear_attn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d46dfda5fc8ea84730f47977fe6c6eb169c9306 Binary files /dev/null and b/fla2/models/linear_attn/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/linear_attn/__pycache__/__init__.cpython-38.pyc b/fla2/models/linear_attn/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb8b6486e6068b6378ee38002d8ddd495c30007b Binary files /dev/null and b/fla2/models/linear_attn/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/models/linear_attn/__pycache__/__init__.cpython-39.pyc b/fla2/models/linear_attn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb9d48004e5a2708a13cdc68ddee5bc96034a4e0 Binary files /dev/null and b/fla2/models/linear_attn/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/models/linear_attn/__pycache__/configuration_linear_attn.cpython-38.pyc b/fla2/models/linear_attn/__pycache__/configuration_linear_attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5a5e852a7e365c21e5c2689a5dc1500a1d56b82 Binary files /dev/null and b/fla2/models/linear_attn/__pycache__/configuration_linear_attn.cpython-38.pyc differ diff --git a/fla2/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc b/fla2/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0745acc1a840b6b19936432b6c389276eeb7e28f Binary files /dev/null and b/fla2/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc differ diff --git a/fla2/models/linear_attn/__pycache__/modeling_linear_attn.cpython-38.pyc b/fla2/models/linear_attn/__pycache__/modeling_linear_attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f167de20f446474c32edeb08c86b4fa75c635c32 Binary files /dev/null and b/fla2/models/linear_attn/__pycache__/modeling_linear_attn.cpython-38.pyc differ diff --git a/fla2/models/linear_attn/__pycache__/modeling_linear_attn.cpython-39.pyc b/fla2/models/linear_attn/__pycache__/modeling_linear_attn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c6111b2e55dc07b4a13fca68a891bf968d689e5 Binary files /dev/null and b/fla2/models/linear_attn/__pycache__/modeling_linear_attn.cpython-39.pyc differ diff --git a/fla2/models/linear_attn/configuration_linear_attn.py b/fla2/models/linear_attn/configuration_linear_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..ed4bae518434b978a725e1f2437b11751cf3d644 --- /dev/null +++ b/fla2/models/linear_attn/configuration_linear_attn.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class LinearAttentionConfig(PretrainedConfig): + + model_type = 'linear_attn' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + attn_mode: str = "fused_chunk", + feature_map: str = "elementwise_product", + tie_feature_map_qk: bool = False, + norm_q: bool = False, + norm_k: bool = False, + norm_feature_map: bool = False, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.attn_mode = attn_mode + self.feature_map = feature_map + self.tie_feature_map_qk = tie_feature_map_qk + self.norm_q = norm_q + self.norm_k = norm_k + self.norm_feature_map = norm_feature_map + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla2/models/mamba/__init__.py b/fla2/models/mamba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0eff2ea26f3a11bcf2333002509686eca2289aa --- /dev/null +++ b/fla2/models/mamba/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.mamba.configuration_mamba import MambaConfig +from fla.models.mamba.modeling_mamba import (MambaBlock, MambaForCausalLM, + MambaModel) + +AutoConfig.register(MambaConfig.model_type, MambaConfig, True) +AutoModel.register(MambaConfig, MambaModel, True) +AutoModelForCausalLM.register(MambaConfig, MambaForCausalLM, True) + + +__all__ = ['MambaConfig', 'MambaForCausalLM', 'MambaModel', 'MambaBlock'] diff --git a/fla2/models/mamba/__pycache__/__init__.cpython-38.pyc b/fla2/models/mamba/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bf45d76e9387f634eb240bae004a35437e22e24 Binary files /dev/null and b/fla2/models/mamba/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/models/mamba/__pycache__/__init__.cpython-39.pyc b/fla2/models/mamba/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f0abab48e715df82caeaf80a9d5832d55faeda1 Binary files /dev/null and b/fla2/models/mamba/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc b/fla2/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..375692319e88789c4c3246a71bf8b4e4d844a8c7 Binary files /dev/null and b/fla2/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc differ diff --git a/fla2/models/mamba/__pycache__/configuration_mamba.cpython-38.pyc b/fla2/models/mamba/__pycache__/configuration_mamba.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb563ec397812337308cda2703c05a9c55308fe9 Binary files /dev/null and b/fla2/models/mamba/__pycache__/configuration_mamba.cpython-38.pyc differ diff --git a/fla2/models/mamba/__pycache__/configuration_mamba.cpython-39.pyc b/fla2/models/mamba/__pycache__/configuration_mamba.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79fd74ab081cc0fcd07e89adb67a2068eac0f2c8 Binary files /dev/null and b/fla2/models/mamba/__pycache__/configuration_mamba.cpython-39.pyc differ diff --git a/fla2/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc b/fla2/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4539b60aed0df903c7543b6344f5dcfbc3b5b219 Binary files /dev/null and b/fla2/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc differ diff --git a/fla2/models/mamba/__pycache__/modeling_mamba.cpython-38.pyc b/fla2/models/mamba/__pycache__/modeling_mamba.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d228d2300ae7a6c7aeeff60194d1154017c111c Binary files /dev/null and b/fla2/models/mamba/__pycache__/modeling_mamba.cpython-38.pyc differ diff --git a/fla2/models/mamba/__pycache__/modeling_mamba.cpython-39.pyc b/fla2/models/mamba/__pycache__/modeling_mamba.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2752953f6c5103cce87aa594888d74309f0975ca Binary files /dev/null and b/fla2/models/mamba/__pycache__/modeling_mamba.cpython-39.pyc differ diff --git a/fla2/models/mamba/modeling_mamba.py b/fla2/models/mamba/modeling_mamba.py new file mode 100644 index 0000000000000000000000000000000000000000..1f8c44a3c389dfb2bd2209a5f422e9eae7b728cf --- /dev/null +++ b/fla2/models/mamba/modeling_mamba.py @@ -0,0 +1,606 @@ +# coding=utf-8 +# Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MAMBA model.""" + +import math +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput, logging + +from fla.models.mamba.configuration_mamba import MambaConfig +from fla.modules import FusedCrossEntropyLoss, RMSNorm + +logger = logging.get_logger(__name__) + +try: + from mamba_ssm.ops.selective_scan_interface import (mamba_inner_fn, + selective_scan_fn) + from mamba_ssm.ops.triton.selective_state_update import \ + selective_state_update +except ImportError: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_update, causal_conv1d_fn = None, None + +is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) +) + + +class MambaCache: + def __init__(self, config, batch_size, dtype=torch.float16, device=None): + self.seqlen_offset = 0 + self.dtype = dtype + intermediate_size = config.intermediate_size + ssm_state_size = config.state_size + conv_kernel_size = config.conv_kernel + + self.conv_states = { + i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) + for i in range(config.num_hidden_layers) + } + self.ssm_states = { + i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) + for i in range(config.num_hidden_layers) + } + + +class MambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.intermediate_size + self.time_step_rank = config.time_step_rank + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias + self.conv1d = nn.Conv1d( + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.intermediate_size, + padding=config.conv_kernel - 1, + ) + + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + + # projection of the input hidden states + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias) + # selective projection used to make dt, B and C input dependant + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) + # time step projection (discretization) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] + A = A.expand(self.intermediate_size, -1).contiguous() + + self.A_log = nn.Parameter(torch.log(A)) + self.D = nn.Parameter(torch.ones(self.intermediate_size)) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.use_bias = config.use_bias + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of " + "`(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the naive implementation. " + "To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Optional[MambaCache] = None): + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states).transpose(1, 2) + + if self.training and cache_params is None: # Doesn't support outputting the states -> used for training + contextualized_states = mamba_inner_fn( + projected_states, + self.conv1d.weight, + self.conv1d.bias if self.use_conv_bias else None, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias.float() if self.use_bias else None, + -torch.exp(self.A_log.float()), + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + ) + + else: + hidden_states, gate = projected_states.chunk(2, dim=1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if cache_params is not None and cache_params.seqlen_offset > 0: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_states[self.layer_idx], + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + hidden_states = causal_conv1d_fn( + hidden_states, conv_weights, self.conv1d.bias, activation=self.activation + ) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) + + A = -torch.exp(self.A_log.float()) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None + if cache_params is not None and cache_params.seqlen_offset > 0: + scan_outputs = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states[..., 0], + discrete_time_step[..., 0], + A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + return contextualized_states + + # fmt: off + def slow_forward(self, input_states, cache_params: Optional[MambaCache] = None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # 1. Gated MLP's linear projection + # [batch, 2 * intermediate_size, seq_len] + projected_states = self.in_proj(input_states).transpose(1, 2) + hidden_states, gate = projected_states.chunk(2, dim=1) + + # 2. Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + if cache_params.seqlen_offset > 0: + # [batch, intermediate_size, conv_kernel_size] + conv_state = cache_params.conv_states[self.layer_idx] + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + conv_state[:, :, -1] = hidden_states[:, :, 0] + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + # [batch, intermediate_size, 1] : decoding + hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) + else: + conv_state = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_state) + # [batch, intermediate_size, seq_len] + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) + # [batch, intermediate_size, seq_len] + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) + + # 3. State Space Model sequence transformation + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + # [batch, seq_len, intermediate_size] + discrete_time_step = self.dt_proj(time_step) + # [batch, intermediate_size, seq_len] + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) + + # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) + # [intermediate_size, ssm_state_size] + A = -torch.exp(self.A_log.float()) + # [batch, intermediate_size, seq_len, ssm_state_size] + discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) + # [batch, intermediade_size, seq_len, ssm_state_size] + discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() + deltaB_u = discrete_B * hidden_states[:, :, :, None].float() + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + scan_outputs = [] + for i in range(seq_len): + # [batch, intermediade_size, ssm_state] + ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] + # [batch, intermediade_size, 1] + scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) + scan_outputs.append(scan_output[:, :, 0]) + # [batch, seq_len, intermediade_size] + scan_output = torch.stack(scan_outputs, dim=-1) + scan_output = scan_output + (hidden_states * self.D[None, :, None]) + scan_output = (scan_output * self.act(gate)) + + if cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + # 4. Final linear projection + # [batch, seq_len, hidden_size] + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) + return contextualized_states + # fmt: on + + def forward(self, hidden_states, cache_params: Optional[MambaCache] = None): + if is_fast_path_available and "cuda" in self.x_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params) + return self.slow_forward(hidden_states, cache_params) + + +class MambaBlock(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.mixer = MambaMixer(config, layer_idx=layer_idx) + + def forward(self, hidden_states, cache_params: Optional[MambaCache] = None): + residual = hidden_states + hidden_states = self.norm(hidden_states) + # if self.residual_in_fp32: + # residual = residual.to(torch.float32) + hidden_states = self.mixer(hidden_states, cache_params=cache_params) + hidden_states = residual + hidden_states + return hidden_states + + +class MambaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MambaConfig + base_model_prefix = "backbone" + _no_split_modules = ["MambaBlock"] + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, MambaMixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale + if self.config.time_step_init_scheme == "constant": + nn.init.constant_(module.dt_proj.weight, dt_init_std) + elif self.config.time_step_init_scheme == "random": + nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std) + + dt = torch.exp( + torch.rand(self.config.intermediate_size) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + module.dt_proj.bias.copy_(inv_dt) + module.dt_proj.bias._no_reinit = True + + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.num_layers) + + +@dataclass +class MambaOutput(ModelOutput): + """ + Class for the MAMBA model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`MambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cache_params: Optional[MambaCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MambaCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`MambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[MambaCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class MambaModel(MambaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + cache_params: Optional[MambaCache] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it + ) -> Union[Tuple, MambaOutput]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if cache_params is None and use_cache: + cache_params = MambaCache( + self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + + hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None + for mixer_block in self.layers: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func(mixer_block.__call__, hidden_states, cache_params) + else: + hidden_states = mixer_block(hidden_states, cache_params=cache_params) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if use_cache: + cache_params.seqlen_offset += inputs_embeds.shape[1] + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return MambaOutput( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + ) + + +class MambaForCausalLM(MambaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.backbone = MambaModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def _update_model_kwargs_for_generation( + self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs + ) -> Dict[str, Any]: + model_kwargs["cache_params"] = outputs.get("cache_params", None) + return model_kwargs + + def prepare_inputs_for_generation( + self, input_ids, cache_params: Optional[MambaCache] = None, inputs_embeds=None, attention_mask=None, **kwargs + ): + # only last token for inputs_ids if the state is passed along. + if cache_params is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + + if inputs_embeds is not None and cache_params is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs["cache_params"] = cache_params + return model_inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_params: Optional[MambaCache] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + **kwargs, # for now we need this for generation + ) -> Union[Tuple, MambaCausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + mamba_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + ) + hidden_states = mamba_outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + mamba_outputs[1:] + return (loss,) + output if loss is not None else output + + return MambaCausalLMOutput( + loss=loss, + logits=logits, + cache_params=mamba_outputs.cache_params, + hidden_states=mamba_outputs.hidden_states, + ) diff --git a/fla2/models/mamba2/__init__.py b/fla2/models/mamba2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0b8ac62a700590e06d1e524979b2f21353aa5188 --- /dev/null +++ b/fla2/models/mamba2/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.mamba2.configuration_mamba2 import Mamba2Config +from fla.models.mamba2.modeling_mamba2 import Mamba2ForCausalLM, Mamba2Model + +AutoConfig.register(Mamba2Config.model_type, Mamba2Config, True) +AutoModel.register(Mamba2Config, Mamba2Model, True) +AutoModelForCausalLM.register(Mamba2Config, Mamba2ForCausalLM, True) + + +__all__ = ['Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model'] diff --git a/fla2/models/mamba2/__pycache__/__init__.cpython-312.pyc b/fla2/models/mamba2/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26ab2741a44be27196ec9608d512b8aa188b3429 Binary files /dev/null and b/fla2/models/mamba2/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/mamba2/__pycache__/__init__.cpython-38.pyc b/fla2/models/mamba2/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d27f37c1302eb7f1557ca0fd30ba577450d35a9 Binary files /dev/null and b/fla2/models/mamba2/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/models/mamba2/__pycache__/__init__.cpython-39.pyc b/fla2/models/mamba2/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22b52f95da8afd55805a9565405eccfced4ebfd6 Binary files /dev/null and b/fla2/models/mamba2/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc b/fla2/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87797416818e9838a5247cf3e2690337c9e90457 Binary files /dev/null and b/fla2/models/mamba2/__pycache__/configuration_mamba2.cpython-312.pyc differ diff --git a/fla2/models/mamba2/__pycache__/configuration_mamba2.cpython-38.pyc b/fla2/models/mamba2/__pycache__/configuration_mamba2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aec3c0346a6842a44fea98eee2549d73d19d3d1c Binary files /dev/null and b/fla2/models/mamba2/__pycache__/configuration_mamba2.cpython-38.pyc differ diff --git a/fla2/models/mamba2/__pycache__/configuration_mamba2.cpython-39.pyc b/fla2/models/mamba2/__pycache__/configuration_mamba2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..934a51fae0551bc11ca6d15a73573e7880d67bb8 Binary files /dev/null and b/fla2/models/mamba2/__pycache__/configuration_mamba2.cpython-39.pyc differ diff --git a/fla2/models/mamba2/__pycache__/modeling_mamba2.cpython-312.pyc b/fla2/models/mamba2/__pycache__/modeling_mamba2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ede5c6cae13adb7952051d621f57fe454f7f0812 Binary files /dev/null and b/fla2/models/mamba2/__pycache__/modeling_mamba2.cpython-312.pyc differ diff --git a/fla2/models/mamba2/__pycache__/modeling_mamba2.cpython-38.pyc b/fla2/models/mamba2/__pycache__/modeling_mamba2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2edaa110eb1de153d95454f6c285f9a7bf551b40 Binary files /dev/null and b/fla2/models/mamba2/__pycache__/modeling_mamba2.cpython-38.pyc differ diff --git a/fla2/models/mamba2/__pycache__/modeling_mamba2.cpython-39.pyc b/fla2/models/mamba2/__pycache__/modeling_mamba2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b797f12fc6c2148b8055a3e94db2def3663bde1c Binary files /dev/null and b/fla2/models/mamba2/__pycache__/modeling_mamba2.cpython-39.pyc differ diff --git a/fla2/models/mamba2/configuration_mamba2.py b/fla2/models/mamba2/configuration_mamba2.py new file mode 100644 index 0000000000000000000000000000000000000000..81cc3451135d8917da98fab213914a16c268e1f7 --- /dev/null +++ b/fla2/models/mamba2/configuration_mamba2.py @@ -0,0 +1,172 @@ +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MAMBA2 configuration""" + +import math + +from transformers.configuration_utils import PretrainedConfig + + +class Mamba2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Mamba2Model`]. It is used to instantiate a MAMBA2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the MAMBA2 + [state-spaces/mamba2-2.8b](https://huggingface.co/state-spaces/mamba2-2.8b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + num_heads (`int`, *optional*, defaults to 64): + Number of heads for the evolution matrices of mamba 2. + head_dim (`int`, *optional*, defaults to 64): + Dimension of each head. + vocab_size (`int`, *optional*, defaults to 32768): + Vocabulary size of the MAMBA2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Mamba2Model`]. + hidden_size (`int`, *optional*, defaults to 2048): + Dimensionality of the embeddings and hidden states. + state_size (`int`, *optional*, defaults to 64): shape of the state space latents. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the model. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the end of sentence token in the vocabulary. + expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. + conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel. + n_groups (`int`, *optional*, defaults to 8): + Number of groups for the evolution matrices of mamba 2. + use_bias (`bool`, *optional*, defaults to `False`): + Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block + use_conv_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias in the convolution layer of the mixer block. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.1): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + residual_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether or not residuals should be in `float32`. + If set to `False` residuals will keep the same `dtype` as the rest of the model + time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the discretization projection matrix. + `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` + time_step_min (`float`, *optional*, defaults to 0.001): + Minimum `time_step` used to bound `dt_proj.bias`. + time_step_max (`float`, *optional*, defaults to 0.1): + Maximum `time_step` used to bound `dt_proj.bias`. + time_step_floor (`float`, *optional*, defaults to 0.0001): + Minimum clamping value of the `dt_proj.bias` layer initialization. + time_step_limit (`tuple`, *optional*, defaults to `(0.0, inf)`): + Accepted range of time step values. + rescale_prenorm_residual (`bool`, *optional*, defaults to `False`): + Whether or not to rescale `out_proj` weights when initializing. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the cache should be used. + norm_before_gate (`bool`, *optional*, defaults to `True`): + Option of cuda kernels -whether to normalize before the gate or not. + rms_norm (`bool`, *optional*, defaults to `True`): + Whether to use RMS norm or not. + chunk_size (`int`, *optional*, defaults to 256): + Size of the chunks that will comprise the sequence. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie word embeddings or not. + """ + + model_type = "mamba2" + + def __init__( + self, + num_heads: int = 64, + head_dim: int = 64, + vocab_size: int = 32000, + hidden_size: int = 2048, + state_size: int = 64, + num_hidden_layers: int = 48, + layer_norm_epsilon: float = 1e-5, + pad_token_id: int = 0, + bos_token_id: int = 1, + eos_token_id: int = 2, + expand: int = 2, + conv_kernel: int = 4, + n_groups: int = 8, + use_bias: bool = False, + use_conv_bias: bool = True, + hidden_act: str = "silu", + initializer_range: float = 0.1, + residual_in_fp32: bool = True, + time_step_rank: str = "auto", + time_step_min: float = 0.001, + time_step_max: float = 0.1, + time_step_floor: float = 1e-4, + time_step_limit=(0.0, float("inf")), + rescale_prenorm_residual: bool = False, + use_cache: bool = True, + norm_before_gate: bool = True, + rms_norm: bool = True, + chunk_size: int = 256, + fuse_cross_entropy: bool = True, + tie_word_embeddings: bool = False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.state_size = state_size + self.num_hidden_layers = num_hidden_layers + self.layer_norm_epsilon = layer_norm_epsilon + self.conv_kernel = conv_kernel + self.expand = expand + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.use_bias = use_bias + self.use_conv_bias = use_conv_bias + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.time_step_rank = ( + math.ceil(self.hidden_size / 16) + if time_step_rank == "auto" + else time_step_rank + ) + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.time_step_floor = time_step_floor + self.rescale_prenorm_residual = rescale_prenorm_residual + self.residual_in_fp32 = residual_in_fp32 + self.use_cache = use_cache + self.n_groups = n_groups + self.num_heads = num_heads + self.head_dim = head_dim + self.norm_before_gate = norm_before_gate + self.rms_norm = rms_norm + self.state_size = state_size + self.chunk_size = chunk_size + self.time_step_limit = time_step_limit + self.fuse_cross_entropy = fuse_cross_entropy + self.tie_word_embeddings = tie_word_embeddings + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla2/models/mamba2/modeling_mamba2.py b/fla2/models/mamba2/modeling_mamba2.py new file mode 100644 index 0000000000000000000000000000000000000000..c4605113fc0b032350eb774ee004f14d66e1c429 --- /dev/null +++ b/fla2/models/mamba2/modeling_mamba2.py @@ -0,0 +1,1077 @@ +# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MAMBA2 model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput, logging + +from fla.models.mamba2.configuration_mamba2 import Mamba2Config +from fla.modules import FusedCrossEntropyLoss, FusedRMSNormSwishGate, RMSNorm + +logger = logging.get_logger(__name__) + +try: + from mamba_ssm.ops.triton.selective_state_update import \ + selective_state_update + from mamba_ssm.ops.triton.ssd_combined import ( + mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined) +except ImportError: + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + ) = (None, None, None) + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_update, causal_conv1d_fn = None, None + +is_fast_path_available = all( + (selective_state_update, causal_conv1d_fn, causal_conv1d_update) +) + + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> + # [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +class Mamba2Cache: + """ + Arguments: + config: Mamba2Config + batch_size: int + dtype: torch.dtype + device: torch.device + + Attributes: + seqlen_offset: int + dtype: torch.dtype + conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size] + ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size] + """ + + def __init__( + self, + config: Mamba2Config, + batch_size: int, + dtype: torch.dtype = torch.float16, + device: Optional[str] = None, + ): + self.seqlen_offset = 0 + self.dtype = dtype + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = int(config.expand * config.hidden_size) + + self.conv_states = { + i: torch.zeros( + batch_size, + self.intermediate_size + 2 * config.n_groups * config.state_size, + self.conv_kernel_size, + device=device, + dtype=dtype, + ) + for i in range(config.num_hidden_layers) + } + self.ssm_states = { + i: torch.zeros( + batch_size, + config.num_heads, + config.head_dim, + config.state_size, + device=device, + dtype=dtype, + ) + for i in range(config.num_hidden_layers) + } + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + + def update_conv_state( + self, + layer_idx: int, + new_conv_state: torch.Tensor, + cache_position: torch.LongTensor, + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +class Mamba2Mixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: Mamba2Config, layer_idx: int): + super().__init__() + self.num_heads = config.num_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = int(config.expand * self.hidden_size) + self.time_step_rank = int(config.time_step_rank) + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + + self.norm_before_gate = config.norm_before_gate + self.layer_norm_epsilon = config.layer_norm_epsilon + self.rms_norm = config.rms_norm + + self.n_groups = config.n_groups + self.head_dim = config.head_dim + self.chunk_size = config.chunk_size + + self.time_step_limit = config.time_step_limit + self.time_step_min = config.time_step_min + self.time_step_max = config.time_step_max + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.conv_dim, + padding=config.conv_kernel - 1, + ) + + # projection of the input hidden state + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=config.use_bias, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.norm = FusedRMSNormSwishGate( + self.intermediate_size, eps=self.layer_norm_epsilon + ) + + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear( + self.intermediate_size, self.hidden_size, bias=config.use_bias + ) + self.use_bias = config.use_bias + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because one of " + "`(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. " + "Falling back to the naive implementation. " + "To install follow https://github.com/state-spaces/mamba/#installation and" + "https://github.com/Dao-AILab/causal-conv1d" + ) + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads + # getting projected states from cache if it exists + if cache_params is not None and cache_params.seqlen_offset > 0: + in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 + split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads] + _, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1) + + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [ + self.intermediate_size, + groups_time_state_size, + groups_time_state_size, + ], + dim=-1, + ) + A = -torch.exp(self.A_log.float()) # (nheads,) + + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view( + batch_size, self.num_heads * self.head_dim + ) + hidden_states = self.norm(hidden_states, o=gate) + + out = self.out_proj(hidden_states)[:, None, ...] + # if no cache is found, calling the kernel + else: + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states) + A = -torch.exp( + self.A_log.float() + ) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = ( + {} + if self.time_step_limit == (0.0, float("inf")) + else {"dt_limit": self.time_step_limit} + ) + + if self.training and cache_params is None: + out, ssm_state = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.eps, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=self.norm_before_gate, + return_final_states=True, + **dt_limit_kwargs, + ) + + else: + gate, hidden_states_B_C, time_step = torch.split( + projected_states, + [self.intermediate_size, self.conv_dim, self.num_heads], + dim=-1, + ) + + time_step = nn.functional.softplus(time_step + self.dt_bias) + # 1D Convolution + if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] + ) # (B, L, self.d_inner + 2 * ngroups * d_state) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2)[:, :seq_len] + hidden_states, B, C = torch.split( + hidden_states_B_C, + [ + self.intermediate_size, + groups_time_state_size, + groups_time_state_size, + ], + dim=-1, + ) + + if ( + attention_mask is not None + and attention_mask.shape[1] > 1 + and attention_mask.shape[0] > 1 + ): + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view( + batch_size, + seq_len, + -1, + self.head_dim, + ), + time_step, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + **dt_limit_kwargs, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + scan_output = scan_output.view( + batch_size, seq_len, -1 + ) + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, o=gate) + out = self.out_proj(scan_output) + return out + + # fmt: off + def torch_forward( + self, + input_states, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # Gated MLP's linear projection + projected_states = self.in_proj(input_states.squeeze(1)) + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 + * self.n_groups * self.ssm_state_size - self.num_heads) // 2 + _, _, gate, hidden_states, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + if cache_params.seqlen_offset > 0: + conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + # handle batched generation - states are copied through + conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding + else: + hidden_states = hidden_states.transpose(1, 2) + conv_state = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = self.act(self.conv1d( + hidden_states).transpose(1, 2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + else: + ssm_state = torch.zeros( + (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) + hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, + self.n_groups * self.ssm_state_size], dim=-1) + A = -torch.exp(self.A_log.float()) # [num_heads] + if cache_params is not None and cache_params.seqlen_offset > 0: + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_min) # , self.time_step_max) + A = A[..., None, None].expand(self.num_heads, self.head_dim, + self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = torch.exp(dt[..., None] * A) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads + // self.n_groups, B.shape[-1]).contiguous() + + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = dB * hidden_states[..., None] + + # State calculation + cache_params.ssm_states[self.layer_idx].copy_( + cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads + // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_min) # , self.time_step_max) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = self.chunk_size - (seq_len % self.chunk_size) + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # First, contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + # Step 2: Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Step 3: Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3) + + # (right term of low-rank factorization of off-diagonal blocks; B terms) + + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] + # permute back B * decay states + states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] + * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) + + if cache_params is not None and cache_params.seqlen_offset > 0: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + + states_permuted = states.permute(0, 2, 1, 3, 4) + result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2) + new_states = result.permute(0, 2, 1, 3, 4) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + # compute Yoff + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + + # move reshape to naive method + y = y.reshape(batch_size, seq_len, -1) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = self.norm(y, o=gate) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward( + hidden_states, cache_params, cache_position, attention_mask + ) + dtype = hidden_states.dtype + if ( + attention_mask is not None + and attention_mask.shape[1] > 1 + and attention_mask.shape[0] > 1 + ): + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward( + hidden_states, cache_params, cache_position, attention_mask + ) + + +class Mamba2Block(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.mixer = Mamba2Mixer(config, layer_idx=layer_idx) + + def forward( + self, + hidden_states, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + residual = hidden_states + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + hidden_states = self.mixer( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + return hidden_states + + +class Mamba2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Mamba2Config + base_model_prefix = "backbone" + _no_split_modules = ["Mamba2Block"] + supports_gradient_checkpointing = True + _is_stateful = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, Mamba2Mixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt = torch.exp( + torch.rand(self.config.num_heads) + * ( + math.log(self.config.time_step_max) + - math.log(self.config.time_step_min) + ) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + module.dt_bias.copy_(inv_dt) + module.dt_bias._no_reinit = True + + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.num_hidden_layers) + + +@dataclass +# Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2 +class Mamba2Output(ModelOutput): + """ + Class for the MAMBA2 model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`Mamba2Cache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cache_params: Optional[Mamba2Cache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->Mamba2 +class Mamba2CausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`Mamba2Cache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[Mamba2Cache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class Mamba2Model(Mamba2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [ + Mamba2Block(config, layer_idx=idx) + for idx in range(config.num_hidden_layers) + ] + ) + + self.gradient_checkpointing = False + self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + # Initialize weights and apply final processing + self._register_load_state_dict_pre_hook(self.load_hook) + self.post_init() + + def load_hook(self, state_dict, prefix, *args): + for k in state_dict: + if "embedding." in k: + state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) + break + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + cache_params: Optional[Mamba2Cache] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, Mamba2Output]: + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = ( + use_cache + if use_cache is not None + else (self.config.use_cache if not self.training else False) + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if use_cache: + if cache_params is None: + cache_params = Mamba2Cache( + self.config, + inputs_embeds.size(0), + device=inputs_embeds.device, + dtype=inputs_embeds.dtype, + ) + cache_position = torch.arange( + 0, self.config.conv_kernel, device=inputs_embeds.device + ) + elif cache_position is None: + # cases when we do manual forward instead of using `model.generate` which will initiate + # `cache_position` and makes sure it is not None, throw error here instead of doing some + # hack to conjecture the current cache position + raise ValueError( + "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, " + "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will " + "be initialized for you automatically" + ) + else: + cache_params = None + + hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None + for mixer_block in self.layers: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, + hidden_states, + cache_params, + cache_position, + attention_mask, + ) + else: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if use_cache: + cache_params.seqlen_offset += inputs_embeds.shape[1] + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, cache_params, all_hidden_states] + if v is not None + ) + + return Mamba2Output( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + ) + + +class Mamba2ForCausalLM(Mamba2PreTrainedModel): + _tied_weights_keys = [] + + def __init__(self, config): + super().__init__(config) + self.backbone = Mamba2Model(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def prepare_inputs_for_generation( + self, + input_ids, + inputs_embeds=None, + use_cache=None, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + if input_ids.shape[1] == 0: + past_len = inputs_embeds.shape[1] + else: + past_len = input_ids.shape[1] + if use_cache: + # `cache_position` should have been initialized in `generate` + if cache_position is None: + raise ValueError( + "`cache_position` should not be None as it should have been initialized in " + "`model.generate`, you are responsible for passing in a valid `cache_position` if " + "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" + ) + # how do we detect that we are in decoding without cache? + if cache_position[0] > 0: + input_ids = input_ids[:, -1][..., None] + attention_mask = attention_mask[:, -1][..., None] + else: + # we initialize the `cache_position` to full size of `conv_states` at prefill stage + # considering padding will be applied when input length is shorter, and truncation + # will be applied when it is longer, so it will be equivalent to always have it match + # the length of `cache_params.conv_states`, which is `config.conv_kernel` + cache_position = torch.arange( + 0, past_len, device=input_ids.device + ) + # if the cache is not used, we also do have to extend the attention mask here + # TODO there is likely a cleverer way to do this + extended_mask = torch.ones( + attention_mask.size(0), + past_len - attention_mask.shape[1], + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, extended_mask], dim=1) + cache_params = None + if attention_mask.shape[1] < past_len: + # we have to update manually the attention mask if + # we are in decoding without cache + # and we don't have position_ids here + # TODO but we should be able to use cache_position though at a later time + extended_mask = torch.ones( + attention_mask.size(0), + past_len - attention_mask.shape[1], + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, extended_mask], dim=1) + if inputs_embeds is not None and cache_params is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "attention_mask": attention_mask, + "cache_params": cache_params, + "use_cache": use_cache, + "cache_position": cache_position, + } + ) + return model_inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_params: Optional[Mamba2Cache] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, # for now we need this for generation + ) -> Union[Tuple, Mamba2CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + mamba2_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = mamba2_outputs[0] + + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + mamba2_outputs[1:] + return (loss,) + output if loss is not None else output + + return Mamba2CausalLMOutput( + loss=loss, + logits=logits, + cache_params=mamba2_outputs.cache_params, + hidden_states=mamba2_outputs.hidden_states, + ) diff --git a/fla2/models/mask_deltanet/__init__.py b/fla2/models/mask_deltanet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2dbfc8c1eeebd06ff9a520a17301919b7c63c92c --- /dev/null +++ b/fla2/models/mask_deltanet/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from .configuration_mask_deltanet import mask_deltanetConfig +from .modeling_mask_deltanet import mask_deltanetForCausalLM, mask_deltanetModel + +AutoConfig.register(mask_deltanetConfig.model_type, mask_deltanetConfig) +AutoModel.register(mask_deltanetConfig, mask_deltanetModel) +AutoModelForCausalLM.register(mask_deltanetConfig, mask_deltanetForCausalLM) + +__all__ = ['mask_deltanetConfig', 'mask_deltanetForCausalLM', 'mask_deltanetModel'] diff --git a/fla2/models/mask_deltanet/__pycache__/__init__.cpython-310.pyc b/fla2/models/mask_deltanet/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb04d4cbd462d6d3fe60eb3a1a2b999ab21e6a73 Binary files /dev/null and b/fla2/models/mask_deltanet/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/models/mask_deltanet/__pycache__/__init__.cpython-312.pyc b/fla2/models/mask_deltanet/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..170a6a16369ad351692aadabdd6449b303039c66 Binary files /dev/null and b/fla2/models/mask_deltanet/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/mask_deltanet/__pycache__/configuration_emdeltanet.cpython-310.pyc b/fla2/models/mask_deltanet/__pycache__/configuration_emdeltanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53e484b37cea42c743a3291b90823b3bcf722e90 Binary files /dev/null and b/fla2/models/mask_deltanet/__pycache__/configuration_emdeltanet.cpython-310.pyc differ diff --git a/fla2/models/mask_deltanet/__pycache__/configuration_emdeltanet.cpython-312.pyc b/fla2/models/mask_deltanet/__pycache__/configuration_emdeltanet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..936c59211678c79d020b9af30bc180bc1c407e02 Binary files /dev/null and b/fla2/models/mask_deltanet/__pycache__/configuration_emdeltanet.cpython-312.pyc differ diff --git a/fla2/models/mask_deltanet/__pycache__/configuration_emgla.cpython-310.pyc b/fla2/models/mask_deltanet/__pycache__/configuration_emgla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9c664fd9d4914d4f96f85972368580a45bea817 Binary files /dev/null and b/fla2/models/mask_deltanet/__pycache__/configuration_emgla.cpython-310.pyc differ diff --git a/fla2/models/mask_deltanet/__pycache__/configuration_mask_deltanet.cpython-310.pyc b/fla2/models/mask_deltanet/__pycache__/configuration_mask_deltanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f45736723c78c01e9c9d081eb5a4b6042245abc Binary files /dev/null and b/fla2/models/mask_deltanet/__pycache__/configuration_mask_deltanet.cpython-310.pyc differ diff --git a/fla2/models/mask_deltanet/__pycache__/configuration_mask_deltanet.cpython-312.pyc b/fla2/models/mask_deltanet/__pycache__/configuration_mask_deltanet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..510e81814dd514f3837f8bcce1bcdfb1f309d843 Binary files /dev/null and b/fla2/models/mask_deltanet/__pycache__/configuration_mask_deltanet.cpython-312.pyc differ diff --git a/fla2/models/mask_deltanet/__pycache__/modeling_emdeltanet.cpython-310.pyc b/fla2/models/mask_deltanet/__pycache__/modeling_emdeltanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4239c45100ed9a6e42879bf733538c9f87e6527 Binary files /dev/null and b/fla2/models/mask_deltanet/__pycache__/modeling_emdeltanet.cpython-310.pyc differ diff --git a/fla2/models/mask_deltanet/__pycache__/modeling_emdeltanet.cpython-312.pyc b/fla2/models/mask_deltanet/__pycache__/modeling_emdeltanet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b86c5c70db19112ea71f2a008b7d4a72942ccd28 Binary files /dev/null and b/fla2/models/mask_deltanet/__pycache__/modeling_emdeltanet.cpython-312.pyc differ diff --git a/fla2/models/mask_deltanet/__pycache__/modeling_emgla.cpython-310.pyc b/fla2/models/mask_deltanet/__pycache__/modeling_emgla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b15c2e90086bd98e7bb2c3a4dbc409f774ac8649 Binary files /dev/null and b/fla2/models/mask_deltanet/__pycache__/modeling_emgla.cpython-310.pyc differ diff --git a/fla2/models/mask_deltanet/__pycache__/modeling_mask_deltanet.cpython-310.pyc b/fla2/models/mask_deltanet/__pycache__/modeling_mask_deltanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..096ff1482ec0e88f40c9913e7951036852d934cb Binary files /dev/null and b/fla2/models/mask_deltanet/__pycache__/modeling_mask_deltanet.cpython-310.pyc differ diff --git a/fla2/models/mask_deltanet/__pycache__/modeling_mask_deltanet.cpython-312.pyc b/fla2/models/mask_deltanet/__pycache__/modeling_mask_deltanet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..929c222cc31c60e6d8cc45fe331dd0f78c0a74f0 Binary files /dev/null and b/fla2/models/mask_deltanet/__pycache__/modeling_mask_deltanet.cpython-312.pyc differ diff --git a/fla2/models/mask_deltanet/configuration_mask_deltanet.py b/fla2/models/mask_deltanet/configuration_mask_deltanet.py new file mode 100644 index 0000000000000000000000000000000000000000..cc0c99a5e1720bfff804397c4c1ea313c735ada8 --- /dev/null +++ b/fla2/models/mask_deltanet/configuration_mask_deltanet.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + + +class mask_deltanetConfig(PretrainedConfig): + + model_type = 'mask_deltanet' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + use_gate: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + use_beta: bool = True, + use_output_norm: bool = True, + num_heads: int = 16, + qk_norm: str = 'l2', + qk_activation: str = 'silu', + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + num_hidden_layers: int = 24, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + aux_loss_scale:float = 0.01, + ratio : int = 2, + topk : int = 1, + **kwargs + ): + self.attn_mode = attn_mode + self.aux_loss_scale = aux_loss_scale + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_beta = use_beta + self.use_output_norm = use_output_norm + self.num_heads = num_heads + self.qk_norm = qk_norm + self.qk_activation = qk_activation + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + self.ratio = ratio + self.topk = topk + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla2/models/mask_deltanet/modeling_mask_deltanet.py b/fla2/models/mask_deltanet/modeling_mask_deltanet.py new file mode 100644 index 0000000000000000000000000000000000000000..95d0537bb6f8eb5f5d1cf9ec43e77aaa6b55f138 --- /dev/null +++ b/fla2/models/mask_deltanet/modeling_mask_deltanet.py @@ -0,0 +1,535 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union,Iterable +from einops import rearrange +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from ...layers.attn import Attention +from ...layers.mask_deltanet import mask_deltanet +from ...models.mask_deltanet.configuration_mask_deltanet import mask_deltanetConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as mask_deltanetMLP +from ...modules import RMSNorm +from ...modules import RotaryEmbedding +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +class mask_deltanetBlock(nn.Module): + def __init__(self, config: mask_deltanetConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + print(config) + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = mask_deltanet( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_short_conv=config.use_short_conv, + use_output_norm=config.use_output_norm, + conv_size=config.conv_size, + norm_eps=config.norm_eps, + ratio = config.ratio, + topk = config.topk, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = mask_deltanetMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values,router_logits = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values,router_logits) + + return outputs + + +class mask_deltanetPreTrainedModel(PreTrainedModel): + + config_class = mask_deltanetConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['mask_deltanetBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class mask_deltanetModel(mask_deltanetPreTrainedModel): + + def __init__(self, config: mask_deltanetConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([mask_deltanetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`mask_deltanetModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + all_router_logits = () + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values, router_logits = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values, router_logits = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + all_router_logits += (router_logits,) + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + # return BaseModelOutputWithPast( + # last_hidden_state=hidden_states, + # past_key_values=past_key_values, + # hidden_states=all_hidden_states, + # attentions=all_attns + # ) + return mask_deltanetOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns, + router_logits=all_router_logits + ) + + +from dataclasses import dataclass +@dataclass +class mask_deltanetOutputWithPast(BaseModelOutputWithPast): + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class mask_deltanetCausalLMOutputWithPast(CausalLMOutputWithPast): + aux_loss: Optional[torch.FloatTensor] = None + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple], + num_memories: torch.Tensor = None, + top_k=2, + use_layer_wise_balance=False, +) -> torch.FloatTensor: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_memories]. + num_memories (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or ( + isinstance(gate_logits, Iterable) and len(gate_logits) == 0 + ): + return 0 + + # ✨ Here is the fix for balance loss in Mixtral. + # We should calculate the balance loss in a layer-wise manner otherwise it may lead to degenerated solutions. + if use_layer_wise_balance: + if not isinstance(gate_logits, Iterable): + gate_logits = (gate_logits,) + else: + if isinstance(gate_logits, Iterable): + gate_logits = (torch.cat(gate_logits, dim=0),) + else: + gate_logits = (gate_logits,) + + all_balance_losses = [] + + for logits in gate_logits: + if logits.dim() == 4: + logits = rearrange(logits,'b h l r-> (b h) l r') + routing_weights, selected_experts = torch.topk(logits, top_k, dim=-1) + routing_weights = routing_weights.softmax(dim=-1).to(logits.dtype) + routing_weights_full = torch.zeros_like(logits).scatter(-1, selected_experts, routing_weights) + + # cast the expert indices to int64, otherwise one-hot encoding will fail + if selected_experts.dtype != torch.int64: + selected_experts = selected_experts.to(torch.int64) + + if len(selected_experts.shape) == 2: + selected_experts = selected_experts.unsqueeze(2) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_memories) + + # For a given token, determine if it was routed to a given expert. + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(routing_weights_full, axis=-2) + + # ✨ balance loss for this layer + balance_loss = torch.mean( + tokens_per_group_and_expert * router_prob_per_group_and_expert + ) * (num_memories**2) + all_balance_losses.append(balance_loss.reshape(1)) + + all_balance_losses = torch.cat(all_balance_losses).mean() # ✨ + + return all_balance_losses + + +class mask_deltanetForCausalLM(mask_deltanetPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = mask_deltanetModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + valid_router_logits = tuple( + logits + for logits in (outputs.router_logits if return_dict else outputs[-1]) + if logits is not None + ) + aux_loss = load_balancing_loss_func( + valid_router_logits, + self.config.ratio, + self.config.topk, + use_layer_wise_balance=True, + ) + aux_loss *= self.config.aux_loss_scale + # print('aux_loss:',aux_loss) + if aux_loss: + loss += aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + # return CausalLMOutputWithPast( + # loss=loss, + # logits=logits, + # past_key_values=outputs.past_key_values, + # hidden_states=outputs.hidden_states, + # attentions=outputs.attentions, + # ) + + return mask_deltanetCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + aux_loss=aux_loss + ) \ No newline at end of file diff --git a/fla2/models/mask_gdn/__init__.py b/fla2/models/mask_gdn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70cdcb220bc669e65a44fc472aabb2436191cec1 --- /dev/null +++ b/fla2/models/mask_gdn/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from .configuration_mask_gdn import mask_gdnConfig +from .modeling_mask_gdn import mask_gdnForCausalLM, mask_gdnModel + +AutoConfig.register(mask_gdnConfig.model_type, mask_gdnConfig) +AutoModel.register(mask_gdnConfig, mask_gdnModel) +AutoModelForCausalLM.register(mask_gdnConfig, mask_gdnForCausalLM) + +__all__ = ['mask_gdnConfig', 'mask_gdnForCausalLM', 'mask_gdnModel'] diff --git a/fla2/models/mask_gdn/__pycache__/__init__.cpython-310.pyc b/fla2/models/mask_gdn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..036c830d4ca419b7608a9d6f927c084ef613dd15 Binary files /dev/null and b/fla2/models/mask_gdn/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/models/mask_gdn/__pycache__/__init__.cpython-312.pyc b/fla2/models/mask_gdn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1b48d814333954ddc0203fc7e59127dd38a8f61 Binary files /dev/null and b/fla2/models/mask_gdn/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/mask_gdn/__pycache__/configuration_mask_gdn.cpython-310.pyc b/fla2/models/mask_gdn/__pycache__/configuration_mask_gdn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b86502f2fc011d8d5b073f319b87b9189bbd02e Binary files /dev/null and b/fla2/models/mask_gdn/__pycache__/configuration_mask_gdn.cpython-310.pyc differ diff --git a/fla2/models/mask_gdn/__pycache__/configuration_mask_gdn.cpython-312.pyc b/fla2/models/mask_gdn/__pycache__/configuration_mask_gdn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14296dc861ea1e2953eb771a7b6d65c1e4f1791b Binary files /dev/null and b/fla2/models/mask_gdn/__pycache__/configuration_mask_gdn.cpython-312.pyc differ diff --git a/fla2/models/mask_gdn/__pycache__/modeling_mask_gdn.cpython-310.pyc b/fla2/models/mask_gdn/__pycache__/modeling_mask_gdn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c396adf57fec057981f19d6dd8c7056df05333a3 Binary files /dev/null and b/fla2/models/mask_gdn/__pycache__/modeling_mask_gdn.cpython-310.pyc differ diff --git a/fla2/models/mask_gdn/__pycache__/modeling_mask_gdn.cpython-312.pyc b/fla2/models/mask_gdn/__pycache__/modeling_mask_gdn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66b14c3b3b70438a1474e5c3a064e369af6e2ccf Binary files /dev/null and b/fla2/models/mask_gdn/__pycache__/modeling_mask_gdn.cpython-312.pyc differ diff --git a/fla2/models/mask_gdn/configuration_mask_gdn.py b/fla2/models/mask_gdn/configuration_mask_gdn.py new file mode 100644 index 0000000000000000000000000000000000000000..7bbbc61106c46f4fdda38414cbe3b63e761573b9 --- /dev/null +++ b/fla2/models/mask_gdn/configuration_mask_gdn.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + + +class mask_gdnConfig(PretrainedConfig): + + model_type = 'mask_gdn' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + use_gate: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + use_beta: bool = True, + use_output_norm: bool = True, + num_heads: int = 16, + qk_norm: str = 'l2', + qk_activation: str = 'silu', + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + num_hidden_layers: int = 24, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + aux_loss_scale:float = 0.01, + ratio : int = 4, + topk : int = 1, + **kwargs + ): + self.attn_mode = attn_mode + self.aux_loss_scale = aux_loss_scale + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_beta = use_beta + self.use_output_norm = use_output_norm + self.num_heads = num_heads + self.qk_norm = qk_norm + self.qk_activation = qk_activation + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + self.ratio = ratio + self.topk = topk + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla2/models/mask_gdn/modeling_mask_gdn.py b/fla2/models/mask_gdn/modeling_mask_gdn.py new file mode 100644 index 0000000000000000000000000000000000000000..8bbfe871090ccfe1018149467330af457c4cde04 --- /dev/null +++ b/fla2/models/mask_gdn/modeling_mask_gdn.py @@ -0,0 +1,535 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union,Iterable +from einops import rearrange +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from ...layers.attn import Attention +from ...layers.mask_gdn import mask_gdn +from ...models.mask_gdn.configuration_mask_gdn import mask_gdnConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as mask_gdnMLP +from ...modules import RMSNorm +from ...modules import RotaryEmbedding +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +class mask_gdnBlock(nn.Module): + def __init__(self, config: mask_gdnConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + print(config) + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = mask_gdn( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_short_conv=config.use_short_conv, + use_output_norm=config.use_output_norm, + conv_size=config.conv_size, + norm_eps=config.norm_eps, + ratio = config.ratio, + topk = config.topk, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = mask_gdnMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values,router_logits = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values,router_logits) + + return outputs + + +class mask_gdnPreTrainedModel(PreTrainedModel): + + config_class = mask_gdnConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['mask_gdnBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class mask_gdnModel(mask_gdnPreTrainedModel): + + def __init__(self, config: mask_gdnConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([mask_gdnBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`mask_gdnModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + all_router_logits = () + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values, router_logits = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values, router_logits = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + all_router_logits += (router_logits,) + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + # return BaseModelOutputWithPast( + # last_hidden_state=hidden_states, + # past_key_values=past_key_values, + # hidden_states=all_hidden_states, + # attentions=all_attns + # ) + return mask_gdnOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns, + router_logits=all_router_logits + ) + + +from dataclasses import dataclass +@dataclass +class mask_gdnOutputWithPast(BaseModelOutputWithPast): + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class mask_gdnCausalLMOutputWithPast(CausalLMOutputWithPast): + aux_loss: Optional[torch.FloatTensor] = None + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple], + num_memories: torch.Tensor = None, + top_k=2, + use_layer_wise_balance=False, +) -> torch.FloatTensor: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_memories]. + num_memories (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or ( + isinstance(gate_logits, Iterable) and len(gate_logits) == 0 + ): + return 0 + + # ✨ Here is the fix for balance loss in Mixtral. + # We should calculate the balance loss in a layer-wise manner otherwise it may lead to degenerated solutions. + if use_layer_wise_balance: + if not isinstance(gate_logits, Iterable): + gate_logits = (gate_logits,) + else: + if isinstance(gate_logits, Iterable): + gate_logits = (torch.cat(gate_logits, dim=0),) + else: + gate_logits = (gate_logits,) + + all_balance_losses = [] + + for logits in gate_logits: + if logits.dim() == 4: + logits = rearrange(logits,'b h l r-> (b h) l r') + routing_weights, selected_experts = torch.topk(logits, top_k, dim=-1) + routing_weights = routing_weights.softmax(dim=-1).to(logits.dtype) + routing_weights_full = torch.zeros_like(logits).scatter(-1, selected_experts, routing_weights) + + # cast the expert indices to int64, otherwise one-hot encoding will fail + if selected_experts.dtype != torch.int64: + selected_experts = selected_experts.to(torch.int64) + + if len(selected_experts.shape) == 2: + selected_experts = selected_experts.unsqueeze(2) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_memories) + + # For a given token, determine if it was routed to a given expert. + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(routing_weights_full, axis=-2) + + # ✨ balance loss for this layer + balance_loss = torch.mean( + tokens_per_group_and_expert * router_prob_per_group_and_expert + ) * (num_memories**2) + all_balance_losses.append(balance_loss.reshape(1)) + + all_balance_losses = torch.cat(all_balance_losses).mean() # ✨ + + return all_balance_losses + + +class mask_gdnForCausalLM(mask_gdnPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = mask_gdnModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + valid_router_logits = tuple( + logits + for logits in (outputs.router_logits if return_dict else outputs[-1]) + if logits is not None + ) + aux_loss = load_balancing_loss_func( + valid_router_logits, + self.config.ratio, + self.config.topk, + use_layer_wise_balance=True, + ) + aux_loss *= self.config.aux_loss_scale + # print('aux_loss:',aux_loss) + if aux_loss: + loss += aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + # return CausalLMOutputWithPast( + # loss=loss, + # logits=logits, + # past_key_values=outputs.past_key_values, + # hidden_states=outputs.hidden_states, + # attentions=outputs.attentions, + # ) + + return mask_gdnCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + aux_loss=aux_loss + ) \ No newline at end of file diff --git a/fla2/models/retnet/__init__.py b/fla2/models/retnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad7d9e9da930819a2a6728e3e189090651b82a2e --- /dev/null +++ b/fla2/models/retnet/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.retnet.configuration_retnet import RetNetConfig +from fla.models.retnet.modeling_retnet import RetNetForCausalLM, RetNetModel + +AutoConfig.register(RetNetConfig.model_type, RetNetConfig) +AutoModel.register(RetNetConfig, RetNetModel) +AutoModelForCausalLM.register(RetNetConfig, RetNetForCausalLM) + + +__all__ = ['RetNetConfig', 'RetNetForCausalLM', 'RetNetModel'] diff --git a/fla2/models/retnet/__pycache__/__init__.cpython-312.pyc b/fla2/models/retnet/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d8e47150730466621f2324060b48fbdf747885c Binary files /dev/null and b/fla2/models/retnet/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/retnet/__pycache__/__init__.cpython-38.pyc b/fla2/models/retnet/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e080755e210c89e956f2dd664589f8e690ff56c2 Binary files /dev/null and b/fla2/models/retnet/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/models/retnet/__pycache__/__init__.cpython-39.pyc b/fla2/models/retnet/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b42e5e64554ac35faa5b5833c66a6b5c89660e4 Binary files /dev/null and b/fla2/models/retnet/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/models/retnet/__pycache__/configuration_retnet.cpython-312.pyc b/fla2/models/retnet/__pycache__/configuration_retnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..973de055d9b0bef184e3ff2dd46a08617e83db43 Binary files /dev/null and b/fla2/models/retnet/__pycache__/configuration_retnet.cpython-312.pyc differ diff --git a/fla2/models/retnet/__pycache__/configuration_retnet.cpython-38.pyc b/fla2/models/retnet/__pycache__/configuration_retnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c102daaa1b1c0c50ce67530eb4a898c0e0dddd3 Binary files /dev/null and b/fla2/models/retnet/__pycache__/configuration_retnet.cpython-38.pyc differ diff --git a/fla2/models/retnet/__pycache__/configuration_retnet.cpython-39.pyc b/fla2/models/retnet/__pycache__/configuration_retnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6641ab6db66742770e9f37f5a40c1e68d974f56 Binary files /dev/null and b/fla2/models/retnet/__pycache__/configuration_retnet.cpython-39.pyc differ diff --git a/fla2/models/retnet/__pycache__/modeling_retnet.cpython-312.pyc b/fla2/models/retnet/__pycache__/modeling_retnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59da8c5080e2b1deb7b509e98cca705d895333b3 Binary files /dev/null and b/fla2/models/retnet/__pycache__/modeling_retnet.cpython-312.pyc differ diff --git a/fla2/models/retnet/__pycache__/modeling_retnet.cpython-38.pyc b/fla2/models/retnet/__pycache__/modeling_retnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef8d0bf4e5f5387b87e1fcc32a71f9bb2e5babf6 Binary files /dev/null and b/fla2/models/retnet/__pycache__/modeling_retnet.cpython-38.pyc differ diff --git a/fla2/models/retnet/__pycache__/modeling_retnet.cpython-39.pyc b/fla2/models/retnet/__pycache__/modeling_retnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c993f5d41f41ef8965c9d8a1588b7bffd6c976f Binary files /dev/null and b/fla2/models/retnet/__pycache__/modeling_retnet.cpython-39.pyc differ diff --git a/fla2/models/retnet/configuration_retnet.py b/fla2/models/retnet/configuration_retnet.py new file mode 100644 index 0000000000000000000000000000000000000000..264e1d9a06b72d17b8effec8994bdad0e781fe3f --- /dev/null +++ b/fla2/models/retnet/configuration_retnet.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class RetNetConfig(PretrainedConfig): + + model_type = 'retnet' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 2, + hidden_ratio: Optional[int] = 2, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 8, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + attn_mode: str = "chunk", + hidden_act: str = "swish", + use_short_conv: bool = False, + conv_size: int = 4, + use_output_gate: bool = True, + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + **kwargs + ) -> RetNetConfig: + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.feature_map = feature_map + self.attn_mode = attn_mode + self.hidden_act = hidden_act + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_output_gate = use_output_gate + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla2/models/retnet/modeling_retnet.py b/fla2/models/retnet/modeling_retnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c34542a580679ce68c648bdf04afe96a21b97ea3 --- /dev/null +++ b/fla2/models/retnet/modeling_retnet.py @@ -0,0 +1,407 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers.multiscale_retention import MultiScaleRetention +from fla.models.retnet.configuration_retnet import RetNetConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, RMSNorm +from fla.modules.activations import swiglu_linear + +logger = logging.get_logger(__name__) + + +class RetNetMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish' + ) -> RetNetMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class RetNetBlock(nn.Module): + def __init__(self, config: RetNetConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn = MultiScaleRetention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + feature_map=config.feature_map, + use_output_gate=config.use_output_gate, + gate_fn=config.hidden_act, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = RetNetMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class RetNetPreTrainedModel(PreTrainedModel): + + config_class = RetNetConfig + supports_gradient_checkpointing = True + _no_split_modules = ['RetNetBlock'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class RetNetModel(RetNetPreTrainedModel): + + def __init__(self, config: RetNetConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [RetNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn( + "`RetNetModel` does not support output attention weights now, so `output_attentions` is set to `False`." + ) + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_len = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_len = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache: + if past_key_values is None: + past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers] + if not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class RetNetForCausalLM(RetNetPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = RetNetModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + # Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'" + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + if not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1) + input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla2/models/rwkv6/__init__.py b/fla2/models/rwkv6/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..942c6dc203bf6c867ffd5111e7f2ae1e7c060386 --- /dev/null +++ b/fla2/models/rwkv6/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config +from fla.models.rwkv6.modeling_rwkv6 import RWKV6ForCausalLM, RWKV6Model + +AutoConfig.register(RWKV6Config.model_type, RWKV6Config) +AutoModel.register(RWKV6Config, RWKV6Model) +AutoModelForCausalLM.register(RWKV6Config, RWKV6ForCausalLM) + + +__all__ = ['RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model'] diff --git a/fla2/models/rwkv6/__pycache__/__init__.cpython-312.pyc b/fla2/models/rwkv6/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b687e052495460691dc5163c0f58caaef8297c7 Binary files /dev/null and b/fla2/models/rwkv6/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/rwkv6/__pycache__/__init__.cpython-38.pyc b/fla2/models/rwkv6/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1fca736666bc0c6437188a92d3701b56b0e392d Binary files /dev/null and b/fla2/models/rwkv6/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/models/rwkv6/__pycache__/__init__.cpython-39.pyc b/fla2/models/rwkv6/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b6a0c5c1cf39c5dee56da7c8044447cd11c5713 Binary files /dev/null and b/fla2/models/rwkv6/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/models/rwkv6/__pycache__/configuration_rwkv6.cpython-312.pyc b/fla2/models/rwkv6/__pycache__/configuration_rwkv6.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e345539e7b64464be13d6a4193e7dd6ca72e79ca Binary files /dev/null and b/fla2/models/rwkv6/__pycache__/configuration_rwkv6.cpython-312.pyc differ diff --git a/fla2/models/rwkv6/__pycache__/configuration_rwkv6.cpython-38.pyc b/fla2/models/rwkv6/__pycache__/configuration_rwkv6.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d99de37ad25c71391f5e5d2e7d05da6a29d6d71 Binary files /dev/null and b/fla2/models/rwkv6/__pycache__/configuration_rwkv6.cpython-38.pyc differ diff --git a/fla2/models/rwkv6/__pycache__/configuration_rwkv6.cpython-39.pyc b/fla2/models/rwkv6/__pycache__/configuration_rwkv6.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d659519b3b102d211ae7cb35d38db3acf5f5ddc Binary files /dev/null and b/fla2/models/rwkv6/__pycache__/configuration_rwkv6.cpython-39.pyc differ diff --git a/fla2/models/rwkv6/__pycache__/modeling_rwkv6.cpython-312.pyc b/fla2/models/rwkv6/__pycache__/modeling_rwkv6.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14a2188a946ee27550622c04d7e3b45ef4a00226 Binary files /dev/null and b/fla2/models/rwkv6/__pycache__/modeling_rwkv6.cpython-312.pyc differ diff --git a/fla2/models/rwkv6/__pycache__/modeling_rwkv6.cpython-38.pyc b/fla2/models/rwkv6/__pycache__/modeling_rwkv6.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6de7d40b2fadd15af00d0b9e1b5d511f8347e69 Binary files /dev/null and b/fla2/models/rwkv6/__pycache__/modeling_rwkv6.cpython-38.pyc differ diff --git a/fla2/models/rwkv6/__pycache__/modeling_rwkv6.cpython-39.pyc b/fla2/models/rwkv6/__pycache__/modeling_rwkv6.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c46eb54fbb7137b7cd3f3f01748ea7fad763a4c6 Binary files /dev/null and b/fla2/models/rwkv6/__pycache__/modeling_rwkv6.cpython-39.pyc differ diff --git a/fla2/models/rwkv6/configuration_rwkv6.py b/fla2/models/rwkv6/configuration_rwkv6.py new file mode 100644 index 0000000000000000000000000000000000000000..87007da115c01d53c0ee7a167a992de14e5d601d --- /dev/null +++ b/fla2/models/rwkv6/configuration_rwkv6.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class RWKV6Config(PretrainedConfig): + + model_type = 'rwkv6' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + vocab_size: int = 32000, + hidden_size: int = 2048, + expand_k: int = 0.5, + expand_v: int = 1, + hidden_ratio: Optional[int] = 3.5, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + proj_low_rank_dim: int = 32, + gate_low_rank_dim: int = 64, + hidden_act: str = "sqrelu", + max_position_embeddings: int = 2048, + norm_first: bool = True, + norm_bias: bool = True, + norm_eps: float = 1e-5, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.norm_first = norm_first + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.proj_low_rank_dim = proj_low_rank_dim + self.gate_low_rank_dim = gate_low_rank_dim + self.attn_mode = attn_mode + self.hidden_act = hidden_act + self.norm_bias = norm_bias + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla2/models/rwkv6/modeling_rwkv6.py b/fla2/models/rwkv6/modeling_rwkv6.py new file mode 100644 index 0000000000000000000000000000000000000000..93a716870163c86ef73d33ede361ff1981f9a22c --- /dev/null +++ b/fla2/models/rwkv6/modeling_rwkv6.py @@ -0,0 +1,435 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers.rwkv6 import LerpLinear, RWKV6Attention +from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, LayerNorm +from fla.modules.activations import ACT2FN + +logger = logging.get_logger(__name__) + + +class RWKV6FeedForward(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'sqrelu', + layer_idx: int = None + ) -> RWKV6FeedForward: + super().__init__() + + self.hidden_size = hidden_size + if hidden_ratio is None: + hidden_ratio = 3.5 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio) + intermediate_size = 32 * ((intermediate_size + 32 - 1) // 32) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + self.key = LerpLinear(hidden_size, intermediate_size) + self.value = nn.Linear(intermediate_size, hidden_size, bias=False) + self.receptance = LerpLinear(hidden_size, hidden_size) + self.act_fn = ACT2FN[hidden_act] + + self.layer_idx = layer_idx + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + state: Optional[Cache] = None + ) -> torch.Tensor: + if attention_mask is not None: + x = x.mul_(attention_mask.unsqueeze(-1)) + if x.shape[1] == 1 and state is not None: + shifted = state[self.layer_idx][-1].unsqueeze(1) + else: + shifted = self.time_shift(x) + if state is not None: + shifted[:, 0] = state[self.layer_idx][-1] + delta = shifted - x + key = self.act_fn(self.key(x, delta)) + value = self.value(key) + receptance = self.receptance(x, delta) + + if state is not None: + state[self.layer_idx][-1] = x[:, -1] + return receptance.sigmoid() * value, state + + def init_state(self, batch_size: Optional[int] = None) -> Tuple[torch.Tensor]: + param = next(self.parameters()) + state = [param.new_zeros(batch_size, self.hidden_size)] + return state + + +class RWKV6Block(nn.Module): + def __init__(self, config: RWKV6Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.config = config + self.layer_idx = layer_idx + + if config.norm_first and layer_idx == 0: + self.pre_norm = LayerNorm(hidden_size=config.hidden_size, bias=config.norm_bias, eps=config.norm_eps) + self.attn_norm = LayerNorm(hidden_size=config.hidden_size, bias=config.norm_bias, eps=config.norm_eps) + self.attn = RWKV6Attention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + proj_low_rank_dim=config.proj_low_rank_dim, + gate_low_rank_dim=config.gate_low_rank_dim, + norm_eps=config.norm_eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + self.ffn_norm = LayerNorm(hidden_size=config.hidden_size, bias=config.norm_bias, eps=config.norm_eps) + self.ffn = RWKV6FeedForward( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + layer_idx=layer_idx + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = self.pre_norm(hidden_states) if hasattr(self, 'pre_norm') else hidden_states + hidden_states = self.attn_norm(residual) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + hidden_states, residual = self.ffn_norm(hidden_states, residual, True) + hidden_states, past_key_values = self.ffn(hidden_states, attention_mask, past_key_values) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + def init_state(self, **kwargs) -> Tuple[torch.Tensor]: + state = [] + if callable(getattr(self.attn, 'init_state', None)): + state += self.attn.init_state(**kwargs) + if callable(getattr(self.ffn, 'init_state', None)): + state += self.ffn.init_state(**kwargs) + return state + + +class RWKV6PreTrainedModel(PreTrainedModel): + + config_class = RWKV6Config + supports_gradient_checkpointing = True + _no_split_modules = ['RWKV6Block'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Parameter): + nn.init.normal_(module, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class RWKV6Model(RWKV6PreTrainedModel): + + def __init__(self, config: RWKV6Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([RWKV6Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = LayerNorm(config.hidden_size, bias=config.norm_bias, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`RWKV6Model` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache: + if past_key_values is None: + past_key_values = [layer.init_state(batch_size=batch_size) for layer in self.layers] + if not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class RWKV6ForCausalLM(RWKV6PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = RWKV6Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + if not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1) + input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla2/models/samba/__init__.py b/fla2/models/samba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..244913e776944de23878781f4be7bd037fac89ab --- /dev/null +++ b/fla2/models/samba/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.samba.configuration_samba import SambaConfig +from fla.models.samba.modeling_samba import (SambaBlock, SambaForCausalLM, + SambaModel) + +AutoConfig.register(SambaConfig.model_type, SambaConfig, True) +AutoModel.register(SambaConfig, SambaModel, True) +AutoModelForCausalLM.register(SambaConfig, SambaForCausalLM, True) + + +__all__ = ['SambaConfig', 'SambaForCausalLM', 'SambaModel', 'SambaBlock'] diff --git a/fla2/models/samba/__pycache__/__init__.cpython-312.pyc b/fla2/models/samba/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e28027f4f7e096e300666cfbfdb0cec0aa63d11f Binary files /dev/null and b/fla2/models/samba/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/samba/__pycache__/__init__.cpython-38.pyc b/fla2/models/samba/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0efb8f6c15ed3e7b90d8e4066acf043e9b31933 Binary files /dev/null and b/fla2/models/samba/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/models/samba/__pycache__/__init__.cpython-39.pyc b/fla2/models/samba/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..366f07733bddcbf19ab1cb6d75914f4ff6aa9c0c Binary files /dev/null and b/fla2/models/samba/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/models/samba/__pycache__/configuration_samba.cpython-312.pyc b/fla2/models/samba/__pycache__/configuration_samba.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3cc797030aed714bb0547f264e07bb841b757a3 Binary files /dev/null and b/fla2/models/samba/__pycache__/configuration_samba.cpython-312.pyc differ diff --git a/fla2/models/samba/__pycache__/configuration_samba.cpython-38.pyc b/fla2/models/samba/__pycache__/configuration_samba.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6551566f78e933d3d61cc11e69e9a274210fae60 Binary files /dev/null and b/fla2/models/samba/__pycache__/configuration_samba.cpython-38.pyc differ diff --git a/fla2/models/samba/__pycache__/configuration_samba.cpython-39.pyc b/fla2/models/samba/__pycache__/configuration_samba.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e2326f6cedffb9421193d8f5b398ba679e1cd4f Binary files /dev/null and b/fla2/models/samba/__pycache__/configuration_samba.cpython-39.pyc differ diff --git a/fla2/models/samba/__pycache__/modeling_samba.cpython-312.pyc b/fla2/models/samba/__pycache__/modeling_samba.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21caa869e2ddf09ef9be03d9a43ce484617435f1 Binary files /dev/null and b/fla2/models/samba/__pycache__/modeling_samba.cpython-312.pyc differ diff --git a/fla2/models/samba/__pycache__/modeling_samba.cpython-38.pyc b/fla2/models/samba/__pycache__/modeling_samba.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1926ad85b74bb680ec0a73a3ad5b777a0f70c151 Binary files /dev/null and b/fla2/models/samba/__pycache__/modeling_samba.cpython-38.pyc differ diff --git a/fla2/models/samba/__pycache__/modeling_samba.cpython-39.pyc b/fla2/models/samba/__pycache__/modeling_samba.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b7310a7c85a6563077e4b548e04dd577342d813 Binary files /dev/null and b/fla2/models/samba/__pycache__/modeling_samba.cpython-39.pyc differ diff --git a/fla2/models/samba/configuration_samba.py b/fla2/models/samba/configuration_samba.py new file mode 100644 index 0000000000000000000000000000000000000000..2f4d5b1e340ead152dec5a8231dd1d2023877c7d --- /dev/null +++ b/fla2/models/samba/configuration_samba.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- + +import math +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class SambaConfig(PretrainedConfig): + + model_type = "samba" + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2304, + state_size: int = 16, + num_hidden_layers: int = 18, + norm_eps=1e-5, + pad_token_id: int = 0, + bos_token_id: int = 1, + eos_token_id: int = 2, + expand: int = 2, + conv_kernel: int = 4, + use_bias: bool = False, + use_conv_bias: bool = True, + hidden_act: str = "silu", + initializer_range: str = 0.02, + residual_in_fp32: bool = False, + time_step_rank: str = "auto", + time_step_scale: float = 1.0, + time_step_min: float = 0.001, + time_step_max: float = 0.1, + time_step_init_scheme: str = "random", + time_step_floor: float = 1e-4, + num_heads: int = 18, + num_kv_heads: int = 18, + window_size: int = 2048, + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + rescale_prenorm_residual: bool = False, + use_cache: bool = True, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + tie_word_embeddings: bool = False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.state_size = state_size + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.conv_kernel = conv_kernel + self.expand = expand + self.intermediate_size = int(expand * self.hidden_size) + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.use_bias = use_bias + self.use_conv_bias = use_conv_bias + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank + self.time_step_scale = time_step_scale + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.time_step_init_scheme = time_step_init_scheme + self.time_step_floor = time_step_floor + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.window_size = window_size + self.max_position_embeddings = max_position_embeddings + self.hidden_ratio = hidden_ratio + self.rescale_prenorm_residual = rescale_prenorm_residual + self.residual_in_fp32 = residual_in_fp32 + self.use_cache = use_cache + self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_norm = fuse_norm + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs + ) diff --git a/fla2/models/samba/modeling_samba.py b/fla2/models/samba/modeling_samba.py new file mode 100644 index 0000000000000000000000000000000000000000..9b265745923df570ade6d61696d25d22bf40fbb0 --- /dev/null +++ b/fla2/models/samba/modeling_samba.py @@ -0,0 +1,388 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput, logging + +from fla.layers.attn import Attention +from fla.models.mamba.modeling_mamba import MambaCache, MambaMixer +from fla.models.samba.configuration_samba import SambaConfig +from fla.modules import FusedCrossEntropyLoss, RMSNorm +from fla.modules.activations import swiglu_linear + +logger = logging.get_logger(__name__) + + +class SambaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + hidden_act: str = 'swish' + ) -> SambaMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + self.hidden_ratio = hidden_ratio + + self.intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + self.intermediate_size = 256 * ((self.intermediate_size + 256 - 1) // 256) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class SambaBlock(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.mixer_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + if self.layer_idx % 2 == 0: + self.mixer = MambaMixer(config, layer_idx=layer_idx) + else: + self.mixer = Attention( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + window_size=config.window_size, + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = SambaMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + hidden_act=config.hidden_act + ) + + def forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[Tuple[torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + hidden_states = self.mixer_norm(hidden_states) + if isinstance(self.mixer, MambaMixer): + hidden_states = self.mixer(hidden_states, cache_params=cache_params) + else: + hidden_states, _, cache_params = self.mixer(hidden_states=hidden_states, past_key_values=cache_params) + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class SambaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SambaConfig + base_model_prefix = "backbone" + _no_split_modules = ["SambaBlock"] + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, MambaMixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale + if self.config.time_step_init_scheme == "constant": + nn.init.constant_(module.dt_proj.weight, dt_init_std) + elif self.config.time_step_init_scheme == "random": + nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std) + + dt = torch.exp( + torch.rand(self.config.intermediate_size) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + module.dt_proj.bias.copy_(inv_dt) + module.dt_proj.bias._no_reinit = True + + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.num_layers) + + +@dataclass +class SambaOutput(ModelOutput): + """ + Class for the Samba model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`MambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cache_params: Optional[MambaCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class SambaCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`MambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[MambaCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class SambaModel(SambaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([SambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = RMSNorm(config.hidden_size, eps=config.norm_eps) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + cache_params: Optional[MambaCache] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it + ) -> Union[Tuple, SambaOutput]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if cache_params is None and use_cache: + cache_params = MambaCache( + self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + + hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None + for mixer_block in self.layers: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func(mixer_block.__call__, hidden_states, cache_params) + else: + hidden_states = mixer_block(hidden_states, cache_params=cache_params) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if use_cache: + cache_params.seqlen_offset += inputs_embeds.shape[1] + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return SambaOutput( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + ) + + +class SambaForCausalLM(SambaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.backbone = SambaModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def _update_model_kwargs_for_generation( + self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs + ) -> Dict[str, Any]: + model_kwargs["cache_params"] = outputs.get("cache_params", None) + return model_kwargs + + def prepare_inputs_for_generation( + self, input_ids, cache_params: Optional[MambaCache] = None, inputs_embeds=None, attention_mask=None, **kwargs + ): + # only last token for inputs_ids if the state is passed along. + if cache_params is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + + if inputs_embeds is not None and cache_params is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs["cache_params"] = cache_params + return model_inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_params: Optional[MambaCache] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + **kwargs, # for now we need this for generation + ) -> Union[Tuple, SambaCausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + samba_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + ) + hidden_states = samba_outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + samba_outputs[1:] + return (loss,) + output if loss is not None else output + + return SambaCausalLMOutput( + loss=loss, + logits=logits, + cache_params=samba_outputs.cache_params, + hidden_states=samba_outputs.hidden_states, + ) diff --git a/fla2/models/transformer/__init__.py b/fla2/models/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c508f35f61e3d176eb78b659606de924e508f65c --- /dev/null +++ b/fla2/models/transformer/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from .configuration_transformer import TransformerConfig +from .modeling_transformer import ( + TransformerForCausalLM, TransformerModel) + +AutoConfig.register(TransformerConfig.model_type, TransformerConfig) +AutoModel.register(TransformerConfig, TransformerModel) +AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM) + + +__all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel'] diff --git a/fla2/models/transformer/__pycache__/__init__.cpython-310.pyc b/fla2/models/transformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54d39b28054780ad15a9a2f55bf45d02faa671e9 Binary files /dev/null and b/fla2/models/transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/models/transformer/__pycache__/__init__.cpython-312.pyc b/fla2/models/transformer/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c59d9f2a83ad301e09799bc9ebdcde86dc4c717 Binary files /dev/null and b/fla2/models/transformer/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/models/transformer/__pycache__/__init__.cpython-38.pyc b/fla2/models/transformer/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4793b99b66a9f8cbbf040d9e432bd7261414ef4c Binary files /dev/null and b/fla2/models/transformer/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/models/transformer/__pycache__/__init__.cpython-39.pyc b/fla2/models/transformer/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12f0654b8fb9dead3e81438fe08b7ca2ed66d337 Binary files /dev/null and b/fla2/models/transformer/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/models/transformer/__pycache__/configuration_transformer.cpython-310.pyc b/fla2/models/transformer/__pycache__/configuration_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78ef92fbb7aa5dc28d0f94c6be597865e4ec27bf Binary files /dev/null and b/fla2/models/transformer/__pycache__/configuration_transformer.cpython-310.pyc differ diff --git a/fla2/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc b/fla2/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3a958d8ce6baf5012a0286851b6c703572f088f Binary files /dev/null and b/fla2/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc differ diff --git a/fla2/models/transformer/__pycache__/configuration_transformer.cpython-38.pyc b/fla2/models/transformer/__pycache__/configuration_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e88e669d0e39efad1c61998a526045dfe56d95e Binary files /dev/null and b/fla2/models/transformer/__pycache__/configuration_transformer.cpython-38.pyc differ diff --git a/fla2/models/transformer/__pycache__/configuration_transformer.cpython-39.pyc b/fla2/models/transformer/__pycache__/configuration_transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5ebb9857c4c41165ece0b45d243176f054f1cde Binary files /dev/null and b/fla2/models/transformer/__pycache__/configuration_transformer.cpython-39.pyc differ diff --git a/fla2/models/transformer/__pycache__/modeling_transformer.cpython-310.pyc b/fla2/models/transformer/__pycache__/modeling_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac9d571ab3ab86f0c8b5a0fe5882fa5db23793ee Binary files /dev/null and b/fla2/models/transformer/__pycache__/modeling_transformer.cpython-310.pyc differ diff --git a/fla2/models/transformer/__pycache__/modeling_transformer.cpython-312.pyc b/fla2/models/transformer/__pycache__/modeling_transformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c58b3e9961fc8d7822e611764c9a882971fc5ad Binary files /dev/null and b/fla2/models/transformer/__pycache__/modeling_transformer.cpython-312.pyc differ diff --git a/fla2/models/transformer/__pycache__/modeling_transformer.cpython-38.pyc b/fla2/models/transformer/__pycache__/modeling_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01a3006d9a83ec0b0a049b0105e5c37c059b1e8c Binary files /dev/null and b/fla2/models/transformer/__pycache__/modeling_transformer.cpython-38.pyc differ diff --git a/fla2/models/transformer/__pycache__/modeling_transformer.cpython-39.pyc b/fla2/models/transformer/__pycache__/modeling_transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0490139f0a3465db4a81aa04369719feb20cb815 Binary files /dev/null and b/fla2/models/transformer/__pycache__/modeling_transformer.cpython-39.pyc differ diff --git a/fla2/models/transformer/configuration_transformer.py b/fla2/models/transformer/configuration_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..6602c6c536080616e54662fffa58ef75c7a6a462 --- /dev/null +++ b/fla2/models/transformer/configuration_transformer.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class TransformerConfig(PretrainedConfig): + + model_type = 'transformer' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 32, + num_kv_heads: int = None, + hidden_act: str = "swish", + window_size: Optional[int] = None, + max_position_embeddings: int = 2048, + initializer_range: float = 0.02, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + attention_bias: bool = False, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.window_size = window_size + self.max_position_embeddings = max_position_embeddings + + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_norm = fuse_norm + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla2/models/transformer/modeling_transformer.py b/fla2/models/transformer/modeling_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d0576f8a18c63a513a1adb94bc1b21afe7812da6 --- /dev/null +++ b/fla2/models/transformer/modeling_transformer.py @@ -0,0 +1,389 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from ...layers.attn import Attention +from ...models.transformer.configuration_transformer import TransformerConfig +from ...modules import FusedCrossEntropyLoss, RMSNorm +from ...modules.activations import swiglu_linear + +logger = logging.get_logger(__name__) + + +class TransformerMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish' + ) -> TransformerMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class TransformerBlock(nn.Module): + def __init__(self, config: TransformerConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + window_size=config.window_size, + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = TransformerMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attentions,) + + if use_cache: + outputs += (past_key_values,) + + return outputs + + +class TransformerPreTrainedModel(PreTrainedModel): + + config_class = TransformerConfig + supports_gradient_checkpointing = True + _no_split_modules = ['TransformerBlock'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class TransformerModel(TransformerPreTrainedModel): + + def __init__(self, config: TransformerConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([TransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, CausalLMOutputWithPast]: + if output_attentions: + warnings.warn( + "`TransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`." + ) + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + next_decoder_cache = None + + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + output_attentions, + use_cache + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class TransformerForCausalLM(TransformerPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = TransformerModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla2/modules/__init__.py b/fla2/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4450e4ca2fd200525eacc7a5d549bc972eadb7d6 --- /dev/null +++ b/fla2/modules/__init__.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- + +from ..modules.convolution import (ImplicitLongConvolution, LongConvolution, + ShortConvolution) +from ..modules.fused_cross_entropy import FusedCrossEntropyLoss +from ..modules.fused_norm_gate import (FusedLayerNormSwishGate, + FusedLayerNormSwishGateLinear, + FusedRMSNormSwishGate, + FusedRMSNormSwishGateLinear) +from ..modules.layernorm import (GroupNorm, GroupNormLinear, LayerNorm, + LayerNormLinear, RMSNorm, RMSNormLinear) +from ..modules.rotary import RotaryEmbedding + +__all__ = [ + 'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution', + 'FusedCrossEntropyLoss', + 'GroupNorm', 'GroupNormLinear', 'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear', + 'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear', + 'RotaryEmbedding' +] diff --git a/fla2/modules/__pycache__/__init__.cpython-310.pyc b/fla2/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b62ed998e042dbf172faa552a8bb64ae9619c971 Binary files /dev/null and b/fla2/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/modules/__pycache__/__init__.cpython-312.pyc b/fla2/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fe0a3d2bd81699b6b25a08c9e2c7bd09ec20fc3 Binary files /dev/null and b/fla2/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/modules/__pycache__/__init__.cpython-38.pyc b/fla2/modules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9979da8a268e64bf92a8930fcf2d05f5f5e9d38b Binary files /dev/null and b/fla2/modules/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/modules/__pycache__/__init__.cpython-39.pyc b/fla2/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b01d652ed626332fec95ff265c49450012af1134 Binary files /dev/null and b/fla2/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/modules/__pycache__/activations.cpython-310.pyc b/fla2/modules/__pycache__/activations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d683366881c10cfa91b772923dd2a5e01984648b Binary files /dev/null and b/fla2/modules/__pycache__/activations.cpython-310.pyc differ diff --git a/fla2/modules/__pycache__/activations.cpython-312.pyc b/fla2/modules/__pycache__/activations.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff376a1e0a9f543e0abd53d4b28e0de0f523fafe Binary files /dev/null and b/fla2/modules/__pycache__/activations.cpython-312.pyc differ diff --git a/fla2/modules/__pycache__/activations.cpython-38.pyc b/fla2/modules/__pycache__/activations.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25dd1303b1ce95f85ea76b7fa3cb2e8f673080d3 Binary files /dev/null and b/fla2/modules/__pycache__/activations.cpython-38.pyc differ diff --git a/fla2/modules/__pycache__/activations.cpython-39.pyc b/fla2/modules/__pycache__/activations.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..807b1d8c0725492d13bf818a101a4e0d1856c2ad Binary files /dev/null and b/fla2/modules/__pycache__/activations.cpython-39.pyc differ diff --git a/fla2/modules/__pycache__/convolution.cpython-310.pyc b/fla2/modules/__pycache__/convolution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e08e73c28aac239f4658c8c48d9e91c308245d32 Binary files /dev/null and b/fla2/modules/__pycache__/convolution.cpython-310.pyc differ diff --git a/fla2/modules/__pycache__/convolution.cpython-312.pyc b/fla2/modules/__pycache__/convolution.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03f42c36b933d0f1431984f26c34462eb5e7c792 Binary files /dev/null and b/fla2/modules/__pycache__/convolution.cpython-312.pyc differ diff --git a/fla2/modules/__pycache__/convolution.cpython-38.pyc b/fla2/modules/__pycache__/convolution.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1cc51aa4dbdff5a14dff465bf933bb2f0904bd2 Binary files /dev/null and b/fla2/modules/__pycache__/convolution.cpython-38.pyc differ diff --git a/fla2/modules/__pycache__/convolution.cpython-39.pyc b/fla2/modules/__pycache__/convolution.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a31e5f35448b27cd46ffa5b1071ac2da23b4cf0 Binary files /dev/null and b/fla2/modules/__pycache__/convolution.cpython-39.pyc differ diff --git a/fla2/modules/__pycache__/feature_map.cpython-312.pyc b/fla2/modules/__pycache__/feature_map.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14ad08b9471e5106ffe18591ae2ac4e1ed96e3a5 Binary files /dev/null and b/fla2/modules/__pycache__/feature_map.cpython-312.pyc differ diff --git a/fla2/modules/__pycache__/feature_map.cpython-39.pyc b/fla2/modules/__pycache__/feature_map.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0cb193923389214d358ca1000ced8c6bb77fdc2 Binary files /dev/null and b/fla2/modules/__pycache__/feature_map.cpython-39.pyc differ diff --git a/fla2/modules/activations.py b/fla2/modules/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..7d15645fd8141442a75b59343156f66ff9df196a --- /dev/null +++ b/fla2/modules/activations.py @@ -0,0 +1,394 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023-2024, Tri Dao, Yu Zhang, Songlin Yang. + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +from ..utils import contiguous + +sigmoid_fwd_codestring = """ +template T sigmoid_fwd(T x) { + return 1.0f / (1.0f + ::exp(-float(x))); +} +""" +sigmoid_bwd_codestring = """ +template T sigmoid_bwd(T x, T g) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + return float(g) * x_sigmoid * (1.0f - x_sigmoid); +} +""" + +sigmoid_fwd = torch.cuda.jiterator._create_jit_fn(sigmoid_fwd_codestring) +sigmoid_bwd = torch.cuda.jiterator._create_jit_fn(sigmoid_bwd_codestring) + + +class SigmoidFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return sigmoid_fwd(x) + + @staticmethod + def backward(ctx, dout): + x, = ctx.saved_tensors + return sigmoid_bwd(x, dout) + + +sigmoid = SigmoidFunction.apply + + +@triton.autotune( + configs=[ + triton.Config({'BT': 16}, num_warps=2), + triton.Config({'BT': 16}, num_warps=4), + triton.Config({'BT': 16}, num_warps=8), + triton.Config({'BT': 32}, num_warps=2), + triton.Config({'BT': 32}, num_warps=4), + triton.Config({'BT': 32}, num_warps=8), + triton.Config({'BT': 64}, num_warps=2), + triton.Config({'BT': 64}, num_warps=4), + triton.Config({'BT': 64}, num_warps=8), + triton.Config({'BT': 128}, num_warps=2), + triton.Config({'BT': 128}, num_warps=4), + triton.Config({'BT': 128}, num_warps=8), + triton.Config({'BT': 256}, num_warps=2), + triton.Config({'BT': 256}, num_warps=4), + triton.Config({'BT': 256}, num_warps=8) + ], + key=['D'] +) +@triton.jit +def logsigmoid_fwd_kernel( + x, + y, + T: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr +): + i = tl.program_id(0) + o_i = i * BT + tl.arange(0, BT) + + p_x = x + o_i + p_y = y + o_i + mask = o_i < T + + # [D,] + b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32) + b_m = tl.minimum(0., b_x) + b_z = 1. + tl.exp(-tl.abs(b_x)) + b_y = b_m - tl.log(b_z) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({'BT': 16}, num_warps=2), + triton.Config({'BT': 16}, num_warps=4), + triton.Config({'BT': 16}, num_warps=8), + triton.Config({'BT': 32}, num_warps=2), + triton.Config({'BT': 32}, num_warps=4), + triton.Config({'BT': 32}, num_warps=8), + triton.Config({'BT': 64}, num_warps=2), + triton.Config({'BT': 64}, num_warps=4), + triton.Config({'BT': 64}, num_warps=8), + triton.Config({'BT': 128}, num_warps=2), + triton.Config({'BT': 128}, num_warps=4), + triton.Config({'BT': 128}, num_warps=8), + triton.Config({'BT': 256}, num_warps=2), + triton.Config({'BT': 256}, num_warps=4), + triton.Config({'BT': 256}, num_warps=8) + ], + key=['D'] +) +@triton.jit +def logsigmoid_bwd_kernel( + x, + dx, + dy, + T: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr +): + i = tl.program_id(0) + o_i = i * BT + tl.arange(0, BT) + + p_x = x + o_i + p_dx = dx + o_i + p_dy = dy + o_i + mask = o_i < T + + # [D,] + b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32) + b_dy = tl.load(p_dy, mask=mask, other=0.).to(tl.float32) + b_dx = b_dy * (1. - tl.sigmoid(b_x)) + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) + + +class LogSigmoidFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, x): + T, D = x.numel(), x.shape[-1] + y = torch.empty_like(x) + logsigmoid_fwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, y, T=T, D=D) + ctx.save_for_backward(x,) + return y + + @staticmethod + @contiguous + def backward(ctx, dy): + x, = ctx.saved_tensors + T, D = x.numel(), x.shape[-1] + dx = torch.empty_like(x) + logsigmoid_bwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, dx, dy, T=T, D=D) + return dx + + +logsigmoid = LogSigmoidFunction.apply + +swish_fwd_codestring = """ +template T swish_fwd(T x) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + return float(x) * x_sigmoid; +} +""" +swish_bwd_codestring = """ +template T swish_bwd(T x, T g) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + return float(g) * x_sigmoid * (1.0f - float(x) * x_sigmoid + float(x)); +} +""" + +swish_fwd = torch.cuda.jiterator._create_jit_fn(swish_fwd_codestring) +swish_bwd = torch.cuda.jiterator._create_jit_fn(swish_bwd_codestring) + + +class SwishFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_fwd(x) + + @staticmethod + def backward(ctx, dout): + x, = ctx.saved_tensors + return swish_bwd(x, dout) + + +swish = SwishFunction.apply + +# 1/sqrt(2*pi)-> 0.3989423 +# 1/sqrt(2) -> 0.70710678 +# sqrt(2/pi) -> 0.79788456 + + +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) +@torch.jit.script +def bias_gelu(y, bias): + x = bias + y + return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) + + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def bias_gelu_bwd(g, y, bias): + """Assume that y has shape (B, D) and bias has shape (D)""" + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( + 1 + tanh_out + ) + grad_y = ff * g + return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) + + +class GeLUFunction(torch.autograd.Function): + + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(input, bias) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_bwd(grad_output, input, bias) + return tmp, tmp + + +bias_gelu_impl = GeLUFunction.apply + + +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) +@torch.jit.script +def gelu_fwd(x): + return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) + + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def gelu_bwd(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( + 1 + tanh_out + ) + return (ff * g).to(dtype=x.dtype) + + +class FastGeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input): + ctx.save_for_backward(input) + return gelu_fwd(input) + + @staticmethod + def backward(ctx, grad_output): + (input,) = ctx.saved_tensors + tmp = gelu_bwd(grad_output, input) + return tmp + + +fast_gelu_impl = FastGeLUFunction.apply + + +@torch.jit.script +def relu_bwd(g, x): + return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype) + + +@torch.jit.script +def sqrelu_fwd(x): + r = F.relu(x) + return (r * r).to(dtype=x.dtype) + + +@torch.jit.script +def sqrelu_bwd(g, x): + return (2.0 * g * F.relu(x)).to(dtype=x.dtype) + + +class SquaredReLUFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + return sqrelu_fwd(input) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return sqrelu_bwd(grad_output, input) + + +sqrelu = SquaredReLUFunction.apply + + +swiglu_fwd_codestring = """ +template T swiglu_fwd(T x, T y) { + return float(x) * float(y) / (1.0f + ::exp(-float(x))); +} +""" +swiglu_bwd_codestring = """ +template T swiglu_bwd(T x, T y, T g, T& dx, T& dy) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); + dy = float(x) * x_sigmoid * float(g); +} +""" + +swiglu_bwd_with_output_codestring = """ +template T swiglu_bwd_with_output(T x, T y, T g, T& dx, T& dy, T& z) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + float x_swish = float(x) * x_sigmoid; + dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); + dy = x_swish * float(g); + z = x_swish * float(y); +} +""" + +swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring) +swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2) +swiglu_bwd_with_output = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_with_output_codestring, num_outputs=3) + + +class SwiGLUFunction(torch.autograd.Function): + r""" + Swish-Gated Linear Unit (SwiGLU) function. + + .. math:: + \text{SwiGLU}(x, y) = swish(x) * y = \frac{x}{1 + \exp(-x)} * y + """ + + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + return swiglu_fwd(x, y) + + @staticmethod + def backward(ctx, dout): + x, y = ctx.saved_tensors + return swiglu_bwd(x, y, dout) + + +class SwiGLULinearFunction(torch.autograd.Function): + r""" + Swish-Gated Linear Unit (SwiGLU) function followed by a linear transformation. + + .. math:: + \text{SwiGLULinear}(x, y, W, b) = (swish(x) * y) W + b + + This simple wrap discards the intermediate results of SwiGLU(x, y) to save memory. + """ + + @staticmethod + def forward(ctx, x, y, weight, bias): + z = swiglu_fwd(x, y) + out = F.linear(z.to(weight.dtype), weight, bias) + # We don't store z, will be recomputed in the backward pass to save memory + ctx.save_for_backward(x, y, weight) + ctx.linear_bias_is_none = bias is None + return out + + @staticmethod + def backward(ctx, dout, *args): + x, y, weight = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dz = F.linear(dout, weight.t()).view_as(x) + dx, dy, z = swiglu_bwd_with_output(x, y, dz) + dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1])) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + return dx, dy, dlinear_weight, dlinear_bias + + +swiglu = SwiGLUFunction.apply + +swiglu_linear = SwiGLULinearFunction.apply + +ACT2FN = { + 'relu': F.relu, + 'sigmoid': sigmoid, + 'logsigmoid': logsigmoid, + 'silu': swish, + 'swish': swish, + 'sqrelu': sqrelu, + 'gelu': fast_gelu_impl, + 'bias_gelu': bias_gelu_impl, +} diff --git a/fla2/modules/convolution.py b/fla2/modules/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..731367983684b54ddcc162b1dfbdbab3a001d874 --- /dev/null +++ b/fla2/modules/convolution.py @@ -0,0 +1,345 @@ +# -*- coding: utf-8 -*- + +# from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py + +import math +import warnings +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from ..modules.activations import ACT2FN +from ..utils import checkpoint + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_fn = None + causal_conv1d_update = None + + +def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None): + seqlen = u.shape[-1] + fft_size = 2 * seqlen + k_f = torch.fft.rfft(k, n=fft_size) / fft_size + if k_rev is not None: + k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size + k_f = k_f + k_rev_f.conj() + u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) + + if len(u.shape) > 3: + k_f = k_f.unsqueeze(1) + y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen] + + out = y + u + if gelu: + out = F.gelu(out) + if dropout_mask is not None: + return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype) + else: + return out.to(dtype=u.dtype) + + +@checkpoint +def proj_then_conv1d( + x: torch.Tensor, + proj_weight: torch.Tensor, + conv1d_weight: torch.Tensor, + conv1d_bias: Optional[torch.Tensor] = None, + cache: Optional[torch.Tensor] = None +) -> torch.Tensor: + # We do matmul and transpose BLH -> HBL at the same time + x = rearrange(proj_weight @ rearrange(x, "b l d -> d (b l)"), "d (b l) -> b d l", l=x.shape[-2]) + + if causal_conv1d_fn is None: + raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.") + if cache is None: + x = causal_conv1d_fn( + x=x, + weight=rearrange(conv1d_weight, "d 1 w -> d w"), + bias=conv1d_bias, + activation="silu", + ).transpose(1, 2) + else: + assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now" + x = x.squeeze(-1) + x = causal_conv1d_update( + x=x, + weight=rearrange(conv1d_weight, "d 1 w -> d w"), + bias=conv1d_bias, + cache=cache, + activation="silu", + ) + return x + + +class ShortConvolution(nn.Conv1d): + """ + Simple wrapper around `nn.Conv1d` that accepts dimension last. + """ + + def __init__( + self, + hidden_size: int, + kernel_size: int, + bias: bool = False, + activation: Optional[str] = 'silu', + use_fast_conv1d: Optional[bool] = True + ): + super().__init__( + in_channels=hidden_size, + out_channels=hidden_size, + kernel_size=kernel_size, + groups=hidden_size, + bias=bias, + padding=kernel_size - 1 + ) + + self.hidden_size = hidden_size + self.activation = None + if activation is not None: + assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet." + self.activation = activation + + if causal_conv1d_fn is None: + if use_fast_conv1d: + raise RuntimeError( + "Please either install `causal-conv1d>=1.4.0` to enable fast causal short convolution CUDA kernel " + "or set `use_fast_conv1d` to False" + ) + else: + warnings.warn( + "The naive Pytorch verison is very slow in practice, " + "please run `pip install causal-conv1d>=1.4.0` to install fast causal short convolution CUDA kernel" + ) + self.use_fast_conv1d = use_fast_conv1d + + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}') + if self.padding != (0,) * len(self.padding): + s += ', padding={padding}' + if self.dilation != (1,) * len(self.dilation): + s += ', dilation={dilation}' + if self.output_padding != (0,) * len(self.output_padding): + s += ', output_padding={output_padding}' + if self.groups != 1: + s += ', groups={groups}' + if self.bias is None: + s += ', bias=False' + if self.padding_mode != 'zeros': + s += ', padding_mode={padding_mode}' + if self.activation is not None: + s += ', activation={activation}' + if not self.use_fast_conv1d: + s += ', use_fast_conv1d={use_fast_conv1d}' + return s.format(**self.__dict__) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + cache: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Args: + x (`torch.Tensor`): + Tensor of shape `[batch_size, seq_len, hidden_size]` + mask (`Optional[torch.Tensor]`): + Attention mask dealing with padded positions. + cache (`Optional[torch.Tensor]`): + Previous cache tensor of shape `[batch_size, hidden_size, kernel_size]`, + Returns: + Tensor of shape `[batch_size, seq_len, hidden_size]`. The `cache` (if provided) is updated inplace. + """ + + if mask is not None: + x = x.mul_(mask.unsqueeze(-1)) + if cache is not None and x.shape[1] == 1: + return self.step(x, cache) + x = rearrange(x, "b l d -> b d l") + # Update state (B D W) + if cache is not None: + cache.copy_(F.pad(x, (self.kernel_size[0] - x.shape[-1], 0))) + if self.use_fast_conv1d: + x = causal_conv1d_fn( + x=x, + weight=rearrange(self.weight, "d 1 w -> d w"), + bias=self.bias, + activation=self.activation, + ) + else: + x = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]] + if self.activation is not None: + x = ACT2FN[self.activation](x) + return rearrange(x, "b d l -> b l d") + + def step( + self, + x: torch.Tensor, + cache: torch.Tensor + ): + assert x.shape[1] == 1, "Only support decoding with 1 token at a time for now" + + x = x.squeeze(1) + if self.use_fast_conv1d: + x = causal_conv1d_update( + x=x, + conv_state=cache, + weight=rearrange(self.weight, "d 1 w -> d w"), + bias=self.bias, + activation=self.activation, + ) + else: + dtype = x.dtype + cache.copy_(torch.roll(cache, shifts=-1, dims=-1)) + cache[:, :, -1] = x + x = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1) + if self.bias is not None: + x = x + self.bias + if self.activation is not None: + x = ACT2FN[self.activation](x).to(dtype=dtype) + return x.unsqueeze(1) + + @property + def state_size(self) -> int: + return self.hidden_size * self.kernel_size + + +class LongConvolution(nn.Module): + """ + LongConvolution applies a convolution operation on the input tensor using a fixed + filter of length l_max. + The filter is learned during training and is applied using FFT convolution. + Args: + hidden_size (int): The number of expected features in the input and output. + l_max (int): The maximum sequence length. + Returns: + y: (b, l, d) tensor + """ + + def __init__( + self, + hidden_size: int, + l_max: int, + **kwargs, + ): + """ + Initializes the LongConvolution module. + Args: + hidden_size (int): The number of expected features in the input and output. + l_max (int): The maximum sequence length. + """ + super().__init__() + self.hidden_size = hidden_size + self.filter = nn.Parameter(torch.randn(self.hidden_size, l_max), requires_grad=True) + + def forward(self, x: torch.Tensor, *args, **kwargs): + """ + Applies the LongConvolution operation on the input tensor. + Args: + x: (b, l, d) tensor + Returns: + y: (b, l, d) tensor + """ + x = x.transpose(1, 2) + y = fft_conv(x, self.filter, dropout_mask=None, gelu=False) + y = y.transpose(1, 2) + return y.to(dtype=x.dtype) + + +class PositionalEmbedding(nn.Module): + def __init__(self, emb_dim: int, seq_len: int, **kwargs): + """Complex exponential positional embeddings for implicit long convolution filters.""" + super().__init__() + + self.seq_len = seq_len + # The time embedding fed to the filteres is normalized so that t_f = 1 + t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1 + + if emb_dim > 1: + bands = (emb_dim - 1) // 2 + # To compute the right embeddings we use the "proper" linspace + t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] + w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1 + + f = torch.linspace(1e-4, bands - 1, bands)[None, None] + z = torch.exp(-1j * f * w) + z = torch.cat([t, z.real, z.imag], dim=-1) + self.z = nn.Parameter(z, requires_grad=False) + + def forward(self, L): + return self.z[:, :L] + + +class ImplicitLongConvolution(nn.Module): + """ + Long convolution with implicit filter parameterized by an MLP. + + Args: + hidden_size (int): + The number of expected features in the input and output. + l_max (int): + The maximum sequence length. + d_emb (Optional[int]): + The dimension of the positional embeddings. Must be odd and greater or equal to 3 (time, sine and cosine). + Defaults to 3. + d_hidden (Optional[int]): + The number of features in the hidden layer of the MLP. Defaults to 16. + + Attributes: + pos_emb (`PositionalEmbedding`): The positional embedding layer. + mlp (`nn.Sequential`): The MLP that parameterizes the implicit filter. + + """ + + def __init__( + self, + hidden_size: int, + l_max: int, + d_emb: int = 3, + d_hidden: int = 16, + **kwargs, + ): + """ + Long convolution with implicit filter parameterized by an MLP. + + + """ + super().__init__() + self.hidden_size = hidden_size + self.d_emb = d_emb + + assert ( + d_emb % 2 != 0 and d_emb >= 3 + ), "d_emb must be odd and greater or equal to 3 (time, sine and cosine)" + self.pos_emb = PositionalEmbedding(d_emb, l_max) + + # final linear layer + self.mlp = nn.Sequential( + nn.Linear(d_emb, d_hidden), + torch.nn.ReLU(), + nn.Linear(d_hidden, hidden_size), + ) + + def filter(self, seq_len: int, *args, **kwargs): + k = self.mlp(self.pos_emb(seq_len)) + + return k.transpose(1, 2) + + def forward(self, x: torch.Tensor, *args, **kwargs): + """ + Args: + x: (b, l, d) tensor + Returns: + y: (b, l, d) tensor + """ + x = x.transpose(1, 2) + k = self.filter(x.shape[-1]) + y = fft_conv(x, k, dropout_mask=None, gelu=False) + + y = y.transpose(1, 2) + return y.to(dtype=x.dtype) diff --git a/fla2/modules/feature_map.py b/fla2/modules/feature_map.py new file mode 100644 index 0000000000000000000000000000000000000000..e4dbd3c37e0f5e2955b112d5c4a677b479cc4f40 --- /dev/null +++ b/fla2/modules/feature_map.py @@ -0,0 +1,300 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ..modules.activations import fast_gelu_impl, sigmoid, sqrelu, swish +from ..modules.layernorm import layer_norm_fn +from ..utils import checkpoint + + +@checkpoint +def flatten_diag_outer_product(x, y): + z = torch.einsum("...i,...j->...ij", x, y) + N = z.size(-1) + indicies = torch.triu_indices(N, N) + return z[..., indicies[0], indicies[1]] + + +@checkpoint +def flatten_diag_outer_product_off1(x, y): + z = torch.einsum("...i,...j->...ij", x, y) + N = z.size(-1) + indicies = torch.triu_indices(N, N, 1) + indices2 = torch.arange(0, N) + return z[..., indicies[0], indicies[1]], z[..., indices2, indices2] + + +def is_power_of_2(n): + return (n & (n - 1) == 0) and n != 0 + + +class HedgehogFeatureMap(nn.Module): + + r""" + Hedgehog feature map as introduced in + `The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry `_ + """ + + def __init__( + self, + head_dim: int + ) -> HedgehogFeatureMap: + super().__init__() + # Trainable map + self.layer = nn.Linear(head_dim, head_dim) + self.init_weights_() + + def init_weights_(self): + """Initialize trainable map as identity""" + with torch.no_grad(): + identity = torch.eye(*self.layer.weight.shape[-2:], dtype=torch.float) + self.layer.weight.copy_(identity.to(self.layer.weight)) + nn.init.zeros_(self.layer.bias) + + def forward(self, x: torch.Tensor): + x = self.layer(x) # shape b, h, l, d + return torch.cat([2*x, -2*x], dim=-1).softmax(-1) + + +class T2RFeatureMap(nn.Module): + + r""" + Simple linear mapping feature map as in + `Finetuning Pretrained Transformers into RNNs `_ + """ + + def __init__( + self, + head_dim: int, + dot_dim: int = None, + bias: Optional[bool] = False + ) -> T2RFeatureMap: + super().__init__() + # Trainable map + if dot_dim is None: + dot_dim = head_dim + + self.head_dim = head_dim + self.dot_dim = dot_dim + self.bias = bias + + self.layer = nn.Linear(head_dim, dot_dim, bias=bias) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(head_dim={self.head_dim}, dot_dim={self.dot_dim}, bias={self.bias})" + + def forward(self, x: torch.Tensor): + return self.layer(x).relu() + + +class DPFPFeatureMap(nn.Module): + + r""" + Deterministic Parameter-Free Projection (DPFP) feature map in + `Linear Transformers Are Secretly Fast Weight Programmers `_ + """ + + def __init__( + self, + head_dim: int, + nu: int = 4 + ) -> DPFPFeatureMap: + super().__init__() + self.nu = nu + + def forward(self, x: torch.Tensor): + x = torch.cat([x.relu(), -x.relu()], dim=-1) + x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu+1)], dim=-1) + x_repeat = torch.cat([x] * self.nu, dim=-1) + return x_repeat * x_rolled + + +class HadamardFeatureMap(nn.Module): + def __init__( + self, + head_dim: int + ) -> HadamardFeatureMap: + super().__init__() + # Trainable map + self.layer1 = nn.Linear(head_dim, head_dim) + self.layer2 = nn.Linear(head_dim, head_dim) + + def forward(self, x: torch.Tensor): + return self.layer1(x) * self.layer2(x) + + +class LearnableOuterProductFeatureMap(nn.Module): + def __init__( + self, + head_dim: int, + feature_dim: int + ) -> LearnableOuterProductFeatureMap: + super().__init__() + # Trainable map + self.layer1 = nn.Linear(head_dim, feature_dim, bias=False) + self.layer2 = nn.Linear(head_dim, feature_dim, bias=False) + self.normalizer = feature_dim ** -0.5 + + def forward(self, x: torch.Tensor): + return flatten_diag_outer_product(self.layer1(x), self.layer2(x)) + + +class LearnablePolySketchNonNegativeFeatureMap(nn.Module): + + def __init__( + self, + head_dim: int, + sketch_size: Optional[int] = None, + degree: Optional[int] = 2 + ) -> LearnablePolySketchNonNegativeFeatureMap: + super().__init__() + + assert is_power_of_2(degree) and degree >= 2, f"The degree {degree} must be a power of 2" + + self.head_dim = head_dim + self.sketch_size = sketch_size if sketch_size is not None else head_dim + self.degree = degree + + self.gamma = nn.Parameter(torch.ones(head_dim)) + self.beta = nn.Parameter(torch.zeros(head_dim)) + # NOTE: the sketch layers defined here are quite different from the original paper + # currently we simply use linear layers without any non-linear activations + self.sketches1 = nn.ModuleList([ + nn.Linear(head_dim, sketch_size, bias=False), + *[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)] + ]) + self.sketches2 = nn.ModuleList([ + nn.Linear(head_dim, sketch_size, bias=False), + *[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)] + ]) + + def forward(self, x: torch.Tensor): + # Section 2.1 + x = layer_norm_fn(x, self.gamma, self.beta) + # first map the input to sketch size with learnable parameters + x = self.sketches1[0](x) * self.sketches2[0](x) * self.head_dim ** -0.5 + for i in range(1, int(math.log2(self.degree)) - 1): + x = self.sketches1[i](x) * self.sketches2[i](x) * self.head_dim ** -0.5 + # do sketch mapping for log2(p) - 1 times in total + # do p=2 mapping to ensure non-negativity + return flatten_diag_outer_product(x, x) + + +class TaylorFeatureMap(nn.Module): + def __init__( + self, + head_dim: int + ) -> TaylorFeatureMap: + super().__init__() + self.head_dim = head_dim + self.r2 = math.sqrt(2) + self.rd = math.sqrt(self.head_dim) + self.rrd = math.sqrt(self.rd) + + def forward(self, x: torch.Tensor): + x2_1, x2_2 = flatten_diag_outer_product_off1(x, x) + return torch.cat([torch.ones_like(x[..., 0:1]), x / self.rrd, x2_2 / (self.rd * self.r2), x2_1 / self.rd], dim=-1) + + +class RebasedFeatureMap(nn.Module): + + def __init__( + self, + head_dim: int, + use_gamma: Optional[bool] = True, + use_beta: Optional[bool] = True, + normalize: Optional[bool] = True + ) -> RebasedFeatureMap: + super().__init__() + + self.head_dim = head_dim + self.use_gamma = use_gamma + self.use_beta = use_beta + self.normalize = normalize + + self.gamma = None + self.beta = None + if use_gamma: + self.gamma = nn.Parameter(torch.ones(head_dim)) + if use_beta: + self.beta = nn.Parameter(torch.zeros(head_dim)) + + def forward(self, x: torch.Tensor, flatten: Optional[bool] = True): + if self.use_beta and self.use_gamma and self.normalize: + x = layer_norm_fn(x, self.gamma, self.beta) + elif self.normalize: + x = F.layer_norm(x, (self.head_dim,), self.gamma, self.beta) + elif self.use_gamma and self.use_beta: + x = torch.addcmul(self.beta, x, self.gamma) + elif self.use_gamma: + x = x.mul(self.gamma) + else: + raise RuntimeError(f"Not supported combination of `use_gamma`, `use_beta` and `normalize`, " + f"which is currentlt set as (`{self.use_gamma}`, `{self.use_beta}`, `{self.normalize}`)") + if not flatten: + return x + x2_1, x2_2 = flatten_diag_outer_product_off1(x, x) + # rebased use learnable parameters to approximate any quadratic function + return torch.cat([x2_2 * self.head_dim ** -0.5, x2_1 * (2 / self.head_dim) ** 0.5], dim=-1) + + +class ReLUFeatureMap(nn.Module): + + def __init__( + self, + ) -> ReLUFeatureMap: + super().__init__() + + def forward(self, x: torch.Tensor): + return F.relu(x) + + +class SquaredReLUFeatureMap(nn.Module): + + def __init__( + self, + ) -> SquaredReLUFeatureMap: + super().__init__() + + def forward(self, x: torch.Tensor): + return sqrelu(x) + + +class GELUFeatureMap(nn.Module): + + def __init__( + self, + ) -> GELUFeatureMap: + super().__init__() + + def forward(self, x: torch.Tensor): + return fast_gelu_impl(x) + + +class SwishFeatureMap(nn.Module): + + def __init__( + self, + ) -> SwishFeatureMap: + super().__init__() + + def forward(self, x: torch.Tensor): + return swish(x) + + +class SigmoidFeatureMap(nn.Module): + + def __init__( + self, + ) -> SigmoidFeatureMap: + super().__init__() + + def forward(self, x: torch.Tensor): + return sigmoid(x) diff --git a/fla2/modules/fused_bitlinear.py b/fla2/modules/fused_bitlinear.py new file mode 100644 index 0000000000000000000000000000000000000000..a3350420612046faae92663f10cb5cef22783a0c --- /dev/null +++ b/fla2/modules/fused_bitlinear.py @@ -0,0 +1,575 @@ +# -*- coding: utf-8 -*- + +# Implementations of BitLinear layer with fused LayerNorm and quantized Linear layer. +# [The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764) +# [Scalable MatMul-free Language Modeling](https://arxiv.org/abs/2406.02528) + +# Code adapted from https://github.com/ridgerchu/matmulfreellm/ + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +from ..modules.layernorm import RMSNorm +from ..utils import contiguous + + +def activation_quant(x): + """ + Per-token quantization to 8 bits. No grouping is needed for quantization. + + Args: + x: An activation tensor with shape [n, d]. + + Returns: + A quantized activation tensor with shape [n, d]. + """ + # Compute the scale factor + scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) + # Quantize and then de-quantize the tensor + y = (x * scale).round().clamp_(-128, 127) / scale + return y + + +def weight_quant(w): + """ + Per-tensor quantization to 1.58 bits. No grouping is needed for quantization. + + Args: + w: A weight tensor with shape [d, k]. + + Returns: + A quantized weight tensor with shape [d, k]. + """ + # Compute the scale factor + scale = 1.0 / w.abs().mean().clamp_(min=1e-5) + # Quantize and then de-quantize the tensor + u = (w * scale).round().clamp_(-1, 1) / scale + return u + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.jit +def _layer_norm_fwd_quant_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + RESIDUAL_OUT, # pointer to the residual + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + N, # number of columns in X + eps, # epsilon to avoid division by zero + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + + y = x_hat * w if HAS_WEIGHT else x_hat + if HAS_BIAS: + y = y + b + + # Aply quantization to the output + scale = 127.0 / tl.maximum(tl.max(tl.abs(y), 0), 1e-5) + # Quantize and then de-quantize the tensor + y = tl.math.round(y * scale) + y = tl.maximum(tl.minimum(y, 127), -128) / scale + + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd_quant( + x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False +): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + # allocate output + y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): + residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + with torch.cuda.device(x.device.index): + _layer_norm_fwd_quant_kernel[(M,)]( + x, + y, + weight, + bias, + residual, + residual_out, + mean, + rstd, + x.stride(0), + y.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + N, + eps, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + weight is not None, + bias is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype + return y, mean, rstd, residual_out if residual_out is not None else x + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + DRESIDUAL_IN, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w if HAS_WEIGHT else xhat + if HAS_BIAS: + y = y + b + + # Aply quantization to the output + scale = 127.0 / tl.maximum(tl.max(tl.abs(y), 0), 1e-5) + # Quantize and then de-quantize the tensor + y = tl.math.round(y * scale) + y = tl.maximum(tl.minimum(y, 127), -128) / scale + + tl.store(Y + cols, y, mask=mask) + wdy = dy + if HAS_WEIGHT: + wdy = dy * w + dw += dy * xhat + if HAS_BIAS: + db += dy + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + if HAS_WEIGHT: + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + has_residual=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, +): + M, N = x.shape + # allocate output + dx = torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device) + dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) if weight is not None else None + _db = torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) if bias is not None else None + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + dresidual_in, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + weight is not None, + bias is not None, + ) + dw = _dw.sum(0).to(weight.dtype) if weight is not None else None + db = _db.sum(0).to(bias.dtype) if bias is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype: + dresidual_in = dx + return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) + + +class LayerNormLinearQuantFn(torch.autograd.Function): + + @staticmethod + @contiguous + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + residual_dtype = residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None) + y, mean, rstd, residual_out = _layer_norm_fwd_quant( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + linear_weight = weight_quant(linear_weight).to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @contiguous + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_linear_quant_fn( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormLinearQuantFn.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + ) + + +class BitLinear(nn.Linear): + """ + A custom linear layer that applies quantization on both activations and weights. + This is primarily for training; kernel optimization is needed for efficiency in deployment. + """ + + def __init__(self, in_features, out_features, bias=False): + """ + Initializes the BitLinear layer. + + Args: + in_features: Size of each input sample. + out_features: Size of each output sample. + bias: If set to False, the layer will not learn an additive bias. Default: True. + """ + # Initialize the superclass nn.Linear with the given parameters + super(BitLinear, self).__init__(in_features, out_features, bias=bias) + + self.norm = RMSNorm(in_features, eps=1e-8) + + def forward(self, x): + """ + Overrides the forward pass to include quantization. + + Args: + x: An input tensor with shape [n, d]. + + Returns: + An output tensor with shape [n, d]. + """ + # Weight tensor + w = self.weight + + # Apply RMS normalization to the input + x_norm = self.norm(x) + + # Apply quantization to both activations and weights + # Uses Straight-Through Estimator (STE) trick with .detach() for gradient flow + x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach() + w_quant = w + (weight_quant(w) - w).detach() + # Perform linear operation with quantized values + y = F.linear(x_quant, w_quant) + + return y + + +class FusedBitLinear(BitLinear): + """ + A custom linear layer that applies quantization on both activations and weights. + This is primarily for training; kernel optimization is needed for efficiency in deployment. + """ + + def __init__(self, in_features, out_features, bias=False): + """ + Initializes the BitLinear layer. + + Args: + in_features: Size of each input sample. + out_features: Size of each output sample. + bias: If set to False, the layer will not learn an additive bias. Default: True. + """ + # Initialize the superclass nn.Linear with the given parameters + super(FusedBitLinear, self).__init__(in_features, out_features, bias=bias) + + def forward(self, x): + return layer_norm_linear_quant_fn( + x, + self.norm.weight, + self.norm.bias, + self.weight, + self.bias, + is_rms_norm=True + ) diff --git a/fla2/modules/fused_cross_entropy.py b/fla2/modules/fused_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..3364680d414d31608b0a77204d62e4118ea80ee3 --- /dev/null +++ b/fla2/modules/fused_cross_entropy.py @@ -0,0 +1,398 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Tri Dao. + +from typing import Tuple + +import torch +import torch.nn as nn +import triton +import triton.language as tl + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 2 lines are for backward compatibility with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base + + +@triton.heuristics( + { + "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, + } +) +@triton.jit +def cross_entropy_fwd_kernel( + loss_ptr, # data ptrs + lse_ptr, + z_loss_ptr, + logits_ptr, + labels_ptr, + smoothing, + logit_scale, + lse_square_scale, + ignored_index, + total_classes, + class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes + n_cols, # shapes + n_rows, + logits_row_stride, # strides + BLOCK_SIZE: tl.constexpr, + HAS_SMOOTHING: tl.constexpr, + # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE + SPLIT: tl.constexpr, +): + row_idx = tl.program_id(0) + col_block_idx = tl.program_id(1) + logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) + col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + label_idx = tl.load(labels_ptr + row_idx) + logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( + tl.float32 + ) * logit_scale + max_logits = tl.max(logits, 0) + if HAS_SMOOTHING: + sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0) + lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits + tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse) + if label_idx == ignored_index: + loss = 0.0 + z_loss = 0.0 + else: + label_idx -= class_start_idx + if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min( + n_cols, (col_block_idx + 1) * BLOCK_SIZE + ): + logits_label = tl.load(logits_ptr + label_idx) * logit_scale + if HAS_SMOOTHING: + loss = ( + (lse if not SPLIT else 0.0) + - smoothing * sum_logits / total_classes + - (1 - smoothing) * logits_label + ) + else: + loss = (lse if not SPLIT else 0.0) - logits_label + else: + # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss + if HAS_SMOOTHING: + loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes) + else: + loss = 0.0 + if not SPLIT: + z_loss = lse_square_scale * lse * lse + loss += z_loss + else: + z_loss = 0.0 + tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss) + if not SPLIT: + tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss) + + +@triton.heuristics( + { + "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, + } +) +@triton.jit +def cross_entropy_bwd_kernel( + dlogits_ptr, # data ptrs + dloss_ptr, + logits_ptr, + lse_ptr, + labels_ptr, + smoothing, + logit_scale, + lse_square_scale, + ignored_index, + total_classes, + class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes + n_cols, # shapes + logits_row_stride, # strides + dlogits_row_stride, + dloss_row_stride, + BLOCK_SIZE: tl.constexpr, + HAS_SMOOTHING: tl.constexpr, +): + row_idx = tl.program_id(0) + col_block_idx = tl.program_id(1) + logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) + dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) + col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + label_idx = tl.load(labels_ptr + row_idx) + if label_idx != ignored_index: + dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) + else: + dloss = 0.0 + logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( + tl.float32 + ) * logit_scale + lse = tl.load(lse_ptr + row_idx) + probs = tl.exp(logits - lse) + probs += 2.0 * lse_square_scale * lse * probs + label_idx -= class_start_idx + if HAS_SMOOTHING: + smooth_negative = smoothing / total_classes + probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative + else: + probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) + tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) + + +class CrossEntropyLossFunction(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + logits, + labels, + smoothing=0.0, + logit_scale=1.0, + lse_square_scale=0.0, + ignored_index=-100, + inplace_backward=False, + process_group=None, + ): + n_rows, n_cols = logits.shape + assert labels.shape == (n_rows,) + world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) + total_classes = world_size * n_cols + rank = 0 if process_group is None else torch.distributed.get_rank(process_group) + class_start_idx = rank * n_cols + + if logits.stride(-1) != 1: + logits = logits.contiguous() + # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py + MAX_BLOCK_SIZE = 64 * 1024 + BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) + num_warps = ( + 4 + if BLOCK_SIZE < 2048 + else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) + ) + # We may split the lse computation across multiple blocks, then do a reduction + # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k) + # where having just one thread block processing more than 64k elements is slow. + split = world_size > 1 or n_cols > MAX_BLOCK_SIZE + n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE + loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,) + losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) + lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) + z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(logits.device.index): + cross_entropy_fwd_kernel[(n_rows, n_splits)]( + losses, # data ptrs + lse, + z_losses, + logits, + labels, + smoothing, + logit_scale, + lse_square_scale, + ignored_index, + total_classes, + class_start_idx, + n_cols, # shapes + n_rows, + logits.stride(0), # strides + BLOCK_SIZE=BLOCK_SIZE, # constants + num_warps=num_warps, + SPLIT=split, + ) + + if split: + # If there's no smoothing, if labels are in the vocab of this partition, losses contains + # - predicted logit, and 0 otherwise. + # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains + # -0.9 * predicted logit - 0.1 * sum logit / total_classes. + # For labels not in the vocab of this partition, losses contains + # -0.1 * sum logit / total_classes. + if n_splits > 1: + lse = torch.logsumexp(lse, dim=0) + losses = losses.sum(dim=0) + if world_size > 1: + lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) + torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) + handle_losses = torch.distributed.all_reduce( + losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True + ) + lse = torch.logsumexp(lse_allgather, dim=0) + handle_losses.wait() + # After the allreduce, if there's no smoothing, the total losses are - predicted_logit, + # we just have to add the (global) lse. + # If there's smoothing=0.1, the total losses are + # -0.9 * predicted_logit - 0.1 * sum logit / total_classes. + # Again, we just have to add the (global) lse. + losses += lse + if lse_square_scale != 0.0: + z_losses = lse_square_scale * lse.square() + z_losses.masked_fill_(labels == ignored_index, 0.0) + losses += z_losses + else: + z_losses = torch.zeros_like(losses) + losses.masked_fill_(labels == ignored_index, 0.0) + + ctx.save_for_backward(logits, lse, labels) + ctx.mark_non_differentiable(z_losses) + ctx.smoothing = smoothing + ctx.logit_scale = logit_scale + ctx.lse_square_scale = lse_square_scale + ctx.ignored_index = ignored_index + ctx.total_classes = total_classes + ctx.class_start_idx = class_start_idx + ctx.inplace_backward = inplace_backward + + return losses, z_losses + + @staticmethod + def backward(ctx, grad_losses, grad_z_losses): + del grad_z_losses # z_losses are only for logging. + + logits, lse, labels = ctx.saved_tensors + dlogits = logits if ctx.inplace_backward else torch.empty_like(logits) + n_rows, n_cols = logits.shape + BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) + num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) + def grid(META): return (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(logits.device.index): + cross_entropy_bwd_kernel[grid]( + dlogits, # data ptrs + grad_losses, + logits, + lse, + labels, + ctx.smoothing, + ctx.logit_scale, + ctx.lse_square_scale, + ctx.ignored_index, + ctx.total_classes, + ctx.class_start_idx, + n_cols, # shapes + logits.stride(0), # strides + dlogits.stride(0), + grad_losses.stride(0), + BLOCK_SIZE=BLOCK_SIZE, # constants + num_warps=num_warps, + ) + return dlogits, None, None, None, None, None, None, None, None + + +def cross_entropy_loss( + logits: torch.Tensor, + labels: torch.Tensor, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + lse_square_scale: float = 0.0, + ignored_index=-100, + inplace_backward: bool = False, + process_group=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + logits: (batch, vocab_size) + labels: (batch,) + label_smoothing: float + logit_scale: float. Multiply logits by this scale before calculating the loss. + lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. + This is also referred to as "z-loss". + ignored_index: int. If labels == ignored_index, the loss is set to 0.0. + inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. + This saves memory. + process_group: if not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss will be aggregated across processes. + Returns: + losses: (batch,), float + z_losses: (batch,), float + """ + return CrossEntropyLossFunction.apply( + logits, + labels, + label_smoothing, + logit_scale, + lse_square_scale, + ignored_index, + inplace_backward, + process_group, + ) + + +class FusedCrossEntropyLoss(nn.Module): + def __init__( + self, + ignore_index=-100, + reduction="mean", + label_smoothing=0.0, + logit_scale=1.0, + lse_square_scale=0.0, + inplace_backward=False, + process_group=None, + return_z_loss=False, + ): + """ + Arguments: + ignored_index: int. If labels == ignored_index, the loss is set to 0.0. + label_smoothing: float + lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. + This is also referred to as "z-loss". + inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. + This saves memory. + process_group: if not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss will be aggregated across processes. + return_z_loss: bool. If True, we return the component of the loss contributed by + the lse_square_scale value. This value is only for logging and does not support + backprop. + """ + super().__init__() + if reduction not in ["mean", "none", "sum"]: + raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.logit_scale = logit_scale + self.lse_square_scale = lse_square_scale + self.inplace_backward = inplace_backward + self.process_group = process_group + self.return_z_loss = return_z_loss + + def forward(self, input, target): + """ + Arguments: + input: (batch, vocab_size) + target: (batch,) + Returns: + losses: (batch,) if reduction is 'none', else (1,), dtype float + z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss) + """ + assert input.is_cuda and target.is_cuda, "Only support CUDA tensors" + loss, z_loss = cross_entropy_loss( + input, + target, + label_smoothing=self.label_smoothing, + logit_scale=self.logit_scale, + lse_square_scale=self.lse_square_scale, + ignored_index=self.ignore_index, + inplace_backward=self.inplace_backward, + process_group=self.process_group, + ) + if self.reduction == "mean": + loss = loss.sum() / (target != self.ignore_index).sum() + elif self.reduction == "sum": + loss = loss.sum() + else: + loss = loss + + if not self.return_z_loss: + return loss + + if self.reduction == "mean": + z_loss = z_loss.sum() / (target != self.ignore_index).sum() + elif self.reduction == "sum": + z_loss = z_loss.sum() + else: + z_loss = z_loss + + return loss, z_loss diff --git a/fla2/modules/fused_norm_gate.py b/fla2/modules/fused_norm_gate.py new file mode 100644 index 0000000000000000000000000000000000000000..6db7c0791eed746222ac20f5965a91e1f4f2d1b2 --- /dev/null +++ b/fla2/modules/fused_norm_gate.py @@ -0,0 +1,889 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Tri Dao. +# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py +# Implement residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +from ..utils import contiguous + + +def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( + dtype + ) + return out if not prenorm else (out, x) + + +def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + \ + bias if bias is not None else (x * rstd * weight) + out = out.to(dtype) + return out if not prenorm else (out, x) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + O, # pointer to the gate + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + RESIDUAL_OUT, # pointer to the residual + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + N, # number of columns in X + eps, # epsilon to avoid division by zero + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + O += row * stride_x_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < + N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w if HAS_WEIGHT else x_hat + if HAS_BIAS: + y = y + b + + # Swish output gate + o = tl.load(O + cols, mask=cols < N, other=0.0).to(tl.float32) + y = y * o * tl.sigmoid(o) + + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd( + x, o, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False +): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + if weight is not None: + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + assert y.stride(-1) == 1 + if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): + residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, + device="cuda") if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError( + "This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + o, + y, + weight, + bias, + residual, + residual_out, + mean, + rstd, + x.stride(0), + y.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + N, + eps, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + weight is not None, + bias is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype + return y, mean, rstd, residual_out if residual_out is not None else x + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + O, # pointer to the gate + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DO, # pointer to the gate gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + DRESIDUAL_IN, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + O += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + DO += row_start * stride_dx_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + o = tl.load(O + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + + y = xhat * w if HAS_WEIGHT else xhat + if HAS_BIAS: + y = y + b + if RECOMPUTE_OUTPUT: + tl.store(Y + cols, y, mask=mask) + + sigmoid_o = tl.sigmoid(o) + do = dy * y * (sigmoid_o + o * sigmoid_o * (1 - sigmoid_o)) + dy = dy * o * sigmoid_o + wdy = dy + if HAS_WEIGHT: + wdy = dy * w + dw += dy * xhat + if HAS_BIAS: + db += dy + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + tl.store(DX + cols, dx, mask=mask) + tl.store(DO + cols, do, mask=mask) + + X += stride_x_row + O += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + DO += stride_dx_row + if HAS_WEIGHT: + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + o, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + has_residual=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, +): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + if weight is not None: + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + dx = ( + torch.empty_like(x) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + do = ( + torch.empty_like(o) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = ( + torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + if weight is not None + else None + ) + _db = ( + torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) + if bias is not None + else None + ) + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + o, + weight, + bias, + y, + dy, + dx, + do, + _dw, + _db, + dresidual, + dresidual_in, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + weight is not None, + bias is not None, + ) + dw = _dw.sum(0).to(weight.dtype) if weight is not None else None + db = _db.sum(0).to(bias.dtype) if bias is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype: + dresidual_in = dx + return (dx, do, dw, db, dresidual_in) if not recompute_output else (dx, do, dw, db, dresidual_in, y) + + +class LayerNormSwishGateFn(torch.autograd.Function): + + @staticmethod + @contiguous + def forward( + ctx, + x, + o, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + o_shape_og = o.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + o = o.reshape(-1, o.shape[-1]) + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, o, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm + ) + ctx.save_for_backward(residual_out, o, weight, bias, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.o_shape_og = o_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + y = y.reshape(x_shape_og) + return y if not prenorm else (y, residual_out.reshape(x_shape_og)) + + @staticmethod + @contiguous + def backward(ctx, dy, *args): + x, o, weight, bias, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + assert dresidual.shape == x.shape + else: + dresidual = None + dx, do, dw, db, dresidual_in = _layer_norm_bwd( + dy, + x, + o, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + do.reshape(ctx.o_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +class LayerNormSwishGateLinearFn(torch.autograd.Function): + + @staticmethod + @contiguous + def forward( + ctx, + x, + o, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + o_shape_og = o.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + o = o.reshape(-1, o.shape[-1]) + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, + o, + norm_weight, + norm_bias, + eps, + residual, + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, o, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.o_shape_og = o_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @contiguous + def backward(ctx, dout, *args): + x, o, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + assert dresidual.shape == x.shape + else: + dresidual = None + dx, do, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( + dy, + x, + o, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual=dresidual, + has_residual=ctx.has_residual, + is_rms_norm=ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + do.reshape(ctx.o_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_swish_gate_fn( + x, + o, + weight, + bias, + residual=None, + prenorm=False, + residual_in_fp32=False, + eps=1e-6 +): + return LayerNormSwishGateFn.apply( + x, + o, + weight, + bias, + residual, + eps, + prenorm, + residual_in_fp32, + False + ) + + +def rms_norm_swish_gate_fn( + x, + o, + weight, + bias, + residual=None, + prenorm=False, + residual_in_fp32=False, + eps=1e-6 +): + return LayerNormSwishGateFn.apply( + x, + o, + weight, + bias, + residual, + eps, + prenorm, + residual_in_fp32, + True + ) + + +def layer_norm_swish_gate_linear_fn( + x, + o, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + prenorm=False, + residual_in_fp32=False, + eps=1e-6 +): + return LayerNormSwishGateLinearFn.apply( + x, + o, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + False + ) + + +def rms_norm_swish_gate_linear_fn( + x, + o, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + prenorm=False, + residual_in_fp32=False, + eps=1e-6 +): + return LayerNormSwishGateLinearFn.apply( + x, + o, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + True + ) + + +class FusedLayerNormSwishGate(nn.Module): + + def __init__( + self, + hidden_size, + elementwise_affine: bool = True, + eps=1e-5 + ) -> FusedLayerNormSwishGate: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, o, residual=None, prenorm=False, residual_in_fp32=False): + return layer_norm_swish_gate_fn( + x, + o, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32 + ) + + +class FusedRMSNormSwishGate(nn.Module): + + def __init__( + self, + hidden_size, + elementwise_affine: bool = True, + eps=1e-5 + ) -> FusedRMSNormSwishGate: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, o, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_swish_gate_fn( + x, + o, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32 + ) + + +class FusedLayerNormSwishGateLinear(nn.Module): + + def __init__( + self, + hidden_size, + elementwise_affine: bool = True, + eps=1e-5 + ) -> FusedLayerNormSwishGateLinear: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, o, weight, bias, residual=None, prenorm=False, residual_in_fp32=False): + return layer_norm_swish_gate_linear_fn( + x, + o, + self.weight, + self.bias, + weight, + bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32 + ) + + +class FusedRMSNormSwishGateLinear(nn.Module): + + def __init__( + self, + hidden_size, + elementwise_affine: bool = True, + eps=1e-5 + ) -> FusedRMSNormSwishGateLinear: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, o, weight, bias, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_swish_gate_linear_fn( + x, + o, + self.weight, + self.bias, + weight, + bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32 + ) diff --git a/fla2/modules/l2norm.py b/fla2/modules/l2norm.py new file mode 100644 index 0000000000000000000000000000000000000000..34257f80778b6551b25343ea90f638258bf9d208 --- /dev/null +++ b/fla2/modules/l2norm.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + stride_x_row, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_x_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) + rstd = 1 / tl.sqrt(var + eps) + # tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + y = x * rstd + # Write output + tl.store(Y + cols, y, mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +# @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _l2_norm_bwd_kernel( + X, # pointer to the input + # Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + stride_x_row, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + DX += row * stride_x_row + DY += row * stride_x_row + + # Y += row * stride_y_row + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x, 0.0) + var = tl.sum(x * x) + rstd = 1 / tl.sqrt(var + eps) + # tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + # y = x * rstd + dy = tl.load(DY + cols, mask=cols < N, other=0.0).to(tl.float32) + dy = tl.where(cols < N, dy, 0.0) + # dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x + dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x + tl.store(DX + cols, dx, mask=mask) + + +def _l2_norm_fwd( + x, eps=1e-6 +): + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + M, N = x.shape + assert x.stride(-1) == 1 + # allocate output + y = torch.empty_like(x) + assert y.stride(-1) == 1 + N = x.shape[-1] + M = x.shape[0] + # rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError( + "This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + with torch.cuda.device(x.device.index): + _l2_norm_fwd_1pass_kernel[(M,)]( + x, + y, + x.stride(0), + N, + eps, + # is_rms_norm, + BLOCK_N, + # residual is not None, + # residual_out is not None, + # bias is not None, + ) + return y.reshape(x_shape_og) + + +def _l2_norm_bwd( + x, dy, eps=1e-5, +): + x_shape_og = x.shape + x = x.reshape(-1, dy.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + # allocate output + dx = torch.empty_like(x) + N = x.shape[-1] + M = x.shape[0] + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + # rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError( + "This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + with torch.cuda.device(x.device.index): + _l2_norm_bwd_kernel[(M,)]( + x, + dy, + dx, + x.stride(0), + N, + eps, + BLOCK_N, + ) + return dx.reshape(x_shape_og) + + +class L2NormFN(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + x, + eps=1e-6, + ): + # reshape input data into 2D tensor + y = _l2_norm_fwd(x, eps) + ctx.eps = eps + ctx.x_dtype = x.dtype + ctx.save_for_backward(x) + return y + + @staticmethod + def backward(ctx, dy, *args): + x, = ctx.saved_tensors + dx = _l2_norm_bwd( + x, + dy, + ctx.eps, + ) + return ( + dx, + None + ) + + +l2_norm_fn = L2NormFN.apply diff --git a/fla2/modules/layernorm.py b/fla2/modules/layernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..bedc9c5a5cfe55367796feb3672f40af8a48ff8e --- /dev/null +++ b/fla2/modules/layernorm.py @@ -0,0 +1,940 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Tri Dao. +# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py +# Implement residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +from ..utils import contiguous + + +def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( + dtype + ) + return out if not prenorm else (out, x) + + +def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) + out = out.to(dtype) + return out if not prenorm else (out, x) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + RESIDUAL_OUT, # pointer to the residual + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + N, # number of columns in X + G, # number of groups + eps, # epsilon to avoid division by zero + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = row % G + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + if HAS_WEIGHT: + w = tl.load(W + group * stride_x_row + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + group * stride_x_row + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + + y = x_hat * w if HAS_WEIGHT else x_hat + if HAS_BIAS: + y = y + b + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd( + x, + weight, + bias, + eps, + residual=None, + out_dtype=None, + residual_dtype=None, + is_rms_norm=False, + num_groups=1 +): + if residual is not None: + residual_dtype = residual.dtype + M, N, G = *x.shape, num_groups + if residual is not None: + assert residual.shape == (M, N) + if weight is not None: + assert weight.shape == (G * N,) + if bias is not None: + assert bias.shape == (G * N,) + # allocate output + y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): + residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + y, + weight, + bias, + residual, + residual_out, + mean, + rstd, + x.stride(0), + y.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + N, + G, + eps, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + weight is not None, + bias is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype + return y, mean, rstd, residual_out if residual_out is not None else x + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + DRESIDUAL_IN, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + G, # number of groups + rows_per_program, + programs_per_group, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + row_block_id = tl.program_id(0) + group_id, program_id_in_group = row_block_id // programs_per_group, row_block_id % programs_per_group + + row_start = group_id + program_id_in_group * G * rows_per_program + row_end = min(row_start + G * rows_per_program, M) + + cols = tl.arange(0, BLOCK_N) + mask = cols < N + + if HAS_WEIGHT: + w = tl.load(W + group_id * stride_x_row + cols, mask=mask).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + group_id * stride_x_row + cols, mask=mask, other=0.0).to(tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for row in range(row_start, row_end, G): + # Load data to SRAM + x = tl.load(X + row * stride_x_row + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + row * stride_dy_row + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w if HAS_WEIGHT else xhat + if HAS_BIAS: + y = y + b + tl.store(Y + row * stride_y_row + cols, y, mask=mask) + wdy = dy + if HAS_WEIGHT: + wdy = dy * w + dw += dy * xhat + if HAS_BIAS: + db += dy + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + row * stride_dres_row + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + row * stride_dres_in_row + cols, dx, mask=mask) + tl.store(DX + row * stride_dx_row + cols, dx, mask=mask) + + if HAS_WEIGHT: + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + has_residual=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, + num_groups=1 +): + M, N, G = *x.shape, num_groups + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.shape == (M, N) + if weight is not None: + assert weight.shape == (G * N,) + if bias is not None: + assert bias.shape == (G * N,) + # allocate output + dx = torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device) + dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # each program handles one group only + S = triton.cdiv(torch.cuda.get_device_properties(x.device).multi_processor_count, G) * G + dw = torch.empty((S, N), dtype=torch.float32, device=weight.device) if weight is not None else None + db = torch.empty((S, N), dtype=torch.float32, device=bias.device) if bias is not None else None + rows_per_program = triton.cdiv(M, S) + programs_per_group = S // G + grid = (S,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + dw, + db, + dresidual, + dresidual_in, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + G, + rows_per_program, + programs_per_group, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + weight is not None, + bias is not None, + ) + dw = dw.view(G, -1, N).sum(1).to(weight).view_as(weight) if weight is not None else None + db = db.view(G, -1, N).sum(1).to(bias).view_as(bias) if bias is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype: + dresidual_in = dx + return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) + + +class LayerNormFn(torch.autograd.Function): + + @staticmethod + @contiguous + def forward( + ctx, + x, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + num_groups=1 + ): + x_shape_og = x.shape + + if x.shape[-1] % num_groups != 0: + raise ValueError('num_channels must be divisible by num_groups') + # reshape input data into 2D tensor + x = x.reshape(-1, (x.shape[-1] // num_groups)) + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape_as(x) + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, weight, bias, eps, residual, + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + num_groups=num_groups + ) + ctx.save_for_backward(residual_out, weight, bias, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.num_groups = num_groups + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + y = y.reshape(x_shape_og) + return y if not prenorm else (y, residual_out.reshape(x_shape_og)) + + @staticmethod + @contiguous + def backward(ctx, dy, *args): + x, weight, bias, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, (dy.shape[-1] // ctx.num_groups)) + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, x.shape[-1]) + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dw, db, dresidual_in = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + num_groups=ctx.num_groups + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + None + ) + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm + ) + + +def group_norm_fn( + x, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + num_groups=1 +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + num_groups + ) + + +def rms_norm_fn( + x, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + eps, + prenorm, + residual_in_fp32, + True + ) + + +class LayerNorm(nn.Module): + + def __init__( + self, + hidden_size: int, + elementwise_affine: bool = True, + bias: bool = False, + eps: float = 1e-5 + ) -> LayerNorm: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + if bias: + self.bias = nn.Parameter(torch.zeros(hidden_size)) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return layer_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32 + ) + + +class GroupNorm(nn.Module): + + def __init__( + self, + num_groups: int, + hidden_size: int, + elementwise_affine: bool = True, + bias: bool = False, + eps: float = 1e-5 + ) -> GroupNorm: + super().__init__() + + if hidden_size % num_groups != 0: + raise ValueError('num_channels must be divisible by num_groups') + + self.num_groups = num_groups + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + if bias: + self.bias = nn.Parameter(torch.zeros(hidden_size)) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.num_groups}, {self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return group_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + num_groups=self.num_groups + ) + + +class RMSNorm(nn.Module): + + def __init__( + self, + hidden_size: int, + elementwise_affine: bool = True, + bias: bool = False, + eps: float = 1e-5 + ) -> RMSNorm: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + if bias: + self.bias = nn.Parameter(torch.zeros(hidden_size)) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + ) + + +class LayerNormLinearFn(torch.autograd.Function): + + @staticmethod + @contiguous + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + num_groups=1 + ): + x_shape_og = x.shape + + if x.shape[-1] % num_groups != 0: + raise ValueError('num_channels must be divisible by num_groups') + # reshape input data into 2D tensor + x = x.reshape(-1, (x.shape[-1] // num_groups)) + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape_as(x) + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + num_groups=num_groups + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.num_groups = num_groups + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @contiguous + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dy = dy.reshape(-1, (dy.shape[-1] // ctx.num_groups)) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, x.shape[-1]) + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + num_groups=ctx.num_groups + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y.view(-1, linear_weight.shape[-1])) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + None + ) + + +def layer_norm_linear_fn( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + num_groups=1 +): + return LayerNormLinearFn.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + num_groups + ) + + +class LayerNormLinear(nn.Module): + + def __init__( + self, + hidden_size, + elementwise_affine: bool = True, + bias: bool = False, + eps: float = 1e-5 + ) -> LayerNormLinear: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + if bias: + self.bias = nn.Parameter(torch.zeros(hidden_size)) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False): + return layer_norm_linear_fn( + x, + self.weight, + self.bias, + weight, + bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=False + ) + + +class GroupNormLinear(nn.Module): + + def __init__( + self, + num_groups: int, + hidden_size: int, + elementwise_affine: bool = True, + bias: bool = False, + eps: float = 1e-5 + ) -> GroupNormLinear: + super().__init__() + + if hidden_size % num_groups != 0: + raise ValueError('num_channels must be divisible by num_groups') + + self.num_groups = num_groups + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + if bias: + self.bias = nn.Parameter(torch.zeros(hidden_size)) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.num_groups}, {self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False): + return layer_norm_linear_fn( + x, + self.weight, + self.bias, + weight, + bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=False, + num_groups=self.num_groups + ) + + +class RMSNormLinear(nn.Module): + + def __init__( + self, + hidden_size, + elementwise_affine: bool = True, + bias: bool = False, + eps: float = 1e-5 + ) -> RMSNormLinear: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + if bias: + self.bias = nn.Parameter(torch.zeros(hidden_size)) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False): + return layer_norm_linear_fn( + x, + self.weight, + self.bias, + weight, + bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=True + ) diff --git a/fla2/modules/rotary.py b/fla2/modules/rotary.py new file mode 100644 index 0000000000000000000000000000000000000000..66f9f1c6fa4b43e1cfd488373a0537a8c533bfae --- /dev/null +++ b/fla2/modules/rotary.py @@ -0,0 +1,310 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Tri Dao. + +from typing import Optional, Tuple, Union + +import torch +from einops import rearrange, repeat + +from ..ops.rotary import apply_rotary + + +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat( + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], + dim=-1, + ) + + +class ApplyRotaryEmb(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + out = apply_rotary( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, + ) + if isinstance(seqlen_offsets, int): + # Can't save int with save_for_backward + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x + + @staticmethod + def backward(ctx, do): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with + # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. + if not ctx.interleaved and not ctx.inplace: + do = do.clone() + dx = apply_rotary( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + ) + return dx, None, None, None, None, None, None, None + + +def apply_rotary_emb( + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmb.apply( + x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen + ) + + +# For backward compatibility +apply_rotary_emb_func = apply_rotary_emb + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). + A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 + Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py + """ + + def __init__( + self, + dim: int, + base=10000.0, + interleaved=False, + scale_base=None, + pos_idx_in_fp32=True, + device=None, + ): + """ + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, + otherwise they might be in lower precision. + This option was added because previously (before 2023-07-02), when we construct + the position indices, we use the dtype of self.inv_freq. In most cases this would + be fp32, but if the model is trained in pure bf16 (not mixed precision), then + self.inv_freq would be bf16, and the position indices are also in bf16. + Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the + embeddings for some positions will coincide. + To maintain compatibility with models previously trained in pure bf16, + we add this option. + """ + super().__init__() + self.dim = dim + self.base = float(base) + self.pos_idx_in_fp32 = pos_idx_in_fp32 + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = self._compute_inv_freq(device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.interleaved = interleaved + self.scale_base = scale_base + scale = ( + (torch.arange(0, dim, 2, device=device, + dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) + if scale_base is not None + else None + ) + self.register_buffer("scale", scale, persistent=False) + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + def _compute_inv_freq(self, device=None): + return 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) + ) + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, + dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, + device=self.scale.device) + - seqlen // 2 + ) / self.scale_base + scale = self.scale.to( + device=power.device) ** rearrange(power, "s -> s 1") + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + seqlen_offset: Union[int, torch.Tensor] = 0, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, + else it's just q of shape (batch, seqlen, nheads, headdim) + kv: (batch, seqlen, 2, nheads, headdim) + seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one + should pass in max_seqlen, which will update the cos / sin cache up to that length. + Apply rotary embedding *inplace* to qkv and / or kv. + """ + seqlen = q.shape[1] + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=q.device, dtype=q.dtype) + elif isinstance(seqlen_offset, int): + self._update_cos_sin_cache(seqlen + seqlen_offset, device=q.device, dtype=q.dtype) + if self.scale is None: + q = apply_rotary_emb_func( + q, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + k = apply_rotary_emb_func( + k, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + + else: + q = apply_rotary_emb_func( + q, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + k = apply_rotary_emb_func( + k, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + + return q, k diff --git a/fla2/ops/__init__.py b/fla2/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2147e2be1be45b7641affb4395107d2f0fd35d61 --- /dev/null +++ b/fla2/ops/__init__.py @@ -0,0 +1,18 @@ +# # -*- coding: utf-8 -*- + +# from .based import fused_chunk_based, parallel_based +# from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla +# from .retention import (chunk_retention, fused_chunk_retention, +# fused_recurrent_retention, parallel_retention) + +# __all__ = [ +# 'fused_chunk_based', +# 'parallel_based', +# 'chunk_gla', +# 'fused_chunk_gla', +# 'fused_recurrent_gla', +# 'chunk_retention', +# 'fused_chunk_retention', +# 'fused_recurrent_retention', +# 'parallel_retention' +# ] diff --git a/fla2/ops/__pycache__/__init__.cpython-310.pyc b/fla2/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..757754320c2fc2217a4df7ac30850c4201288721 Binary files /dev/null and b/fla2/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/ops/__pycache__/__init__.cpython-312.pyc b/fla2/ops/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe24528b641b94037d5775ba965ebe7809565325 Binary files /dev/null and b/fla2/ops/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/ops/__pycache__/__init__.cpython-38.pyc b/fla2/ops/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43fe1b921ca074c518b5e6cbf23e2ef1e9014e15 Binary files /dev/null and b/fla2/ops/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/ops/__pycache__/__init__.cpython-39.pyc b/fla2/ops/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55f4287f41978a5849348413c5d30c6fb965ea20 Binary files /dev/null and b/fla2/ops/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/ops/__pycache__/rotary.cpython-310.pyc b/fla2/ops/__pycache__/rotary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed4a266612ac4135e606d87eed6287a1799e162e Binary files /dev/null and b/fla2/ops/__pycache__/rotary.cpython-310.pyc differ diff --git a/fla2/ops/__pycache__/rotary.cpython-312.pyc b/fla2/ops/__pycache__/rotary.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..101165a3be3ad2d852f572ad9743065e2b1b8335 Binary files /dev/null and b/fla2/ops/__pycache__/rotary.cpython-312.pyc differ diff --git a/fla2/ops/__pycache__/rotary.cpython-38.pyc b/fla2/ops/__pycache__/rotary.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16a2096bf068ca96644129cab3e1e87b50b3634e Binary files /dev/null and b/fla2/ops/__pycache__/rotary.cpython-38.pyc differ diff --git a/fla2/ops/__pycache__/rotary.cpython-39.pyc b/fla2/ops/__pycache__/rotary.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ae75c1132cc0ee34d0f8cd9f5820fd9afd9a1f5 Binary files /dev/null and b/fla2/ops/__pycache__/rotary.cpython-39.pyc differ diff --git a/fla2/ops/__pycache__/utils.cpython-310.pyc b/fla2/ops/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3ead84c7131575c3556f05a3b2d195385b8852b Binary files /dev/null and b/fla2/ops/__pycache__/utils.cpython-310.pyc differ diff --git a/fla2/ops/__pycache__/utils.cpython-312.pyc b/fla2/ops/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..942cb8272085262301835b3a132237c82b60e0d8 Binary files /dev/null and b/fla2/ops/__pycache__/utils.cpython-312.pyc differ diff --git a/fla2/ops/__pycache__/utils.cpython-38.pyc b/fla2/ops/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ce053461fe43f958208e1c2bbdc2ed26685a934 Binary files /dev/null and b/fla2/ops/__pycache__/utils.cpython-38.pyc differ diff --git a/fla2/ops/abc/__init__.py b/fla2/ops/abc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa366a836aa307b9e4cd4a486e8600f8ac473b1 --- /dev/null +++ b/fla2/ops/abc/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_abc +from .chunk_gate import chunk_gated_abc +from .recurrent_fuse import fused_recurrent_gated_abc + +__all__ = [ + 'chunk_abc', + 'chunk_gated_abc', + 'fused_recurrent_gated_abc' +] diff --git a/fla2/ops/abc/__pycache__/__init__.cpython-312.pyc b/fla2/ops/abc/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9367ea8dfc8978eb2215520c2ccb2c93ee65c75b Binary files /dev/null and b/fla2/ops/abc/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/ops/abc/__pycache__/__init__.cpython-39.pyc b/fla2/ops/abc/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e6ebe523759c22bc25495e7ed6220852ddddecd Binary files /dev/null and b/fla2/ops/abc/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/ops/abc/__pycache__/chunk.cpython-312.pyc b/fla2/ops/abc/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd1eeb84f691d3818ec5bc247d1c2b7ec6f7c53b Binary files /dev/null and b/fla2/ops/abc/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla2/ops/abc/__pycache__/chunk.cpython-38.pyc b/fla2/ops/abc/__pycache__/chunk.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf924fc71aa2296fca89272c15de38292fd014dc Binary files /dev/null and b/fla2/ops/abc/__pycache__/chunk.cpython-38.pyc differ diff --git a/fla2/ops/abc/__pycache__/chunk_gate.cpython-312.pyc b/fla2/ops/abc/__pycache__/chunk_gate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9f387946f8666ed282700e64d41f7b24bfee3a6 Binary files /dev/null and b/fla2/ops/abc/__pycache__/chunk_gate.cpython-312.pyc differ diff --git a/fla2/ops/abc/__pycache__/chunk_gate.cpython-38.pyc b/fla2/ops/abc/__pycache__/chunk_gate.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e34760e741fc6fa0b81adc7645de601f8a5ebd20 Binary files /dev/null and b/fla2/ops/abc/__pycache__/chunk_gate.cpython-38.pyc differ diff --git a/fla2/ops/abc/__pycache__/chunk_gate.cpython-39.pyc b/fla2/ops/abc/__pycache__/chunk_gate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d3d164bbf80f325a98a76c6b0d29e9e79263b4a Binary files /dev/null and b/fla2/ops/abc/__pycache__/chunk_gate.cpython-39.pyc differ diff --git a/fla2/ops/abc/__pycache__/recurrent_fuse.cpython-312.pyc b/fla2/ops/abc/__pycache__/recurrent_fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18915be07f6ac3cb7584482c28196481899b4ab7 Binary files /dev/null and b/fla2/ops/abc/__pycache__/recurrent_fuse.cpython-312.pyc differ diff --git a/fla2/ops/abc/__pycache__/recurrent_fuse.cpython-38.pyc b/fla2/ops/abc/__pycache__/recurrent_fuse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85bb3f9ebe5d6d061d6d6435c4332479dbf272eb Binary files /dev/null and b/fla2/ops/abc/__pycache__/recurrent_fuse.cpython-38.pyc differ diff --git a/fla2/ops/abc/chunk.py b/fla2/ops/abc/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..b9902e4cd5013aa79e4dba654db0ab2a84004f15 --- /dev/null +++ b/fla2/ops/abc/chunk.py @@ -0,0 +1,1192 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ...ops.utils import (logcumsumexp_fwd_kernel, softmax_bwd_kernel, + softmax_fwd_kernel) +from ...utils import contiguous + + +@triton.jit +def chunk_abc_fwd_kernel_h( + k, + v, + z, + h, + h0, + ht, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + NORMK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + if NORMK: + p_z0 = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_k * BK,), (BK,), (0,)) + else: + p_z0 = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_v * BV,), (BV,), (0,)) + b_zp = tl.load(p_z0).to(tl.float32) + for i_t in range(NT): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + if NORMK: + p_zc = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = tl.exp(b_zp - b_zc), b_zc + # [BK, BV] + b_h = b_h * b_r[:, None] + b_k = tl.exp(b_k - b_zc[:, None]).to(b_k.dtype) + else: + p_zc = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = tl.exp(b_zp - b_zc), b_zc + # [BK, BV] + b_h = b_h * b_r[None, :] + b_v = tl.exp(b_v - b_zc[None, :]).to(b_v.dtype) + # [BK, BV] + b_h += tl.dot(b_k, b_v, allow_tf32=False) + + if STORE_FINAL_STATE: + p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_abc_fwd_kernel_intra_K( + v, + z, + o, + A, + s_v_h, + s_v_t, + s_v_d, + T: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BV] + b_o = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(0, i_i): + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, tl.exp(b_v - b_zn[None, :]).to(b_v.dtype), allow_tf32=False) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_o *= tl.exp(b_zn[None, :] - b_z) + + o_i = tl.arange(0, BC) + o_A = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + # [BC,] + b_A = tl.load(A + o_A + j, mask=m_A, other=0) + # [BV,] + b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32) + # [BC, BV] + # avoid 0 * inf = inf + m_i = o_i[:, None] >= j + b_o += tl.where(m_i, b_A[:, None] * tl.exp(b_v[None, :] - b_z), 0) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_abc_fwd_kernel_K( + q, + k, + z, + h, + o, + A, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_o += tl.dot(b_q, b_h, allow_tf32=False) + # [BT, BT] + b_A += tl.dot(b_q, b_k, allow_tf32=False) + p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + # [BT, BV] + p_zp = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_p * V + i_v * BV,), (BV,), (0,)) + b_zp = tl.load(p_zp, boundary_check=(0,)) + b_o = b_o * tl.exp(b_zp[None, :] - b_z) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BT] + b_A = tl.where(m_s, b_A, 0.) + if i_v == 0: + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_abc_fwd_kernel_intra_V( + q, + k, + z, + A, + s_k_h, + s_k_t, + s_k_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + n_bh = tl.num_programs(2) + + if i_i > i_j: + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_q = (b_q * tl.exp(b_zn[None, :] - b_z) * scale).to(b_q.dtype) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k = tl.exp(b_k - b_zn[:, None]).to(b_k.dtype) + # [BC, BC] + b_A = tl.dot(b_q, b_k, allow_tf32=False) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + + o_i = tl.arange(0, BC) + o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + # [BK,] + b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32) + # [BC,] + b_A = tl.sum(b_q * tl.exp(b_k[None, :] - b_z) * scale, 1) + b_A = tl.where(o_i >= j, b_A, 0.) + tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A) + + p_k = tl.advance(p_k, (K,)) + + +@triton.jit +def chunk_abc_fwd_kernel_V( + q, + v, + z, + h, + o, + A, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_zp = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_p * K + i_k * BK,), (BK,), (0,)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_z = tl.load(p_z, boundary_check=(0, 1)) + # [BT, BK] + b_zp = tl.load(p_zp, boundary_check=(0,)) + b_q = (b_q * tl.exp(b_zp[None, :] - b_z)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_q, b_h, allow_tf32=False) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_abc_bwd_kernel_dh( + q, + z, + do, + dh, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + NORMK: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + b_zp = tl.full([BK if NORMK else BV], float('inf'), dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + i_p = tl.maximum(i_t * BT - 1, 0) + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + if NORMK: + p_z = tl.make_block_ptr(z + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_zc = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_p * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = tl.exp(b_zc - b_zp), b_zc + # [BK, BT] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_q = (b_q * tl.exp(b_zc[:, None] - b_z)).to(b_q.dtype) + # [BK, BV] + b_dh = b_dh * b_r[:, None] + else: + p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_zc = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_p * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = tl.exp(b_zc - b_zp), b_zc + # [BT, BV] + b_z = tl.load(p_z, boundary_check=(0,)) + b_do = (b_do * tl.exp(b_zc[None, :] - b_z)).to(b_do.dtype) + # [BK, BV] + b_dh = b_dh * b_r[None, :] + # [BK, BV] + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + + +@triton.jit +def chunk_abc_bwd_kernel_V( + k, + v, + z, + h, + A, + do, + dh, + dq, + dk, + dv, + dA, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + n_bh = tl.num_programs(2) + + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_zc = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + + # [BK,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k = tl.exp(b_k - b_zc[None, :]).to(b_k.dtype) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + if i_k == 0: + b_dv += tl.dot(b_A, b_do, allow_tf32=False) + b_do = (b_do * scale).to(b_do.dtype) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BT, BT] + b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + # [BT, BK] + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_zp = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_p * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zp = tl.load(p_zp, boundary_check=(0,)) + # [BT, BK] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_z = tl.exp(b_zp[None, :] - b_z) + # [BT, BK] + b_dq = b_dq * b_z + b_dk = b_dk * b_k + + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT,), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + # [BT, BT] + b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype) + if i_k == 0: + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_abc_bwd_kernel_intra_V( + q, + k, + z, + dA, + dq, + dk, + s_k_h, + s_k_t, + s_k_d, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BK] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_zq = tl.exp(b_zn[None, :] - b_z) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kz = tl.exp(b_k - b_zn[None, :]).to(b_k.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dA, b_kz, allow_tf32=False) + b_dq *= b_zq + + o_i = tl.arange(0, BC) + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK,] + b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] >= j + # [BC, BK] + b_dq += tl.where(m_i, b_dA[:, None] * tl.exp(b_kj[None, :] - b_z), 0.) + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kz = tl.exp(b_k - b_zn[None, :]) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_qz = (b_q * tl.exp(b_zn[None, :] - b_z)).to(b_q.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dk += tl.dot(tl.trans(b_dA), b_qz, allow_tf32=False) + b_dk *= b_kz + + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC) + for j in range(0, BC): + p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + p_zj = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0) + # [BK,] + b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32) + b_zj = tl.load(p_zj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] <= j + b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_k - b_zj[None, :]), 0.) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_abc_bwd_kernel_intra_K( + v, + z, + do, + dA, + s_v_h, + s_v_t, + s_v_d, + scale, + T: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + n_bh = tl.num_programs(2) + + if i_i > i_j: + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (V, T), (s_v_d, s_v_t), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1)) + p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_dA = tl.make_block_ptr(dA+(i_bh+i_v*n_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BV,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * tl.exp(b_zn[None, :] - b_z) * scale).to(b_do.dtype) + # [BV, BC] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v = tl.exp(b_v - b_zn[:, None]).to(b_v.dtype) + # [BC, BC] + b_dA = tl.dot(b_do, b_v, allow_tf32=False) + tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_j * BC) * V + i_v * BV,), (BV,), (0,)) + p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) * scale + + o_i = tl.arange(0, BC) + o_A = (i_bh + i_v * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + # [BV,] + b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32) + # [BC,] + b_dA = tl.sum(b_do * tl.exp(b_v[None, :] - b_z), 1) + b_dA = tl.where(o_i >= j, b_dA, 0) + tl.store(dA + o_A + j, b_dA.to(b_do.dtype), mask=m_A) + + p_v = tl.advance(p_v, (V,)) + + +@triton.jit +def chunk_abc_bwd_kernel_K( + q, + k, + v, + z, + h, + A, + do, + dh, + dq, + dk, + dv, + dA, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + n_bh = tl.num_programs(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh) * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k), allow_tf32=False) + b_A = tl.where(m_s, b_A, 0.) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_zp = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_p * V + i_v * BV,), (BV,), (0,)) + p_zc = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K*V, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BV,] + b_zp = tl.load(p_zp, boundary_check=(0,)) + b_zc = tl.load(p_zc, boundary_check=(0,)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v = tl.exp(b_v - b_zc[None, :]).to(b_v.dtype) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_z = tl.exp(b_zp[None, :] - b_z) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * b_z * scale).to(b_do.dtype) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + # [BT, BV] + b_dv = b_v * tl.dot(b_k, b_dh, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BT] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BT, BK] + b_dq += tl.dot(b_dA, b_k, allow_tf32=False) + b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q, allow_tf32=False) + + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_abc_bwd_kernel_intra_KV( + v, + z, + A, + do, + dv, + s_v_h, + s_v_t, + s_v_d, + T: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * s_v_h, (T*V,), (s_v_d,), ((i_t * BT + i_i * BC + BC - 1) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * tl.exp(b_zn[None, :] - b_z)).to(b_do.dtype) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dv += tl.dot(b_A, b_do, allow_tf32=False) + b_dv *= tl.exp(b_v - b_zn[None, :]) + + o_i = tl.arange(0, BC) + for j in range(0, BC): + p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T * BT,), (1,), ((i_t * BT + i_i * BC + j) * BT + i_i * BC,), (BC,), (0,)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + # [BC,] + b_A = tl.load(p_A, boundary_check=(0,)) + # [BV,] + b_z = tl.load(p_z, boundary_check=(0,)) + b_do = tl.load(p_do, boundary_check=(0,)) + # [BC, BV] + m_i = o_i[:, None] <= j + b_dv += tl.where(m_i, tl.exp(b_v - b_z[None, :]) * b_A[:, None] * b_do[None, :], 0.) + p_dv = tl.make_block_ptr(dv + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_abc_bwd_kernel_rcum_inter( + s, + z, + ss, + doo, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + NT: tl.constexpr +): + i_m, i_bh = tl.program_id(0), tl.program_id(1) + + b_sp = tl.zeros([BS,], dtype=tl.float32) + b_zp = tl.full([BS,], float('inf'), dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + p_zc = tl.make_block_ptr(z + i_bh * s_s_h, (T * S,), (s_s_d,), ((i_t * BT) * S + i_m * BS,), (BS,), (0,)) + p_ss = tl.make_block_ptr(ss + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + p_doo = tl.make_block_ptr(doo + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + # [BS,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_ss = tl.load(p_ss, boundary_check=(0, 1)) + + b_doo = tl.exp(b_s - b_zp[None, :]) * b_sp[None, :] + tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1)) + # [BS,] + b_sp = b_sp * tl.exp(b_zc - b_zp) + tl.sum(b_ss * tl.exp(b_zc[None, :] - b_z), 0) + b_zp = b_zc + + +@triton.jit +def chunk_abc_bwd_kernel_rcum_intra( + s, + z, + ss, + doo, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + NC: tl.constexpr +): + i_s, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + o_i = tl.arange(0, BC) + m_o = tl.full([BC, BC], 1., dtype=tl.float32) + + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * s_s_h, (T*S,), (s_s_d,), ((i_t * BT + i_i * BC + BC - 1) * S + i_s * BS,), (BS,), (0,)) + p_doo = tl.make_block_ptr(doo + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0)) + # [BC, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)) + # [BS,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + + b_doo = tl.zeros([BC, BS], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0)) + p_ss = tl.make_block_ptr(ss + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0)) + # [BC, BS] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_ss = tl.load(p_ss, boundary_check=(0, 1)) + # [BC, BS] + b_doo += b_ss * tl.exp(b_zn[None, :] - b_z) + b_doo = tl.exp(b_s - b_zn[None, :]) * tl.dot(m_o.to(b_s.dtype), b_doo.to(b_s.dtype), allow_tf32=False) + + for j in range(0, BC): + p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T * S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,)) + p_ss = tl.make_block_ptr(ss + i_bh * s_s_h, (T * S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,)) + # [BS,] + b_z = tl.load(p_z, boundary_check=(0,)) + b_ss = tl.load(p_ss, boundary_check=(0,)) + # [BC, BS] + m_i = o_i[:, None] <= j + b_doo += tl.where(m_i, tl.exp(b_s - b_z[None, :]) * b_ss[None, :], 0.) + b_doo += tl.load(p_doo, boundary_check=(0, 1)) + tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkABCFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, q, k, v, s, initial_state, output_final_state): + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + BT, BC = 64, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + BM = min(64, triton.next_power_of_2(M)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NV, NM = triton.cdiv(V, BV), triton.cdiv(M, BM) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + def fwd_pre(s, B, H, T, S): + # keep cummulative normalizer in fp32 + z = torch.empty_like(s, dtype=torch.float) + grid = (B * H,) + logcumsumexp_fwd_kernel[grid]( + s, z, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=S + ) + return z + + def fwd_inner(q, k, v, z, B, H, T, K, V, BT, BK, BV, NT, normk=False, h0=None, ht=None): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + h = q.new_empty(B, H, NT * K, V) + grid = (NV, NK, B * H) + chunk_abc_fwd_kernel_h[grid]( + k, v, z, h, h0, ht, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + NORMK=normk, + USE_INITIAL_STATE=h0 is not None, + STORE_FINAL_STATE=ht is not None, + num_warps=num_warps, + num_stages=num_stages + ) + return h + + final_state = None + if output_final_state: + final_state = (q.new_empty(B, H, K, M, dtype=torch.float), + q.new_empty(B, H, M, V, dtype=torch.float)) + + z = fwd_pre(s, B, H, T, M) + scale = K ** -0.5 + hk = fwd_inner( + q=q, k=k, v=s, z=z, + B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT, + normk=False, + h0=initial_state[0] if initial_state is not None else None, + ht=final_state[0] if final_state is not None else None + ) + ok1 = torch.empty_like(s) + Ak = q.new_empty(B, H, T, BT) + grid = (NM, NT, B * H) + chunk_abc_fwd_kernel_K[grid]( + q, k, z, hk, ok1, Ak, + k.stride(1), k.stride(2), k.stride(3), + s.stride(1), s.stride(2), s.stride(3), + hk.stride(1), hk.stride(2), hk.stride(3), + scale=scale, + T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, + num_warps=num_warps, + num_stages=num_stages + ) + ok0 = torch.empty_like(s) + grid = (NM, NT * NC, B * H) + chunk_abc_fwd_kernel_intra_K[grid]( + s, z, ok0, Ak, + s.stride(1), s.stride(2), s.stride(3), + T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + ok = ok0.add_(ok1) + + scale = 1. + # equivalent to: + # p = ok.softmax(-1, torch.float) + # p is kept in fp32 for safe softmax backward + p = torch.empty_like(ok, dtype=torch.float) + grid = (NT, B * H) + softmax_fwd_kernel[grid]( + ok, p, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=M, BT=BT + ) + qv = p.to(q.dtype) + + scale = 1. + hv = fwd_inner( + q=qv, k=s, v=v, z=z, + B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT, + normk=True, + h0=initial_state[1] if initial_state is not None else None, + ht=final_state[1] if final_state is not None else None + ) + Av = q.new_zeros(NM, B, H, T, BT) + grid = (NM, NT * NC * NC, B * H) + chunk_abc_fwd_kernel_intra_V[grid]( + qv, s, z, Av, + s.stride(1), s.stride(2), s.stride(3), + scale=scale, + T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + Av = Av.sum(0) + ov = torch.empty_like(v) + grid = (NV, NT, B * H) + chunk_abc_fwd_kernel_V[grid]( + qv, v, z, hv, ov, Av, + s.stride(1), s.stride(2), s.stride(3), + v.stride(1), v.stride(2), v.stride(3), + hv.stride(1), hv.stride(2), hv.stride(3), + scale=scale, + T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v, s, z, ok, p, hk, hv, Av) + ctx.BT = BT + return ov, final_state + + @staticmethod + @contiguous + def backward(ctx, dov, dht=None): + q, k, v, s, z, ok, p, hk, hv, Av = ctx.saved_tensors + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + BT, BC = ctx.BT, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + BM = min(64, triton.next_power_of_2(M)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NK, NM = triton.cdiv(K, BK), triton.cdiv(M, BM) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + def bwd_inner(q, z, do, B, H, T, K, V, BT, BK, BV, NT, scale, normk=False): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + dh = q.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_abc_bwd_kernel_dh[grid]( + q, z, do, dh, + q.stride(1), q.stride(2), q.stride(3), + do.stride(1), do.stride(2), do.stride(3), + dh.stride(1), dh.stride(2), dh.stride(3), + scale=scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + NORMK=normk, + num_warps=num_warps, + num_stages=num_stages + ) + return dh + + def bwd_post(s, z, ss, B, H, T, S, BT, BC, BS, NT, NC, NS): + doo = torch.empty_like(s) + grid = (NS, B * H) + chunk_abc_bwd_kernel_rcum_inter[grid]( + s, z, ss, doo, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=S, BT=BT, BS=BS, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + grid = (NS, NT * NC, B * H) + chunk_abc_bwd_kernel_rcum_intra[grid]( + s, z, ss, doo, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=S, BT=BT, BC=BC, BS=BS, NC=NC, + num_warps=num_warps, + num_stages=num_stages + ) + return doo + + scale = 1. + qv = p.to(q.dtype) + dhv = bwd_inner( + qv, z, dov, + B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT, + scale=scale, + normk=True + ) + dp1 = torch.empty_like(p) + dsv1 = torch.empty_like(s, dtype=torch.float) + dv = v.new_empty(NM, *v.shape) + dAv = q.new_zeros(B, H, T, BT) + grid = (NM, NT, B * H) + chunk_abc_bwd_kernel_V[grid]( + s, v, z, hv, Av, dov, dhv, dp1, dsv1, dv, dAv, + s.stride(1), s.stride(2), s.stride(3), + v.stride(1), v.stride(2), v.stride(3), + hv.stride(1), hv.stride(2), hv.stride(3), + scale=scale, + T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + dv = dv.sum(0) + dp0 = torch.empty_like(p) + dsv0 = s.new_zeros(s.shape, dtype=torch.float) + grid = (NM, NT * NC, B * H) + chunk_abc_bwd_kernel_intra_V[grid]( + qv, s, z, dAv, dp0, dsv0, + s.stride(1), s.stride(2), s.stride(3), + T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + dp = dp1.add_(dp0) + dsv = dsv1.add_(dsv0) + + # softmax gradient, equivalent to: + # dok = p * (dp - (p * dp).sum(-1, True)) + dok = torch.empty_like(ok) + grid = (NT, B * H) + softmax_bwd_kernel[grid]( + p, dp, dok, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=M, BT=BT + ) + + scale = K ** -0.5 + dhk = bwd_inner( + q, z, dok, + B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT, + scale=scale, + normk=False + ) + dAk = q.new_zeros(NM, B, H, T, BT) + grid = (NM, NT * NC * NC, B * H) + chunk_abc_bwd_kernel_intra_K[grid]( + s, z, dok, dAk, + s.stride(1), s.stride(2), s.stride(3), + scale=scale, + T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + dAk = dAk.sum(0) + + Ak = q.new_zeros(NK, B, H, T, BT) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dsk1 = s.new_empty(NK, *s.shape, dtype=torch.float) + grid = (NK, NT, B * H) + chunk_abc_bwd_kernel_K[grid]( + q, k, s, z, hk, Ak, dok, dhk, dq, dk, dsk1, dAk, + q.stride(1), q.stride(2), q.stride(3), + s.stride(1), s.stride(2), s.stride(3), + hk.stride(1), hk.stride(2), hk.stride(3), + scale=scale, + T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, + num_warps=num_warps, + num_stages=num_stages + ) + Ak = Ak.sum(0) + dsk1 = dsk1.sum(0) + dsk0 = torch.empty_like(s, dtype=torch.float) + grid = (NM, NT * NC, B * H) + chunk_abc_bwd_kernel_intra_KV[grid]( + s, z, Ak, dok, dsk0, + s.stride(1), s.stride(2), s.stride(3), + T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + ds = dsv.add_(dsk1.add_(dsk0)) + ds -= bwd_post(s, z, ok * dok + p * dp, B, H, T, M, BT, BC, BM, NT, NC, NM) + ds = ds.to(s.dtype) + return dq, dk, dv, ds, None, None + + +def chunk_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + initial_state: Optional[Tuple[torch.Tensor]] = None, + output_final_state: Optional[bool] = False +) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: + ov, final_state = ChunkABCFunction.apply(q, k, v, s, initial_state, output_final_state) + return ov, final_state diff --git a/fla2/ops/abc/chunk_gate.py b/fla2/ops/abc/chunk_gate.py new file mode 100644 index 0000000000000000000000000000000000000000..481a6b0b85ea59730f6c5a78872f3b483e50b864 --- /dev/null +++ b/fla2/ops/abc/chunk_gate.py @@ -0,0 +1,1333 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl +from einops import reduce + +from ...ops.utils import (chunk_global_reversed_cumsum, chunk_local_cumsum, softmax_bwd_kernel, + softmax_fwd_kernel) +from ...utils import contiguous + + + +@triton.jit +def chunk_gated_abc_fwd_kernel_h( + k, + v, + g, + h, + h0, + ht, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + GATEK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + for i_t in range(NT): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + o_t = min(i_t * BT + BT, T) + + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + if GATEK: + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((o_t - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BK, BV] + b_h *= tl.exp(b_gn)[:, None] + # [BK, BT] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype) + else: + p_g = tl.make_block_ptr(g + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_v_h, (T * V,), (s_v_d,), ((o_t - 1) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BK, BV] + b_h *= tl.exp(b_gn)[None, :] + # [BT, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_v = (b_v * tl.exp(b_gn[None, :] - b_g)).to(b_v.dtype) + # [BK, BV] + b_h += tl.dot(b_k, b_v, allow_tf32=False) + + if STORE_FINAL_STATE: + p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_fwd_kernel_intra_K( + v, + g, + o, + A, + s_v_h, + s_v_t, + s_v_d, + T: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr, + NG: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bg = i_bh // NG + i_t, i_i = i_c // NC, i_c % NC + + p_g = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bg * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BV] + b_o = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(0, i_i): + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_v = tl.make_block_ptr(v + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + p_gv = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_vg = (b_v * tl.exp(b_gn[None, :] - b_gv)).to(b_v.dtype) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, b_vg, allow_tf32=False) + # [BC, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_o *= tl.exp(b_g - b_gn[None, :]) + + o_i = tl.arange(0, BC) + o_A = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + p_v = tl.make_block_ptr(v + i_bg * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + p_gv = tl.make_block_ptr(g + i_bg * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + # [BC,] + b_A = tl.load(A + o_A + j, mask=m_A, other=0) + # [BV,] + b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32) + b_gv = tl.load(p_gv, boundary_check=(0,)).to(tl.float32) + # [BC, BV] + b_vg = b_v[None, :] * tl.exp(b_g - b_gv[None, :]) + # avoid 0 * inf = inf + b_o += tl.where(o_i[:, None] >= j, b_A[:, None] * b_vg, 0.) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + + b_o += tl.load(p_o, boundary_check=(0, 1)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_fwd_kernel_K( + q, + k, + h, + g, + o, + A, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bg = i_bh // NG + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bg * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bg * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_o += tl.dot(b_q, b_h, allow_tf32=False) + # [BT, BT] + b_A += tl.dot(b_q, b_k, allow_tf32=False) + p_g = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_o = b_o * tl.exp(b_g) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BT] + b_A = tl.where(m_s, b_A, 0.) + if i_v == 0: + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_fwd_kernel_intra_Vk( + q, + k, + g, + A, + s_k_h, + s_k_t, + s_k_d, + i_k, + i_c, + i_bh, + scale, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + NG: tl.constexpr +): + i_bg = i_bh // NG + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + + b_A = tl.zeros([BC, BC], tl.float32) + if i_i > i_j: + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bg * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(g + i_bg * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = tl.make_block_ptr(g + i_bg * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_qg = (b_q * tl.exp(b_g - b_gn[None, :]) * scale).to(b_q.dtype) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype) + # [BC, BC] + b_A = tl.dot(b_qg, b_kg, allow_tf32=False) + if i_k != 0: + b_A += tl.load(p_A, boundary_check=(0, 1)) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bg * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + p_gk = tl.make_block_ptr(g + i_bg * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + + o_i = tl.arange(0, BC) + # [BC, BC] + m_A = o_i[:, None] >= o_i[None, :] + for j in range(0, BC): + # [BK,] + b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32) + b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32) + # [BC,] + b_Aj = tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]) * scale, 1) + b_A = tl.where((o_i == j)[None, :], b_Aj[:, None], b_A) + + p_k = tl.advance(p_k, (K,)) + p_gk = tl.advance(p_gk, (K,)) + b_A = tl.where(m_A, b_A, 0.) + if i_k != 0: + b_A += tl.load(p_A, boundary_check=(0, 1)) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + else: + # set the upper triangular part to 0 + if i_k == 0: + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_fwd_kernel_intra_V( + q, + k, + g, + A, + s_k_h, + s_k_t, + s_k_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + NK: tl.constexpr, + NG: tl.constexpr +): + i_c, i_bh = tl.program_id(0), tl.program_id(1) + + for i_k in range(0, NK): + chunk_gated_abc_fwd_kernel_intra_Vk( + q, + k, + g, + A, + s_k_h, + s_k_t, + s_k_d, + i_k, + i_c, + i_bh, + scale, + T, + K, + BT, + BC, + BK, + NC, + NG, + ) + + +@triton.jit +def chunk_gated_abc_fwd_kernel_V( + q, + v, + g, + h, + o, + A, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bg = i_bh // NG + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bg * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + # [BT, BK] + b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_qg, b_h, allow_tf32=False) + p_v = tl.make_block_ptr(v + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_bwd_kernel_dh( + q, + g, + do, + dh, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + NG: tl.constexpr, + GATEK: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bg = i_bh // NG + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + o_t = min(i_t * BT + BT, T) + + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + if GATEK: + p_g = tl.make_block_ptr(g + i_bg * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gn = tl.make_block_ptr(g + i_bg * s_k_h, (T * K,), (s_k_d,), ((o_t - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BK, BV] + b_dh *= tl.exp(b_gn)[:, None] + # [BK, BT] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_q = (b_q * tl.exp(b_g)).to(b_q.dtype) + else: + p_g = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bg * s_v_h, (T * V,), (s_v_d,), ((o_t - 1) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BK, BV] + b_dh *= tl.exp(b_gn)[None, :] + # [BT, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_do = (b_do * tl.exp(b_g)).to(b_do.dtype) + # [BK, BV] + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + + +@triton.jit +def chunk_gated_abc_bwd_kernel_V( + k, + v, + h, + g, + A, + do, + dh, + dq, + dk, + dv, + dA, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bg = i_bh // NG + n_bh = tl.num_programs(2) + o_t = min(i_t * BT + BT, T) + + p_k = tl.make_block_ptr(k + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bg * s_k_h, (T * K,), (s_k_d,), ((o_t - 1) * K + i_k * BK,), (BK,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + + # [BK,] + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk) + b_k = (b_k * b_gn).to(b_k.dtype) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bg * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + if i_k == 0: + b_dv += tl.dot(b_A, b_do, allow_tf32=False) + b_do = (b_do * scale).to(b_do.dtype) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BT, BT] + b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + # [BT, BK] + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + b_dq = b_dq * tl.exp(b_gk) + b_dk = b_dk * b_gn + + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + # [BT, BT] + b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype) + if i_k == 0: + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_bwd_kernel_intra_V( + q, + k, + g, + dA, + dq, + dk, + dg, + s_k_h, + s_k_t, + s_k_d, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + NG: tl.constexpr, + OVERWRITE: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bg = i_bh // NG + i_t, i_i = i_c // NC, i_c % NC + + p_g = tl.make_block_ptr(g + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bg * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dA, b_kg, allow_tf32=False) + b_dq *= tl.exp(b_g - b_gn[None, :]) + + o_i = tl.arange(0, BC) + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + p_kj = tl.make_block_ptr(k + i_bg * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,)) + p_gkj = tl.make_block_ptr(g + i_bg * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK,] + b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32) + b_gkj = tl.load(p_gkj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] >= j + # [BC, BK] + b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_g - b_gkj[None, :]), 0.) + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + p_k = tl.make_block_ptr(k + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bg * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_qg = (b_q * tl.exp(b_g - b_gn[None, :])).to(b_q.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False) + b_dk *= tl.exp(b_gn[None, :] - b_gk) + + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC) + for j in range(0, BC): + p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + p_gqj = tl.make_block_ptr(g + i_bg * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0) + # [BK,] + b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32) + b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] <= j + b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.) + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1)).to(tl.float32) + b_dg = b_q * b_dq - b_k * b_dk + if not OVERWRITE: + b_dg = b_dg + tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_bwd_kernel_intra_K( + v, + g, + do, + dA, + s_v_h, + s_v_t, + s_v_d, + scale, + T: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr, + NG: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + i_bg = i_bh // NG + n_bh = tl.num_programs(2) + + p_dA = tl.make_block_ptr(dA+(i_bh+i_v*n_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + + # [BC, BC] + b_dA = tl.zeros([BC, BC], dtype=tl.float32) + if i_i > i_j: + p_v = tl.make_block_ptr(v + i_bg * s_v_h, (V, T), (s_v_d, s_v_t), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1)) + p_gv = tl.make_block_ptr(g + i_bg * s_v_h, (V, T), (s_v_d, s_v_t), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1)) + p_gn = tl.make_block_ptr(g + i_bg * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,)) + p_g = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + # [BV,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * tl.exp(b_g - b_gn[None, :]) * scale).to(b_do.dtype) + # [BV, BC] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_vg = (b_v * tl.exp(b_gn[:, None] - b_gv)).to(b_v.dtype) + # [BC, BC] + b_dA = tl.dot(b_do, b_vg, allow_tf32=False) + elif i_i == i_j: + p_v = tl.make_block_ptr(v + i_bg * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_j * BC) * V + i_v * BV,), (BV,), (0,)) + p_gv = tl.make_block_ptr(g + i_bg * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_j * BC) * V + i_v * BV,), (BV,), (0,)) + p_g = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) * scale + + o_i = tl.arange(0, BC) + # [BC, BC] + m_dA = o_i[:, None] >= o_i[None, :] + for j in range(0, BC): + # [BV,] + b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32) + b_gv = tl.load(p_gv, boundary_check=(0,)).to(tl.float32) + # [BC,] + b_dAj = tl.sum(b_do * b_v[None, :] * tl.exp(b_g - b_gv[None, :]), 1) + b_dA = tl.where((o_i == j)[None, :], b_dAj[:, None], b_dA) + + p_v = tl.advance(p_v, (V,)) + p_gv = tl.advance(p_gv, (V,)) + b_dA = tl.where(m_dA, b_dA, 0.) + tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_bwd_kernel_K( + q, + k, + v, + h, + g, + A, + do, + dh, + dq, + dk, + dv, + dA, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bg = i_bh // NG + n_bh = tl.num_programs(2) + + o_i = tl.arange(0, BT) + o_t = min(i_t * BT + BT, T) + m_s = o_i[:, None] >= o_i[None, :] + + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bg * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh) * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k), allow_tf32=False) + b_A = tl.where(m_s, b_A, 0.) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bg * s_h_h + i_t * K*V, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_g = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bg * s_v_h, (T * V,), (s_v_d,), ((o_t - 1) * V + i_v * BV,), (BV,), (0,)) + + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BV,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_v = b_v * tl.exp(b_gn[None, :] - b_g) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * tl.exp(b_g) * scale).to(b_do.dtype) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + b_dk += tl.dot(b_v.to(b_dh.dtype), tl.trans(b_dh), allow_tf32=False) + # [BT, BV] + b_dv = tl.exp(b_gn[None, :] - b_g) * tl.dot(b_k, b_dh, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BT] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BT, BK] + b_dq += tl.dot(b_dA, b_k, allow_tf32=False) + b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q, allow_tf32=False) + + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_bwd_kernel_intra_KV( + v, + g, + o, + A, + do, + dv, + dg, + s_v_h, + s_v_t, + s_v_d, + T: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr, + NG: tl.constexpr, + OVERWRITE: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bg = i_bh // NG + i_t, i_i = i_c // NC, i_c % NC + + p_gv = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bg * s_v_h, (T*V,), (s_v_d,), ((i_t * BT + i_i * BC + BC - 1) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BV] + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_dv = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_g = tl.make_block_ptr(g + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * tl.exp(b_g - b_gn[None, :])).to(b_do.dtype) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dv += tl.dot(b_A, b_do, allow_tf32=False) + b_dv *= tl.exp(b_gn[None, :] - b_gv) + + o_i = tl.arange(0, BC) + for j in range(0, BC): + p_g = tl.make_block_ptr(g + i_bg * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T * BT,), (1,), ((i_t * BT + i_i * BC + j) * BT + i_i * BC,), (BC,), (0,)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + # [BC,] + b_A = tl.load(p_A, boundary_check=(0,)) + # [BV,] + b_g = tl.load(p_g, boundary_check=(0,)) + b_do = tl.load(p_do, boundary_check=(0,)) + # [BC, BV] + m_i = o_i[:, None] <= j + b_dv += tl.where(m_i, tl.exp(b_g[None, :] - b_gv) * b_A[:, None] * b_do[None, :], 0.) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_v = tl.make_block_ptr(v + i_bg * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + + b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) + b_v = tl.load(p_v, boundary_check=(0, 1)).to(tl.float32) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(tl.float32) + b_dv = b_dv + tl.load(p_dv, boundary_check=(0, 1)).to(tl.float32) + b_dg = b_o * b_do - b_v * b_dv + if not OVERWRITE: + b_dg = b_dg + tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, gatek=False, h0=None, ht=None): + NT = triton.cdiv(T, BT) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + h = q.new_empty(B, H, NT * K, V) + grid = (NV, NK, B * H) + chunk_gated_abc_fwd_kernel_h[grid]( + k, v, g, h, h0, ht, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + GATEK=gatek, + USE_INITIAL_STATE=h0 is not None, + STORE_FINAL_STATE=ht is not None, + num_warps=num_warps, + num_stages=num_stages + ) + return h + + +def fwd_v(q, k, v, g, B, H, T, K, V, BT, BK, BV, BC, h0=None, ht=None, scale=1.): + HQ = q.shape[1] + NT = triton.cdiv(T, BT) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + NC = triton.cdiv(BT, BC) + NG = HQ // H + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + h = fwd_inner( + q=q, k=k, v=v, g=g, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + gatek=True, + h0=h0, + ht=ht + ) + A = q.new_empty(B, HQ, T, BT) + grid = (NT * NC * NC, B * HQ) + chunk_gated_abc_fwd_kernel_intra_V[grid]( + q, k, g, A, + k.stride(1), k.stride(2), k.stride(3), + scale, + T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, NK=NK, NG=NG, + num_warps=num_warps, + num_stages=num_stages + ) + o = v.new_empty(B, HQ, T, V) + grid = (NV, NT, B * HQ) + chunk_gated_abc_fwd_kernel_V[grid]( + q, v, g, h, o, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NG=NG, + num_warps=num_warps, + num_stages=num_stages + ) + return o, h, A + + +def fwd_k(q, k, v, g, B, H, T, K, V, BT, BK, BV, BC, h0=None, ht=None, scale=1.): + HQ = q.shape[1] + NT = triton.cdiv(T, BT) + NV = triton.cdiv(V, BV) + NC = triton.cdiv(BT, BC) + NG = HQ // H + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + h = fwd_inner( + q=q, k=k, v=v, g=g, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + gatek=False, + h0=h0, + ht=ht + ) + o = v.new_empty(B, HQ, T, V) + A = q.new_empty(B, HQ, T, BT) + grid = (NV, NT, B * HQ) + chunk_gated_abc_fwd_kernel_K[grid]( + q, k, h, g, o, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NG=NG, + num_warps=num_warps, + num_stages=num_stages + ) + grid = (NV, NT * NC, B * HQ) + chunk_gated_abc_fwd_kernel_intra_K[grid]( + v, g, o, A, + v.stride(1), v.stride(2), v.stride(3), + T=T, V=V, BT=BT, BC=BC, BV=BV, NC=NC, NG=NG, + num_warps=num_warps, + num_stages=num_stages + ) + return o, h, A + + +def bwd_inner(q, g, do, B, H, T, K, V, BT, BK, BV, scale, gatek=False): + HQ = q.shape[1] + NT = triton.cdiv(T, BT) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + NG = HQ // H + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + dh = q.new_empty(B, HQ, NT * K, V) + grid = (NK, NV, B * HQ) + chunk_gated_abc_bwd_kernel_dh[grid]( + q, g, do, dh, + q.stride(1), q.stride(2), q.stride(3), + do.stride(1), do.stride(2), do.stride(3), + dh.stride(1), dh.stride(2), dh.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, NG=NG, + GATEK=gatek, + num_warps=num_warps, + num_stages=num_stages + ) + return dh + + +def bwd_v(q, k, v, g, h, A, do, dg, B, H, T, K, V, BT, BK, BV, BC, scale=1.): + HQ = q.shape[1] + NT = triton.cdiv(T, BT) + NK = triton.cdiv(K, BK) + NC = triton.cdiv(BT, BC) + NG = HQ // H + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + overwrite_dg = dg is None + dh = bwd_inner( + q, g, do, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + scale=scale, + gatek=True + ) + dq = torch.empty_like(q, dtype=torch.float) + dk = k.new_empty(B, HQ, T, K, dtype=torch.float) + dv = v.new_empty(NK, B, HQ, T, V) + dg = g.new_empty(B, HQ, T, K, dtype=torch.float) if dg is None else dg + dA = v.new_empty(B, HQ, T, BT) + + grid = (NK, NT, B * HQ) + chunk_gated_abc_bwd_kernel_V[grid]( + k, v, h, g, A, do, dh, dq, dk, dv, dA, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NG=NG, + num_warps=num_warps, + num_stages=num_stages + ) + dv = dv.sum(0, dtype=dv.dtype) + grid = (NK, NT * NC, B * HQ) + chunk_gated_abc_bwd_kernel_intra_V[grid]( + q, k, g, dA, dq, dk, dg, + k.stride(1), k.stride(2), k.stride(3), + T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, NG=NG, + OVERWRITE=overwrite_dg, + num_warps=num_warps, + num_stages=num_stages + ) + return dq, dk, dv, dg + + +def bwd_k(q, k, v, g, h, o, do, dg, B, H, T, K, V, BT, BK, BV, BC, scale=1.): + HQ = q.shape[1] + NT = triton.cdiv(T, BT) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + NC = triton.cdiv(BT, BC) + NG = HQ // H + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + overwrite_dg = dg is None + dh = bwd_inner( + q, g, do, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + scale=scale, + gatek=False + ) + dA = q.new_empty(NV, B, HQ, T, BT) + grid = (NV, NT * NC * NC, B * HQ) + chunk_gated_abc_bwd_kernel_intra_K[grid]( + v, g, do, dA, + v.stride(1), v.stride(2), v.stride(3), + scale, + T=T, V=V, BT=BT, BC=BC, BV=BV, NC=NC, NG=NG, + num_warps=num_warps, + num_stages=num_stages + ) + dA = dA.sum(0, dtype=dA.dtype) + + A = do.new_empty(NK, B, HQ, T, BT) + dq = torch.empty_like(q) + dk = k.new_empty(B, HQ, T, K) + dv = v.new_empty(NK, B, HQ, T, V) + dg = g.new_empty(B, HQ, T, V, dtype=torch.float) if dg is None else dg + grid = (NK, NT, B * HQ) + chunk_gated_abc_bwd_kernel_K[grid]( + q, k, v, h, g, A, do, dh, dq, dk, dv, dA, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NG=NG, + num_warps=num_warps, + num_stages=num_stages + ) + A = A.sum(0, dtype=A.dtype) + dv = dv.sum(0, dtype=dv.dtype) + grid = (NV, NT * NC, B * HQ) + chunk_gated_abc_bwd_kernel_intra_KV[grid]( + v, g, o, A, do, dv, dg, + v.stride(1), v.stride(2), v.stride(3), + T=T, V=V, BT=BT, BC=BC, BV=BV, NC=NC, NG=NG, + OVERWRITE=overwrite_dg, + num_warps=num_warps, + num_stages=num_stages + ) + return dq, dk, dv, dg + + +class ChunkGatedABCFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, q, k, v, s, g, scale, hk0, hv0, output_final_state, checkpoint_level): + B, H, T, K, V, M = *k.shape, v.shape[-1], s.shape[-1] + BT, BC = 64, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + BM = min(64, triton.next_power_of_2(M)) + + hkt, hvt = None, None + if output_final_state: + hkt = q.new_empty(B, H, K, M, dtype=torch.float) + hvt = q.new_empty(B, H, M, V, dtype=torch.float) + + g_cumsum = chunk_local_cumsum(g, BT) + g_org, g = g, g_cumsum + ok, hk, _ = fwd_k( + q=q, k=k, v=s, g=g, + B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, BC=BC, + h0=hk0, + ht=hkt, + scale=scale + ) + + # equivalent to: + # p = ok.softmax(-1, torch.float) + # p is kept in fp32 for safe softmax backward + p = torch.empty_like(ok, dtype=torch.float) + def grid(meta): return (triton.cdiv(meta['T'], meta['BT']), p.shape[0] * p.shape[1]) + softmax_fwd_kernel[grid]( + ok, p, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=M, BT=BT + ) + + ov, hv, Av = fwd_v( + q=p.to(q.dtype), k=s, v=v, g=g, + B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, BC=BC, + h0=hv0, + ht=hvt, + scale=1. + ) + + if checkpoint_level >= 1: + del g + g = g_org + if checkpoint_level > 1: + del hk + del hv + hk, hv = None, None + else: + hk0, hv0 = None, None + + ctx.save_for_backward(q, k, v, s, g, ok, p, hk, hv, Av, hk0, hv0) + ctx.checkpoint_level = checkpoint_level + ctx.scale = scale + ctx.BT = BT + return ov, (hkt, hvt) + + @staticmethod + @contiguous + def backward(ctx, dov, dht=None): + q, k, v, s, g, ok, p, hk, hv, Av, hk0, hv0 = ctx.saved_tensors + qv = p.to(q.dtype) + B, H, T, K, V, M = *k.shape, v.shape[-1], s.shape[-1] + BT, BC = ctx.BT, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + BM = min(64, triton.next_power_of_2(M)) + + if ctx.checkpoint_level >= 1: + g = chunk_local_cumsum(g, BT) + + # rerun the forward pass to get h if checkpoint_level >= 1 + if ctx.checkpoint_level > 1: + hk = fwd_inner( + q=q, k=k, v=s, g=g, + B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, + gatek=False, + h0=hk0, + ht=None + ) + hv = fwd_inner( + q=qv, k=s, v=v, g=g, + B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, + gatek=True, + h0=hv0, + ht=None + ) + + dqv, dsv, dv, dg = bwd_v( + q=qv, k=s, v=v, g=g, h=hv, A=Av, do=dov, dg=None, + B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, BC=BC, + scale=1. + ) + + # softmax gradient, equivalent to: + # dok = qv * (dqv - (qv * dqv).sum(-1, True)) + dok = torch.empty_like(ok) + def grid(meta): return (triton.cdiv(meta['T'], meta['BT']), p.shape[0] * p.shape[1]) + softmax_bwd_kernel[grid]( + p, dqv, dok, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=M, BT=BT + ) + + dq, dk, dsk, dg = bwd_k( + q=q, k=k, v=s, g=g, h=hk, o=ok, do=dok, dg=dg, + B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, BC=BC, + scale=ctx.scale + ) + + ds = dsv.add_(dsk) + # reversed cumsum, equivalent to: + # + # def reversed_cumsum(x, dim=-1): + # c = x.cumsum(dim) + # return x + c.index_select(dim, x.new_tensor([c.shape[dim]-1], dtype=torch.long)) - c + dg = chunk_global_reversed_cumsum(dg).to(s.dtype) + if q.shape[1] != H: + dk, dv, ds, dg = map(lambda x: reduce(x, 'b (h g) ... -> b h ...', 'sum', h=H), (dk, dv, ds, dg)) + return dq, dk, dv, ds, dg, None, None, None, None, None + + +def chunk_gated_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[Tuple[torch.Tensor]] = None, + output_final_state: Optional[bool] = False, + checkpoint_level: Optional[int] = 2 +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `(B, HQ, T, K)`. + k (torch.Tensor): + keys of shape `(B, H, T, K)`. GQA is performed if `H` is not equal to `HQ`. + v (torch.Tensor): + values of shape `(B, H, T, V)`. + g (torch.Tensor): + Forget gates of shape `(B, H, T, M)` applied to keys. + If not provided, this function is equivalent to vanilla ABC. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[Tuple[torch.Tensor]]): + Initial state tuple having tensors of shape `(B, H, K, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state tuple, having tensors of shape `(B, H, K, V)`. Default: `False`. + checkpoint_level (Optional[int]): + Checkpointing level; higher values will save more memories and do more recomputations during backward. + Default: `2`: + - Level `0`: no memory saved, no recomputation. + - Level `1`: recompute the fp32 cumulative values during backward. + - Level `2`: recompute the fp32 cumulative values and forward hidden states during backward. + """ + assert checkpoint_level in [0, 1, 2] + if g is None: + # TODO: this 3 steps took huge amount of time, ought to be optimized + z = s.float().logcumsumexp(2) + g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z + s = torch.exp(s - z).to(k.dtype) + if scale is None: + scale = q.shape[-1] ** -0.5 + + hk0, hv0 = None, None + if initial_state is not None: + hk0, hv0 = initial_state + ov, final_state = ChunkGatedABCFunction.apply(q, k, v, s, g, scale, hk0, hv0, output_final_state, checkpoint_level) + return ov, final_state diff --git a/fla2/ops/abc/naive.py b/fla2/ops/abc/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..a7f25c40db73bcf33d1599761be0008cc5be7c59 --- /dev/null +++ b/fla2/ops/abc/naive.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +from einops import repeat + + +def naive_recurrent_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False +) -> torch.Tensor: + dtype = q.dtype + + NG = q.shape[1]//k.shape[1] + # [batch_size, n_heads, seq_len, n_slots] + if g is None: + z = s.float().logcumsumexp(2) + g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z + s = torch.exp(s - z) + q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g)) + k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g)) + if initial_state is not None: + initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state)) + + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + + hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device) + ok = torch.zeros_like(s) + + if scale is None: + scale = q.shape[-1] ** -0.5 + + final_state = None + if initial_state is not None: + hk += initial_state[0] + + for i in range(T): + q_i = q[:, :, i] * scale + k_i = k[:, :, i] + v_i = s[:, :, i] + g_i = g[:, :, i].exp() + hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :] + ok[:, :, i] = (q_i[..., None] * hk).sum(-2) + + qv = ok.softmax(-1) + hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device) + ov = torch.zeros_like(v) + if initial_state is not None: + hv += initial_state[1] + + for i in range(T): + q_i = qv[:, :, i] + k_i = s[:, :, i] + v_i = v[:, :, i] + g_i = g[:, :, i].exp() + hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :] + ov[:, :, i] = (q_i[..., None] * hv).sum(-2) + + if output_final_state: + final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0]) + return ov.to(dtype), final_state + + +def naive_cumsum_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor +) -> torch.Tensor: + """ + A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper. + This is just for demonstration purposes, with no numerical stabilities guaranteed. + """ + + dtype = q.dtype + q, k, v, s = map(lambda x: x.float(), (q, k, v, s)) + + scale = q.shape[-1] ** -0.5 + # [batch_size, n_heads, seq_len, n_slots] + s = (s - s.max(2, True)[0]).exp() + z = s.cumsum(2) + # [batch_size, n_heads, seq_len, n_slots, d_head] + K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) + V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) + # [batch_size, n_heads, seq_len, n_slots] + p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1) + # [batch_size, n_heads, seq_len, d_head] + o = torch.einsum('...m,...md->...d', p, V) + return o.to(dtype), None diff --git a/fla2/ops/abc/recurrent_fuse.py b/fla2/ops/abc/recurrent_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..0e73f854236027fd985b040b5e5eccf88a46ffbf --- /dev/null +++ b/fla2/ops/abc/recurrent_fuse.py @@ -0,0 +1,490 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2024, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + + +@triton.jit +def fused_recurrent_gated_abc_inference_kernel( + q, + k, + v, + s, + g, + o, + hk0, + hv0, + hkt, + hvt, + scale, + K: tl.constexpr, + V: tl.constexpr, + M: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr +): + i_bh = tl.program_id(0) + i_bg = i_bh // NG + + b_s = tl.load(s + i_bg * M + tl.arange(0, M)).to(tl.float32) + b_g = tl.load(g + i_bg * M + tl.arange(0, M)).to(tl.float32) + b_g = tl.exp(b_g) + + b_ok = tl.zeros([M], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + + p_hk0 = hk0 + i_bg * K * M + (o_k[None, :]) * M + tl.arange(0, M)[:, None] + # [BK,] + mask_k = o_k < K + # [M, BK] + mask_hk = (tl.arange(0, M) < M)[:, None] & mask_k[None, :] + # [M, BK] + b_hk = tl.load(p_hk0, mask=mask_hk, other=0.).to(tl.float32) + # [BK,] + b_q = tl.load(q + i_bh * K + o_k, mask=mask_k, other=0.).to(tl.float32) * scale + b_k = tl.load(k + i_bg * K + o_k, mask=mask_k, other=0.).to(tl.float32) + b_hk = b_hk * b_g[:, None] + b_k[None, :] * b_s[:, None] + b_ok += tl.sum(b_hk * b_q[None, :], axis=1) + + if i_bh % NG == 0: + p_hkt = hkt + i_bg * K * M + o_k[None, :] * M + tl.arange(0, M)[:, None] + tl.store(p_hkt, b_hk.to(p_hkt.dtype.element_ty), mask=mask_hk) + + b_qv = tl.softmax(b_ok) + for i_v in range(tl.cdiv(V, BV)): + o_v = i_v * BV + tl.arange(0, BV) + + p_hv0 = hv0 + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None] + # [BV,] + mask_v = o_v < V + # [BV, M] + mask_hv = mask_v[:, None] & (tl.arange(0, M) < M)[None, :] + # [BV, M] + b_hv = tl.load(p_hv0, mask=mask_hv, other=0).to(tl.float32) + # [BV,] + b_v = tl.load(v + i_bg * V + o_v, mask=mask_v, other=0).to(tl.float32) + b_hv = b_hv * b_g[None, :] + b_s[None, :] * b_v[:, None] + b_ov = tl.sum(b_hv * b_qv[None, :], axis=1) + + tl.store(o + i_bh * V + o_v, b_ov.to(o.dtype.element_ty), mask=mask_v) + + if i_bh % NG == 0: + p_hvt = hvt + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None] + tl.store(p_hvt, b_hv.to(p_hvt.dtype.element_ty), mask=mask_hv) + + +@triton.jit +def fused_recurrent_gated_abc_fwd_kernel( + q, + k, + v, + gk, + gv, + o, + h0, + ht, + s_k_h, + s_v_h, + scale, + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + REVERSE: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + + if USE_GK: + p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + if USE_GV: + p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + + mask_k = (i_k * BK + tl.arange(0, BK)) < K + mask_v = (i_v * BV + tl.arange(0, BV)) < V + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + mask_h = mask_k[None, :] & mask_v[:, None] + + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) + b_h = b_h * tl.exp(b_gk)[None, :] + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32) + b_h = b_h * tl.exp(b_gv)[:, None] + b_h += b_k[None, :] * b_v[:, None] + b_o = b_h * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + p_q += -K if REVERSE else K + p_k += -K if REVERSE else K + p_o += -V if REVERSE else V + p_v += -V if REVERSE else V + if USE_GK: + p_gk += -K if REVERSE else K + if USE_GV: + p_gv += -V if REVERSE else V + + if STORE_FINAL_STATE: + p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +@triton.jit +def fused_recurrent_gated_abc_bwd_kernel( + q, + k, + v, + gk, + gv, + do, + dq, + dk, + dv, + dh0, + h0, + s_k_h, + s_v_h, + scale, + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + REVERSE: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + if USE_GK: + p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + if USE_GV: + p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + mask_k = i_k * BK + tl.arange(0, BK) < K + mask_v = i_v * BV + tl.arange(0, BV) < V + mask_h = mask_k[:, None] & mask_v[None, :] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) + b_h = b_h * tl.exp(b_gk)[:, None] + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32) + b_h = b_h * tl.exp(b_gv)[None, :] + b_h += b_k[:, None] * b_v[None, :] + b_dq = tl.sum(b_h * b_do[None, :], axis=1) * scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k) + + p_k += -K if REVERSE else K + p_v += -V if REVERSE else V + p_q += -K if REVERSE else K + p_do += -V if REVERSE else V + p_dq += -K if REVERSE else K + if USE_GK: + p_gk += -K if REVERSE else K + if USE_GV: + p_gv += -V if REVERSE else V + + # sync threads + tl.debug_barrier() + + p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + if USE_GK: + p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + if USE_GV: + p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for _ in range(T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + b_dh += b_q[:, None] * b_do[None, :] + b_dk = tl.sum(b_dh * b_v[None, :], axis=1) + b_dv = tl.sum(b_dh * b_k[:, None], axis=0) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) + b_dh *= tl.exp(b_gk)[:, None] + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32) + b_dh *= tl.exp(b_gv)[None, :] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v) + + p_q += K if REVERSE else -K + p_k += K if REVERSE else -K + p_v += V if REVERSE else -V + p_do += V if REVERSE else -V + p_dk += K if REVERSE else -K + p_dv += V if REVERSE else -V + if USE_GK: + p_gk += K if REVERSE else -K + if USE_GV: + p_gv += V if REVERSE else -V + + if USE_INITIAL_STATE: + p_dh0 = dh0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_h) + + +class FusedRecurrentGatedABCFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: torch.Tensor, + scale: Optional[float] = None, + hk0: Optional[torch.Tensor] = None, + hv0: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + inference_mode: bool = False + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: + B, H, T, K, V, M = *k.shape, v.shape[-1], s.shape[-1] + HQ = q.shape[1] + + BK, BV, BM = min(K, 64), min(V, 64), min(M, 64) + NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM) + NG = HQ // H + num_warps = 1 + num_stages = 1 + + hkt, hvt = None, None + if output_final_state: + hkt, hvt = (hk0, hv0) if inference_mode and NG == 1 else (q.new_empty(B, H, K, M, dtype=torch.float), q.new_empty(B, H, M, V, dtype=torch.float)) + + if inference_mode: + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 16) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + o = v.new_empty(B, HQ, T, V) + grid = (B * HQ,) + fused_recurrent_gated_abc_inference_kernel[grid]( + q, k, v, s, g, o, hk0, hv0, hkt, hvt, + scale=scale, + K=K, V=V, M=M, BK=BK, BV=BV, NG=NG, + num_warps=num_warps, + num_stages=num_stages + ) + return o, (hkt, hvt) + + ok = q.new_empty(NK, B, H, T, M, dtype=torch.float) + gk, gv = None, g + grid = (NM, NK, B * H) + fused_recurrent_gated_abc_fwd_kernel[grid]( + q, k, s, gk, gv, ok, hk0, hkt, + k.stride(1), + s.stride(1), + scale=scale, + B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM, + USE_INITIAL_STATE=hk0 is not None, + STORE_FINAL_STATE=hkt is not None, + USE_GK=False, + USE_GV=True, + REVERSE=reverse, + num_warps=num_warps, + num_stages=num_stages + ) + ok = ok.sum(0) + + qv = ok.softmax(-1, dtype=torch.float) + ov = q.new_empty(NM, B, H, T, V, dtype=torch.float) + gk, gv = g, None + grid = (NV, NM, B * H) + fused_recurrent_gated_abc_fwd_kernel[grid]( + qv, s, v, gk, gv, ov, hv0, hvt, + s.stride(1), + v.stride(1), + scale=1., + B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV, + USE_INITIAL_STATE=hv0 is not None, + STORE_FINAL_STATE=hvt is not None, + USE_GK=True, + USE_GV=False, + REVERSE=reverse, + num_warps=num_warps, + num_stages=num_stages + ) + ov = ov.sum(0) + + ctx.save_for_backward(q, k, v, s, g, qv, hk0, hv0, ok) + ctx.scale = scale + ctx.reverse = reverse + return ov.to(q.dtype), (hkt, hvt) + + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, dht=None): + q, k, v, s, g, qv, hk0, hv0, ok = ctx.saved_tensors + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + scale = ctx.scale + + BK, BV, BM = min(K, 64), min(V, 64), min(M, 64) + NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM) + num_warps = 1 + num_stages = 1 + + dqv = q.new_empty(NV, B, H, T, M, dtype=torch.float) + dsv = q.new_empty(NV, B, H, T, M, dtype=torch.float) + dv = q.new_empty(NM, B, H, T, V, dtype=torch.float) + dhk0 = torch.empty_like(hk0)if hk0 is not None else None + dhv0 = torch.empty_like(hv0)if hv0 is not None else None + + gk, gv = g, None + grid = (NV, NM, B * H) + fused_recurrent_gated_abc_bwd_kernel[grid]( + qv, s, v, gk, gv, do, dqv, dsv, dv, dhv0, hv0, + s.stride(1), + v.stride(1), + scale=1., + B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV, + USE_INITIAL_STATE=hv0 is not None, + REVERSE=ctx.reverse, + USE_GK=gk is not None, + USE_GV=gv is not None, + num_warps=num_warps, + num_stages=num_stages + ) + dqv = dqv.sum(0) + dsv = dsv.sum(0) + dv = dv.sum(0) + dgk = dqv * qv.float() - dsv * s.float() + dgk_cumsum = dgk.cumsum(-2) + dgk = dgk + dgk_cumsum[:, :, -1, None] - dgk_cumsum + + dok = qv * (dqv - (qv * dqv).sum(-1, True)) + dq = q.new_empty(NM, B, H, T, K, dtype=torch.float) + dk = q.new_empty(NM, B, H, T, K, dtype=torch.float) + dsk = q.new_empty(NK, B, H, T, M, dtype=torch.float) + gk, gv = None, g + grid = (NM, NK, B * H) + fused_recurrent_gated_abc_bwd_kernel[grid]( + q, k, s, gk, gv, dok, dq, dk, dsk, dhk0, hk0, + q.stride(1), + s.stride(1), + scale=scale, + B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM, + USE_INITIAL_STATE=hk0 is not None, + REVERSE=ctx.reverse, + USE_GK=gk is not None, + USE_GV=gv is not None, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dsk = dsk.sum(0) + + dgv = dok.float() * ok.float() - dsk * s.float() + dgv_cumsum = dgv.cumsum(-2) + dgv = dgv + dgv_cumsum[:, :, -1, None] - dgv_cumsum + + ds = dsk.add_(dsv) + dg = dgk.add_(dgv) + + return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, dhk0, dhv0, None, None, None + + +def fused_recurrent_gated_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[Tuple[torch.Tensor]] = None, + output_final_state: Optional[bool] = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `(B, H, T, K)` + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + g (torch.Tensor): + Forget gates of shape `(B, H, T, M)` applied to keys. + If not provided, this function is equivalent to vanilla ABC. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[Tuple[torch.Tensor]]): + Initial state tuple having tensors of shape `(B, H, K, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state tuple, having tensors of shape `(B, H, K, V)`. Default: `False`. + """ + if g is None: + # TODO: this 3 steps took huge amount of time, ought to be optimized + z = s.float().logcumsumexp(2) + g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z + s = torch.exp(s - z).to(k.dtype) + if scale is None: + scale = q.shape[-1] ** -0.5 + if initial_state is None: + initial_state = (None, None) + inference_mode = q.shape[2] == 1 and not q.requires_grad + ov, final_state = FusedRecurrentGatedABCFunction.apply( + q, k, v, s, g, scale, *initial_state, output_final_state, False, inference_mode + ) + return ov, final_state diff --git a/fla2/ops/based/__init__.py b/fla2/ops/based/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5bcfcdc536a2a3eea00541e768207e633e8485fe --- /dev/null +++ b/fla2/ops/based/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk_fuse import fused_chunk_based +from .parallel import parallel_based + +__all__ = [ + 'fused_chunk_based', + 'parallel_based' +] diff --git a/fla2/ops/based/__pycache__/__init__.cpython-312.pyc b/fla2/ops/based/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..264d89e2f12941378b9d2010913443c3259139f8 Binary files /dev/null and b/fla2/ops/based/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/ops/based/__pycache__/__init__.cpython-38.pyc b/fla2/ops/based/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1e3c443b47921b333f1ad464584a5afd257b389 Binary files /dev/null and b/fla2/ops/based/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/ops/based/__pycache__/__init__.cpython-39.pyc b/fla2/ops/based/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9bb2069517d6e970499c25c9253b875618c7d93 Binary files /dev/null and b/fla2/ops/based/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/ops/based/__pycache__/chunk_fuse.cpython-312.pyc b/fla2/ops/based/__pycache__/chunk_fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..107cf7726be2685b8f1d9ac8441df7ebf88cb21c Binary files /dev/null and b/fla2/ops/based/__pycache__/chunk_fuse.cpython-312.pyc differ diff --git a/fla2/ops/based/__pycache__/chunk_fuse.cpython-38.pyc b/fla2/ops/based/__pycache__/chunk_fuse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eae926fb6eb487eb506cf2dba104e654bd49c253 Binary files /dev/null and b/fla2/ops/based/__pycache__/chunk_fuse.cpython-38.pyc differ diff --git a/fla2/ops/based/__pycache__/parallel.cpython-312.pyc b/fla2/ops/based/__pycache__/parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5336df9b77e9e0b5c4d62cc7d8c7edbca557ae2b Binary files /dev/null and b/fla2/ops/based/__pycache__/parallel.cpython-312.pyc differ diff --git a/fla2/ops/based/chunk_fuse.py b/fla2/ops/based/chunk_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..76ed5da8a7855ed60ed39abfcdf5d978a1f08169 --- /dev/null +++ b/fla2/ops/based/chunk_fuse.py @@ -0,0 +1,389 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_chunk_based_fwd_kernel( + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + o, # output [B, H, L, V] + z, # normalizer [B, H, L, 1] + s_qk_h, # stride size: L * K + s_qk_t, # stride size: K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * V + s_vo_t, # stride size: V + s_vo_d, # stride size: 1 + scale, # K ** -0.5 + B: tl.constexpr, # batch size + H: tl.constexpr, # H + T: tl.constexpr, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + + # [BV], zero-order taylor expansion + b_h_0o = tl.zeros([BV], dtype=tl.float32) + # [BK, BV], first-order taylor expansion + b_h_1o = tl.zeros([BK, BV], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT) + k_2o = tl.zeros([1, BK * BK], dtype=tl.float32) + k_1o = tl.zeros([1, BK], dtype=tl.float32) + k_0o = 0 + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK*BK, BT] + b_k_2o = b_k[:, None, :] * b_k[None, :, :] + b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype) + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_z = tl.zeros([BT], dtype=tl.float32) + + # interchunk + # zero-order + b_o += b_h_0o + b_z += k_0o + # first-order + b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False) + b_z += tl.sum(b_q * k_1o, axis=1) + # second-order + b_q_2o = b_q[:, :, None] * b_q[:, None, :] + b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype) + b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5 + b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5 + + # update running statistics + k_1o += tl.sum(b_k, axis=1)[None, :] + k_2o += tl.sum(b_k_2o, axis=1)[None, :] + k_0o += BT + + # intrachunk + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + # [TB, BV] + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=(i * BT + tl.arange(0, BT)) < T) + + # update hidden state + # [BK, BV] + b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False) + b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False) + b_h_0o = b_h_0o + tl.sum(b_v, axis=0) + + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_z += BT + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_chunk_based_bwd_kernel( + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + do, # gradient of output [B, H, L, V] + dz, # gradient of normalizer [B, H, L] + dq, # gradient of query [NV, B, H, L, K] + dk, # gradient of key [NV, B, H, L, K] + dv, # gradient of value [NK, B, H, L, V] + s_qk_h, # stride size: L * K + s_qk_t, # stride size: K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * V + s_vo_t, # stride size: V + s_vo_d, # stride size: 1 + scale, # K ** -0.5 + B: tl.constexpr, # B + H: tl.constexpr, # H + T: tl.constexpr, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + # [BV], zero-order taylor expansion + # b_h_0o = tl.zeros([BV], dtype=tl.float32) + # [BK, BV], first-order taylor expansion + b_h_1o = tl.zeros([BV, BK], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32) + + k_1o = tl.zeros([1, BK], dtype=tl.float32) + k_2o = tl.zeros([1, BK * BK], dtype=tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + + # load tensors + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T) + # [BV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # inter-chunk + b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False) + if i_v == 0: + b_dq += b_dz[:, None] * k_1o + b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5 + if i_v == 0: + b_dq_2o += (b_dz[:, None] * k_2o) * 0.5 + b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK]) + b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1) + b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2) + b_dq *= scale + + # intra-chunk + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False) + + # store + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # update hidden state + # [BT, BK*BK] + b_k_2o = b_k[:, :, None] * b_k[:, None, :] + b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype) + # [BV, BK*BK] + b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False) + # [BV, BK] + b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False) + + if i_v == 0: + # update running statistics + k_1o += tl.sum(b_k, axis=0)[None, :] + k_2o += tl.sum(b_k_2o, axis=0)[None, :] + + tl.debug_barrier() + b_h_1o = None + b_h_2o = None + + # [BK, BV], first-order taylor expansion + b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32) + b_dh_0o = tl.zeros([BV], dtype=tl.float32) + m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :] + + dq_1o = tl.zeros([1, BK], dtype=tl.float32) + dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32) + + for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i, i_v*BV), (BT, BV), (1, 0)) + p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i + + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T) + b_q = (b_q * scale).to(b_k.dtype) + + # intra chunk + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + b_ds = tl.where(m_s, b_ds, 0) + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + b_ds *= (1+b_s) + + b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False) + b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False) + + # inter chunk + b_k_2o = b_k[:, :, None] * b_k[:, None, :] + b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype) + + b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False) + b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False) + b_dv += b_dh_0o + + b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False) + + if i_v == 0: + b_dk += dq_1o + + b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype), tl.trans(b_v), allow_tf32=False) + if i_v == 0: + b_dk_2o += dq_2o + b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT]) + b_k_fp32 = tl.trans(b_k.to(tl.float32)) + b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0) + b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1) + b_dk += tl.trans(b_dk2) + + # hidden state update + b_dh_0o += tl.sum(b_do, axis=0) + b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False) + b_q_2o = b_q[None, :, :] * b_q[:, None, :] + b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype) + b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5 + + if i_v == 0: + dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :] + dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None] + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkBasedFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, scale=1): + B, H, T, K, V = *k.shape, v.shape[-1] + + scale = scale + BT = 16 + BK, BV = min(K, 16), min(V, 32) + BK, BV = max(BK, 16), max(BV, 16) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + num_warps = 4 + + # the norm of o might explode, so we need to use float32 here + o = q.new_empty(NK, B, H, T, V, dtype=torch.float32) + z = q.new_empty(NK, B, H, T, dtype=torch.float32) + + grid = (NV, NK, B * H) + fused_chunk_based_fwd_kernel[grid]( + q, k, v, o, z, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + scale, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + ) + o = o.sum(0) + z = z.sum(0) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.to(q.dtype), z.to(z.dtype) + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + B, H, T, K, V = *k.shape, v.shape[-1] + scale = ctx.scale + + BT = 16 + BK, BV = min(K, 16), min(V, 32) + BK, BV = max(BK, 16), max(BV, 16) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + grid = (NV, NK, B * H) + + fused_chunk_based_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + scale, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None + + +triton_fused_chunk_based = FusedChunkBasedFunction.apply + + +def fused_chunk_based( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + use_norm: bool = True +): + assert q.shape[-1] <= 16, 'only support feature dimension up to 16.' + if scale is None: + scale = q.shape[-1] ** -0.5 + o, z = triton_fused_chunk_based(q, k, v, scale) + if use_norm: + o = o / (z[..., None] + 1e-6) + return o.to(q.dtype) diff --git a/fla2/ops/based/naive.py b/fla2/ops/based/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..4de614137ed28567ebb1df39c0892f498b91fb5a --- /dev/null +++ b/fla2/ops/based/naive.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +from einops import rearrange + + +def naive_parallel_based( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + use_norm: bool = True +): + if scale is None: + scale = q.shape[-1] ** -0.5 + q = q * scale + attn = q @ k.transpose(-2, -1) + attn = 1 + attn + 1/2 * (attn ** 2) + attn.masked_fill_(~torch.tril(torch.ones( + q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) + o = attn @ v + if use_norm: + z = attn.sum(-1) + return o / (z[..., None] + 1e-6) + else: + return o + + +def naive_chunk_based(q, k, v, chunk_size=256): + q = q * (q.shape[-1] ** -0.5) + # compute normalizer. + k_cumsum = torch.cumsum(k, dim=-2) + kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3) + # first + z = (q * k_cumsum).sum(-1) + # second order + z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5 + # zero-th order + z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :] + + # compute o + # constant term + _o = v.cumsum(-2) + + q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) + + k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) + v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) + + intra_chunk_attn = q @ k.transpose(-2, -1) + intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2) + intra_chunk_attn.masked_fill_(~torch.tril(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device)), 0) + o = intra_chunk_attn @ v + + # quadractic term + kv = torch.einsum('b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v) + kv = kv.cumsum(2) + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + + o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q) + + # linear term + kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v) + kv = kv.cumsum(2) + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q) + + o = rearrange(o, 'b h n c d -> b h (n c) d') + o = o + _o + return o / (z[..., None] + 1e-6) diff --git a/fla2/ops/based/parallel.py b/fla2/ops/based/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..9c68dd9f9eae1fffd1b206a43f2f602cf46e9fac --- /dev/null +++ b/fla2/ops/based/parallel.py @@ -0,0 +1,403 @@ + +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + +# Based: An Educational and Effective Sequence Mixer +# https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based + + +@triton.jit +def parallel_based_fwd_kernel( + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + o, # output [B, H, L, V] + z, # normalizer [B, H, L] + s_qk_h, # stride size: L * K + s_qk_t, # stride size: K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * V + s_vo_t, # stride size: V + s_vo_d, # stride size: 1 + scale, # K ** -0.5 + B: tl.constexpr, # batch size + H: tl.constexpr, # H + T: tl.constexpr, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q + BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension +): + # i_c: chunk index. used for sequence parallelism + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(V, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0)) + + # [BQ, BD] block Q, in the shared memory throughout the whole kernel + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_o = tl.zeros([BTL, BV], dtype=tl.float32) + b_z = tl.zeros([BTL], dtype=tl.float32) + + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for _ in range(0, i_c * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_s = tl.dot(b_q, (b_k), allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_z += tl.sum(b_s, axis=1) + + # [BQ, BD] + b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + + # # rescale interchunk output + tl.debug_barrier() + o_q = tl.arange(0, BTL) + # # sync threads, easy for compiler to optimize + # tl.debug_barrier() + + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + # [BTL, BV] + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + o_k += BTS + + p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c * BTL + tl.arange(0, BTL)) < T)) + + +@triton.jit +def _parallel_based_bwd_dq( + i_bh, + i_c, + i_k, + i_v, + i_h, + q, + k, + v, + do, + dz, + dq, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, +): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, K), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_q = (b_q * scale).to(b_q.dtype) + b_dq = tl.zeros([BTL, BK], dtype=tl.float32) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL) + b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T) + + for _ in range(0, i_c * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + # [BQ, BD] + b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + + b_dq *= scale + o_q = tl.arange(0, BTL) + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BTL, BK] + b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype), + b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + o_k += BTS + p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, K), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def _parallel_based_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + K: tl.constexpr, V: tl.constexpr, +): + # compute dk dv + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load( + p_v, boundary_check=(0, 1)) + b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros( + [BTL, BV], dtype=tl.float32) + + for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS] + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \ + scale # [BTL, BTS] + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale + if i_v == 0: + b_ds += b_dz[None, :] * scale + else: + b_ds = b_ds + b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False) + + tl.debug_barrier() + o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) + for i in range(i_c*BTL, (i_c+1)*BTL, BTS): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + # [BK, BQ] + m_s = o_k[:, None] <= o_q[None, :] + b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + + b_ds = tl.dot(b_v, b_do, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + # [BK, BD] + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), + tl.trans(b_q), allow_tf32=False) + o_q += BTS + + p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h, (T, K), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h, (T, V), + (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def parallel_based_bwd_kernel( + q, + k, + v, + do, + dz, + dq, + dk, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + scale, + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(V, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + i_h = i_bh % H + _parallel_based_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, K=K, V=V + ) + tl.debug_barrier() + _parallel_based_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, K, V + ) + + +class ParallelBasedFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, scale): + BTL, BTS = 128, 32 + assert BTL % BTS == 0 + # assert q.shape[-1] % 16 == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + B, H, T, K, V = *k.shape, v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + grid = (NK * NV, triton.cdiv(T, BTL), B * H) + + assert NK == 1, "will encounter some synchronization issue if not." + + o = torch.empty(NK, B, H, T, V, device=q.device) + z = torch.empty(NK, B, H, T, device=q.device) + parallel_based_fwd_kernel[grid]( + q, k, v, o, z, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + scale, + B=B, H=H, T=T, K=K, V=V, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype) + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + scale = ctx.scale + BTL, BTS = 64, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + B, H, T, K, V = *k.shape, v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + grid = (NK * NV, triton.cdiv(T, BTL), B * H) + + assert NK == 1, "will encounter some synchronization issue if not" + + dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device) + dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device) + dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device) + + parallel_based_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + scale, + B=B, H=H, T=T, K=K, V=V, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + + return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None + + +triton_parallel_based = ParallelBasedFunction.apply + + +def parallel_based( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + use_norm: bool = True +): + assert q.shape[-1] <= 128, "only support feature dim up to 128" + if scale is None: + scale = q.shape[-1] ** -0.5 + o, z = triton_parallel_based(q, k, v, scale) + if use_norm: + o = o / (z[..., None] + 1e-6) + return o.to(q.dtype) diff --git a/fla2/ops/common/__pycache__/chunk_h.cpython-310.pyc b/fla2/ops/common/__pycache__/chunk_h.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8512793e4965c80f9ba62a5abaf26e66cfd1e751 Binary files /dev/null and b/fla2/ops/common/__pycache__/chunk_h.cpython-310.pyc differ diff --git a/fla2/ops/common/__pycache__/fused_recurrent.cpython-310.pyc b/fla2/ops/common/__pycache__/fused_recurrent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..034dd3a97eed2cad8bc0aed6ce4a95799b460d23 Binary files /dev/null and b/fla2/ops/common/__pycache__/fused_recurrent.cpython-310.pyc differ diff --git a/fla2/ops/common/chunk_h.py b/fla2/ops/common/chunk_h.py new file mode 100644 index 0000000000000000000000000000000000000000..87585482689ad361843122056f22b529c232eebe --- /dev/null +++ b/fla2/ops/common/chunk_h.py @@ -0,0 +1,249 @@ +import triton +import triton.language as tl +import torch + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=["BT", "BK", "BV", "USE_G", 'USE_GK', 'USE_GV'], +) +@triton.jit +def chunk_fwd_kernel_h( + k, + v, + h, + g, + gk, + gv, + h0, + ht, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + last_idx = min((i_t + 1) * BT, T) - 1 + + # scalar decay + if USE_G: + b_g_last = tl.load(g + i_bh * T + last_idx) + b_h *= tl.exp(b_g_last) + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_v = (b_v * tl.exp(b_g_last - b_g)[:, None]).to(b_v.dtype) + + # vector decay, h = Diag(gk) @ h + if USE_GK: + p_gk_last = tl.make_block_ptr(gk + i_bh * s_qk_h, (T * K,), (s_qk_d,), (last_idx * K + i_k * BK,), (BK,), (0,)) + b_gk_last = tl.load(p_gk_last, boundary_check=(0,)) + b_h *= tl.exp(b_gk_last)[:, None] + + p_gk = tl.make_block_ptr(gk + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_k = (b_k * tl.exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype) + + # vector decay, h = h @ Diag(gv) + if USE_GV: + p_gv_last = tl.make_block_ptr(gv + i_bh * s_vo_h, (T * V,), (s_vo_d,), (last_idx * V + i_v * BV,), (BV,), (0,)) + b_gv_last = tl.load(p_gv, boundary_check=(0,)) + b_h *= tl.exp(b_gv_last)[None, :] + + p_gv = tl.make_block_ptr(gv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_v = (b_v * tl.exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype) + + b_h += tl.dot(b_k, b_v, allow_tf32=False) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=["BT", "BK", "BV", "USE_G", 'USE_GK', 'USE_GV'], +) +@triton.jit +def chunk_bwd_kernel_dh( + q, + g, + gk, + gv, + do, + dh, + dht, + dh0, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + LOAD_FINAL_STATE_GRADIENT: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if LOAD_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT - 1, -1, -1): + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + last_idx = min(i_t * BT + BT, T) - 1 + # [BK, BT] + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + if USE_G: + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_q = (b_q * tl.exp(b_g)[None, :]).to(b_q.dtype) + b_g_last = tl.load(g + i_bh * T + last_idx) + b_dh *= tl.exp(b_g_last) + + if USE_GK: + p_gk = tl.make_block_ptr(gk + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_q = (b_q * tl.exp(b_gk)).to(b_q.dtype) + p_gk_last = tl.make_block_ptr(gk + i_bh * s_qk_h, (T * K,), (s_qk_d,), (last_idx * K + i_k * BK,), (BK,), (0,)) + b_gk_last = tl.load(p_gk_last, boundary_check=(0,)) + b_dh *= tl.exp(b_gk_last)[:, None] + + if USE_GV: + p_gv = tl.make_block_ptr(gv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_do = (b_do * tl.exp(b_gv)).to(b_do.dtype) + p_gv_last = tl.make_block_ptr(gv + i_bh * s_vo_h, (T * V,), (s_vo_d,), (last_idx * V + i_v * BV,), (BV,), (0,)) + b_gv_last = tl.load(p_gv, boundary_check=(0,)) + b_dh *= tl.exp(b_gv_last)[None, :] + + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + + + if STORE_INITIAL_STATE_GRADIENT: + p_dh0 = tl.make_block_ptr(dh0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + + + +def chunk_fwd_h_fn(k, v, g, gk, gv, BT, h0, output_final_state): + B, H, T, K, V = *k.shape, v.shape[-1] + ht = None + if output_final_state: + ht = k.new_empty(B, H, K, V, dtype=torch.float32) + + BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V)) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + + USE_G, USE_GK, USE_GV = g is not None, gk is not None, gv is not None + + chunk_fwd_kernel_h[grid]( + k, v, h, g, gk, gv, h0, ht, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=h0 is not None, + STORE_FINAL_STATE=output_final_state, + USE_G=USE_G, USE_GK=USE_GK, USE_GV=USE_GV + ) + return h, ht + + + +def chunk_bwd_dh_fn(q, k, v, g, gk, gv, do, h0, dht, BT, scale): + B, H, T, K, V = *k.shape, v.shape[-1] + BT = 64 + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + dh = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + if h0 is not None: + dh0 = torch.empty_like(h0, dtype=torch.float32) + else: + dh0 = None + USE_GATE = (g is not None) or (gk is not None) or (gv is not None) + assert not (USE_GATE and dht is not None), "Cannot load final state gradient and use gates at the same time" + chunk_bwd_kernel_dh[grid]( + q, g, gk, gv, do, dh, dht, dh0, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + dh.stride(1), dh.stride(2), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + USE_G=g is not None, USE_GK=gk is not None, USE_GV=gv is not None, + STORE_INITIAL_STATE_GRADIENT=dh0 is not None, + LOAD_FINAL_STATE_GRADIENT=dht is not None + ) + return dh, dh0 + + + diff --git a/fla2/ops/common/fused_recurrent.py b/fla2/ops/common/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..2cadffd58d6bdb46098f3d56b74cd7c28f6cef1f --- /dev/null +++ b/fla2/ops/common/fused_recurrent.py @@ -0,0 +1,346 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang +from typing import Tuple +import torch +import triton +import triton.language as tl + +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +from ...ops.utils import chunk_global_reversed_cumsum, chunk_global_cumsum + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8) + ], + key=["BK", "BV", "USE_GK", "USE_GV", "USE_G"], +) +@triton.jit +def fused_recurrent_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, K] + k, # key [B, H, L, K] + v, # value [B, H, L, V] + g, # log gate [B, H, L] or None + gk, # log gate [B, H, L, K] or None + gv, # log gate [B, H, L, V] or None + o, # output [NK, B, H, L, V] + h0, # initial hidden state [B, H, K, V] + ht, # final hidden state [B, H, K, V] + s_qk_h, # stride size: L * K + s_vo_h, # stride size: L * V + scale, # K ** -0.5 + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + REVERSE: tl.constexpr, # whether to reverse the recurrence + USE_GK: tl.constexpr, # whether to use gk + USE_GV: tl.constexpr, # whether to use gv + USE_G: tl.constexpr, # whether to use g +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + + if USE_G: + p_g = g + i_bh * T + ((T-1) if REVERSE else 0) + if USE_GK: + p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + if USE_GV: + p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < K + mask_bv = (i_v * BV + tl.arange(0, BV)) < V + mask_kv = mask_bk[None, :] & mask_bv[:, None] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32) + b_h = b_h * tl.exp(b_gk[None, :]) + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32) + b_h = b_h * tl.exp(b_gv[:, None]) + if USE_G: + b_g = tl.load(p_g).to(tl.float32) + b_h = b_h * tl.exp(b_g) + b_h += b_k[None, :] * b_v[:, None] + b_o = b_h * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv) + p_q += -K if REVERSE else K + p_k += -K if REVERSE else K + p_o += -V if REVERSE else V + p_v += -V if REVERSE else V + if USE_GK: + p_gk += -K if REVERSE else K + if USE_GV: + p_gv += -V if REVERSE else V + if USE_G: + p_g += -1 if REVERSE else 1 + + + if STORE_FINAL_STATE: + p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8) + ], + key=["BK", "BV", "USE_GK", "USE_GV", "USE_G"], +) +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + g, # log gate [B, H, L] + gk, # log gate [B, H, L, K] \alpha + gv, # log gate [B, H, L, V] \bete + do, # gradient wrt output [B, H, L, V] + dq, # gradient wrt query [NV, B, H, L, K] + dk, # gradient wrt key [NV, B, H, L, K] + dv, # gradient wrt value [NK, B, H, L, V] + dht, # gradient wrt final hidden state [B, H, K, V] + dh0, # gradient wrt initial hidden state [B, H, K, V] + h0, # initial hidden state [B, H, K, V] + s_qk_h, # stride size: L * K + s_vo_h, # stride size: L * V + scale, # K ** -0.5 + B, + H, + T, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction + USE_GK: tl.constexpr, # whether to use gk + USE_GV: tl.constexpr, # whether to use gv + USE_G: tl.constexpr, # whether to use g + USE_FINAL_STATE_GRADIENT: tl.constexpr, # whether to compute gradient wrt final state + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, # whether to store gradient wrt initial state +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + if USE_GK: + p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + if USE_GV: + p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + if USE_G: + p_g = g + i_bh * T + ((T-1) if REVERSE else 0) + mask_bk = i_k * BK + tl.arange(0, BK) < K + mask_bv = i_v * BV + tl.arange(0, BV) < V + mask_kv = mask_bk[:, None] & mask_bv[None, :] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + for i in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32) + b_h = b_h * tl.exp(b_gk[:, None]) + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32) + b_h = b_h * tl.exp(b_gv[None, :]) + if USE_G: + b_g = tl.load(p_g).to(tl.float32) + b_h = b_h * tl.exp(b_g) + b_h += b_k[:, None] * b_v[None, :] + b_dq = b_h * b_do[None, :] + d_q = tl.sum(b_dq, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + p_k += -K if REVERSE else K + p_v += -V if REVERSE else V + p_q += -K if REVERSE else K + p_do += -V if REVERSE else V + p_dq += -K if REVERSE else K + if USE_GK: + p_gk += -K if REVERSE else K + if USE_GV: + p_gv += -V if REVERSE else V + if USE_G: + p_g += -1 if REVERSE else 1 + + # sync threads + tl.debug_barrier() + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + if USE_GK: + p_gk = gk + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + if USE_GV: + p_gv = gv + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + if USE_G: + p_g = g + i_bh * T + ((T - 1) if not REVERSE else 0) + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = dht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + b_dh += tl.load(p_dht, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(T): + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_dh += b_q[:, None] * b_do[None, :] + d_k = tl.sum(b_dh * b_v[None, :], axis=1) + d_v = tl.sum(b_dh * b_k[:, None], axis=0) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32) + b_dh *= tl.exp(b_gk)[:, None] + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32) + b_dh *= tl.exp(b_gv)[None, :] + if USE_G: + b_g = tl.load(p_g).to(tl.float32) + b_dh *= tl.exp(b_g) + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + + p_q += K if REVERSE else -K + p_k += K if REVERSE else -K + p_v += V if REVERSE else -V + p_do += V if REVERSE else -V + p_dk += K if REVERSE else -K + p_dv += V if REVERSE else -V + if USE_GK: + p_gk += K if REVERSE else -K + if USE_GV: + p_gv += V if REVERSE else -V + if USE_G: + p_g += 1 if REVERSE else -1 + + if STORE_INITIAL_STATE_GRADIENT: + p_dh0 = dh0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_kv) + + + +class FusedRecurrentFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, g, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False): + B, H, T, K, V = *q.shape, v.shape[-1] + # default scale + if scale is None: + scale = K ** -0.5 + + BK, BV = min(K, 64), min(V, 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + o = q.new_empty(NK, B, H, T, V, dtype=torch.float32) + + h0 = initial_state + if output_final_state: + ht = q.new_empty(B, H, K, V, dtype=torch.float32) + else: + ht = None + + grid = (NV, NK, B * H) + fused_recurrent_fwd_kernel[grid]( + q, k, v, g, gk, gv, o, h0, ht, + q.stride(1), v.stride(1), + scale, + B=B, H=H, T=T, K=K, V=V, + BK=BK, BV=BV, + USE_INITIAL_STATE=h0 is not None, + STORE_FINAL_STATE=ht is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + USE_G=g is not None, + REVERSE=reverse, + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, g, gk, gv, h0, o) + ctx.scale = scale + ctx.reverse = reverse + return o.to(q.dtype), ht + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, dht): + q, k, v, g, gk, gv, h0, o = ctx.saved_tensors + batch_size, n_heads, seq_len, K = q.shape + V = v.shape[-1] + scale = ctx.scale + + BK, BV = min(K, 64), min(V, 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + dq = q.new_empty(NV, batch_size, n_heads, seq_len, K, dtype=torch.float32) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, K, dtype=torch.float32) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, V, dtype=torch.float32) + dh0 = torch.empty_like(h0) if (h0 is not None) else None + grid = (NV, NK, batch_size * n_heads) + + fused_recurrent_bwd_kernel[grid]( + q, k, v, g, gk, gv, do, dq, dk, dv, dht, dh0, h0, + q.stride(1), + v.stride(1), scale, + B=batch_size, H=n_heads, T=seq_len, K=K, V=V, BK=BK, BV=BV, + USE_INITIAL_STATE=h0 is not None, + REVERSE=ctx.reverse, + USE_GK=gk is not None, + USE_GV=gv is not None, + USE_G=g is not None, + USE_FINAL_STATE_GRADIENT=dht is not None, + STORE_INITIAL_STATE_GRADIENT=dh0 is not None + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + fn = chunk_global_cumsum if ctx.reverse else chunk_global_reversed_cumsum + dgk = fn(dq * q.float() - dk * k.float()) if gk is not None else None + dgv = fn(do.float() * o.float() - dv * v.float()) if gv is not None else None + dg = fn((dq * q.float() - dk * k.float()).sum(-1)) if g is not None else None + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, dgk, dgv, None, dh0, None, None + + +def fused_recurrent(q, k, v, g=None, gk=None, gv=None, scale=None, initial_state=None, output_final_state=False, reverse=False): + return FusedRecurrentFunction.apply(q, k, v, g, gk, gv, scale, initial_state, output_final_state, reverse) diff --git a/fla2/ops/delta_rule/README.md b/fla2/ops/delta_rule/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1ab2d485a9552d70238c1f68288c72c62f9e0ef2 --- /dev/null +++ b/fla2/ops/delta_rule/README.md @@ -0,0 +1,4 @@ +- Delta Rule + +The implementation of delta rule described in https://arxiv.org/abs/2102.11174 + diff --git a/fla2/ops/delta_rule/chunk.py b/fla2/ops/delta_rule/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..41ee3fee316199a6106d2740356da7944aca8ba5 --- /dev/null +++ b/fla2/ops/delta_rule/chunk.py @@ -0,0 +1,543 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +import torch +import triton +import triton.language as tl + +from ...ops.delta_rule.wy_fast import (bwd_prepare_wy_repr, + fwd_prepare_wy_repr, fwd_recompute_w_u) +from ...ops.utils import contiguous +from ...utils import autocast_custom_bwd, autocast_custom_fwd + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_dv_kernel( + q, + k, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_k.dtype) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def fwd_prepare_dv(q, k, do, BT): + dv = torch.empty_like(do) + B, H, T, K, V = *k.shape, do.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + fwd_prepare_dv_kernel[(NT, B*H)]( + q, k, do, dv, + k.stride(1), k.stride(2), k.stride(3), + do.stride(1), do.stride(2), do.stride(3), + T, K, V, K**-0.5, BT, BK, BV + ) + return dv + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_fwd_kernel_h( + k, + v, + d, + v_new, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + b_h_cumsum = tl.zeros([BK, BV], dtype=tl.float32) + # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden + for i_c in range(tl.cdiv(BT, BC)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BK] + b_d = tl.load(p_d, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False) + # [BK, BV] + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + b_h_cumsum += tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False) + b_h += b_h_cumsum + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) + p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + p_dv2 = tl.make_block_ptr(dv2 + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BV] + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d, b_dv.to(b_q.dtype), allow_tf32=False) + b_dh += b_dh_tmp + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + h, + do, + dh, + dq, + dk, + dv, + dw, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False) + + # [BT, BT] + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype) + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dq *= scale + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state): + B, H, T, K, V = *k.shape, u.shape[-1] + + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + v_new = torch.empty_like(u) + chunk_delta_rule_fwd_kernel_h[grid]( + k, u, w, v_new, h, initial_state, final_state, + k.stride(1), k.stride(2), k.stride(3), + u.stride(1), u.stride(2), u.stride(3), + h.stride(1), h.stride(2), + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + + +def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): + B, H, T, K, V = *q.shape, do.shape[-1] + + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B, H, NT * K, V) + # dv_new = torch.empty_like(do) + grid = (NK, NV, B * H) + dv2 = torch.empty_like(dv) + chunk_delta_rule_bwd_kernel_dhu[grid]( + q, k, w, do, dh, dv, dv2, + q.stride(1), q.stride(2), q.stride(3), + do.stride(1), do.stride(2), do.stride(3), + dh.stride(1), dh.stride(2), + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT, + ) + return dh, dv2 + + +def chunk_fwd_o_fn(q, k, v_new, h, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + + BK = triton.next_power_of_2(K) + o = torch.empty_like(v_new) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H) + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, o, + q.stride(1), q.stride(2), q.stride(3), + v_new.stride(1), v_new.stride(2), v_new.stride(3), + h.stride(1), h.stride(2), + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + ) + return o + + +def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + + BK = triton.next_power_of_2(K) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K, BK) + NT = triton.cdiv(T, BT) + grid = (NK, NT, B * H) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dw = torch.empty_like(w) + chunk_delta_rule_bwd_kernel_dqkw[grid]( + q, k, v_new, w, h, do, dh, dq, dk, du, dw, + q.stride(1), q.stride(2), q.stride(3), + v_new.stride(1), v_new.stride(2), v_new.stride(3), + dh.stride(1), dh.stride(2), + scale=K ** -0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + ) + return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + + +class ChunkDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=1): + # obtain WY representation. u is actually the new v. + w, u, A = fwd_prepare_wy_repr(k, v, beta, BT) + # ### forward_h + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False) + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state) + # obtain output + o = chunk_fwd_o_fn(q, k, v_new, h, BT) + # save memory + if checkpoint_level == 1: + h, v_new = None, None + ctx.save_for_backward(q, k, v, beta, A, h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta, A, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + w, u = fwd_recompute_w_u(k, v, beta, A, BT) + # checkpont_level=1, recomputation. + if h is None: + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + dv = fwd_prepare_dv(q, k, do, BT) + dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT) + dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT) + dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, dv, BT) + dk.add_(dk2) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None + + +def chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state) + return o, final_state diff --git a/fla2/ops/delta_rule/naive.py b/fla2/ops/delta_rule/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..1e4c628f0472d00081386a121655df146b018bb0 --- /dev/null +++ b/fla2/ops/delta_rule/naive.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +def delta_rule_recurrence(q, k, v, beta): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + + if beta.ndim < v.ndim: + beta = beta[..., None] + + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v - (S.clone() * _k[..., None]).sum(-2) + _v = _v * beta_i + S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + + return o + + +def delta_rule_chunkwise(q, k, v, beta, chunk_size=32): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + q = q * (d_k ** -0.5) + v = v * beta[..., None] + k_beta = k * beta[..., None] + + assert l % chunk_size == 0 + + # note that diagonal is masked. + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0) + q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta]) + attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0) + + for i in range(1, chunk_size): + attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2) + + attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device) + # u + k_cumsum = attn @ v + # w + k_cumdecay = attn @ k_beta + + v = k_cumsum + S = k.new_zeros(b, h, d_k, d_v) + o = torch.zeros_like(v) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) + for i in range(0, l // chunk_size): + q_i, k_i, v_i = q[:, :, i], k[:, :, i], v[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0) + v_prime = k_cumdecay[:, :, i] @ S + v_new = v_i - v_prime + o_inter = q_i @ S + o[:, :, i] = o_inter + attn @ v_new + # chunk state update + S = S + k_i.transpose(-1, -2) @ v_new + + return rearrange(o, 'b h n c d -> b h (n c) d') + + +if __name__ == '__main__': + B = 2 + H = 4 + L = 256 + DK = 128 + DV = 128 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + + o = delta_rule_recurrence(q, k, v, beta) + do = torch.randn(B, H, L, DV).cuda() + o.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + beta_grad, beta.grad = beta.grad, None + + o2 = delta_rule_chunkwise(q, k, v, beta) + o2.backward(do) + assert torch.allclose(o, o2, atol=1e-4), breakpoint() + assert torch.allclose(q.grad, q_grad, atol=1e-4), breakpoint() + assert torch.allclose(k.grad, k_grad, atol=1e-4), breakpoint() + assert torch.allclose(v.grad, v_grad, atol=1e-4), breakpoint() + assert torch.allclose(beta.grad, beta_grad, atol=1e-4), breakpoint() + print("All passed!") diff --git a/fla2/ops/rotary.py b/fla2/ops/rotary.py new file mode 100644 index 0000000000000000000000000000000000000000..18ccc5f06a231f6a92aa2bfdca290fe9a65ffae7 --- /dev/null +++ b/fla2/ops/rotary.py @@ -0,0 +1,252 @@ +# Copyright (c) 2023, Tri Dao. https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py + +from typing import Optional, Union + +import torch + +import triton +import triton.language as tl + + +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_M": 2}), +# triton.Config({"BLOCK_M": 4}), +# triton.Config({"BLOCK_M": 8}), +# triton.Config({"BLOCK_M": 16}), +# ], +# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"], +# ) +@triton.jit +def rotary_kernel( + OUT, # Pointers to matrices + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, # this could be int or a pointer + # Matrix dimensions + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + # strides + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + # Meta-parameters + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * BLOCK_M >= seqlen: + return + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rk = tl.arange(0, BLOCK_K) + rk_half = tl.arange(0, BLOCK_K // 2) + + if not INTERLEAVED: + # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT + X = X + (rm[:, None] * stride_x_seqlen + + rk_half[None, :] * stride_x_headdim) + COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + cos = tl.load( + COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0 + ).to(tl.float32) + sin = tl.load( + SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0 + ).to(tl.float32) + x0 = tl.load( + X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0 + ).to(tl.float32) + x1 = tl.load( + X + rotary_dim_half * stride_x_headdim, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + # write back result + OUT = OUT + (rm[:, None] * stride_out_seqlen + + rk_half[None, :] * stride_out_headdim) + tl.store(OUT, o0, mask=(rm[:, None] < seqlen) + & (rk_half[None, :] < rotary_dim_half)) + tl.store( + OUT + rotary_dim_half * stride_out_headdim, + o1, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + ) + else: + # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow. + # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...]. + # Loading x0 will be fast but x1 will be slow. + # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...]. + # Then we do the calculation and use tl.where to pick put the right outputs for the even + # and for the odd indices. + rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... + rk_repeat = tl.arange(0, BLOCK_K) // 2 + X0 = X + (rm[:, None] * stride_x_seqlen + + rk[None, :] * stride_x_headdim) + X1 = X + (rm[:, None] * stride_x_seqlen + + rk_swap[None, :] * stride_x_headdim) + COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + cos = tl.load( + COS, + mask=(rm_cs[:, None] < seqlen_ro) & ( + rk_repeat[None, :] < rotary_dim_half), + other=1.0, + ).to(tl.float32) + sin = tl.load( + SIN, + mask=(rm_cs[:, None] < seqlen_ro) & ( + rk_repeat[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to( + tl.float32 + ) + x1 = tl.load( + X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0 + ).to(tl.float32) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + OUT = OUT + (rm[:, None] * stride_out_seqlen + + rk[None, :] * stride_out_headdim) + tl.store(OUT, out, mask=(rm[:, None] < seqlen) + & (rk[None, :] < rotary_dim)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + + assert ( + cos.dtype == sin.dtype + ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert ( + x.dtype == cos.dtype + ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = ( + 32 + if rotary_dim <= 32 + else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) + ) + def grid(META): return (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(x.device.index): + rotary_kernel[grid]( + output, # data ptrs + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, # shapes + nheads, + rotary_dim, + seqlen_ro, + # key for triton cache (limit number of compilations) + seqlen // 128, + # batch_strides if not varlen else 0 + output.stride(0) if not is_varlen else 0, + output.stride(-3), # seqlen_stride or total_seqlen_stride + output.stride(-2), # nheads_stride + output.stride(-1), # headdim_stride + # batch_strides if not varlen else 0 + x.stride(0) if not is_varlen else 0, + x.stride(-3), # seqlen stride or total_seqlen_stride + x.stride(-2), # nheads stride + x.stride(-1), # headdim stride + BLOCK_K, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + BLOCK_M, + ) + return output diff --git a/fla2/ops/utils.py b/fla2/ops/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..22518b225d704409c6f32820d6878c4da46e6ac0 --- /dev/null +++ b/fla2/ops/utils.py @@ -0,0 +1,471 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from ..utils import contiguous + + +@triton.autotune( + configs=[ + triton.Config({'BT': 16}, num_warps=2), + triton.Config({'BT': 16}, num_warps=4), + triton.Config({'BT': 16}, num_warps=8), + triton.Config({'BT': 32}, num_warps=2), + triton.Config({'BT': 32}, num_warps=4), + triton.Config({'BT': 32}, num_warps=8), + triton.Config({'BT': 64}, num_warps=2), + triton.Config({'BT': 64}, num_warps=4), + triton.Config({'BT': 64}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def logcumsumexp_fwd_kernel( + s, + z, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr +): + i_bh = tl.program_id(0) + o_i = tl.arange(0, BT) + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + + b_mp = tl.full([S,], float('-inf'), dtype=tl.float32) + b_zp = tl.zeros([S,], dtype=tl.float32) + for i_t in range(tl.cdiv(T, BT)): + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0)) + + # [BT, S] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + # [S,] + b_mc = tl.max(b_s, 0) + # workaround for compiler bugs + if i_t > 0: + b_mc = tl.maximum(b_mp, b_mc) + b_zp = b_zp * tl.exp(b_mp - b_mc) + # [BT, S] + b_s = tl.exp(b_s - b_mc) + b_z = tl.dot(m_s, b_s, allow_tf32=False) + b_zp + # [S,] + b_zc = tl.max(b_z, 0) + b_mp = b_mc + b_zp = b_zc + # [BT, BS] + # small eps to prevent underflows + b_z = tl.log(tl.where(b_z != 0, b_z, 1e-20)) + b_mc + tl.store(p_z, b_z.to(p_z.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def softmax_fwd_kernel( + s, + p, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0)) + p_p = tl.make_block_ptr(p + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0)) + + # [BT, S] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + # [BT] + b_m = tl.max(b_s, 1) + + # [BT, BS] + b_s = tl.exp(b_s - b_m[:, None]) + b_z = tl.sum(b_s, 1) + b_p = tl.where(b_s != 0, b_s / b_z[:, None], 0.) + tl.store(p_p, b_p.to(p_p.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def softmax_bwd_kernel( + p, + dp, + ds, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_p = tl.make_block_ptr(p + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0)) + p_dp = tl.make_block_ptr(dp + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0)) + p_ds = tl.make_block_ptr(ds + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0)) + # [BT, BS] + b_p = tl.load(p_p, boundary_check=(0, 1)).to(tl.float32) + b_dp = tl.load(p_dp, boundary_check=(0, 1)).to(tl.float32) + # [BT,] + b_pp = tl.sum(b_p * b_dp, 1) + # [BT, BS] + b_ds = b_p * b_dp - b_p * b_pp[:, None] + tl.store(p_ds, b_ds.to(p_ds.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({'BT': 16}, num_warps=2), + triton.Config({'BT': 16}, num_warps=4), + triton.Config({'BT': 16}, num_warps=8), + triton.Config({'BT': 32}, num_warps=2), + triton.Config({'BT': 32}, num_warps=4), + triton.Config({'BT': 32}, num_warps=8), + triton.Config({'BT': 64}, num_warps=2), + triton.Config({'BT': 64}, num_warps=4), + triton.Config({'BT': 64}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def chunk_global_reversed_cumsum_vector_kernel( + s, + z, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr +): + i_s, i_bh = tl.program_id(0), tl.program_id(1) + o_i = tl.arange(0, BT) + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) + + b_z = tl.zeros([BS], dtype=tl.float32) + for i_t in range(tl.cdiv(T, BT) - 1, -1, -1): + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1)) + + if i_t >= 0: + b_z += tl.sum(b_s, 0) + + +@triton.autotune( + configs=[ + triton.Config({'BT': 16}, num_warps=2), + triton.Config({'BT': 16}, num_warps=4), + triton.Config({'BT': 16}, num_warps=8), + triton.Config({'BT': 32}, num_warps=2), + triton.Config({'BT': 32}, num_warps=4), + triton.Config({'BT': 32}, num_warps=8), + triton.Config({'BT': 64}, num_warps=2), + triton.Config({'BT': 64}, num_warps=4), + triton.Config({'BT': 64}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def chunk_global_cumsum_vector_kernel( + s, + z, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr +): + i_s, i_bh = tl.program_id(0), tl.program_id(1) + o_i = tl.arange(0, BT) + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + b_z = tl.zeros([BS], dtype=tl.float32) + for i_t in range(tl.cdiv(T, BT)): + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1)) + if i_t >= 0: + b_z += tl.sum(b_s, 0) + + +@triton.autotune( + configs=[ + triton.Config({'BT': 16}, num_warps=2), + triton.Config({'BT': 32}, num_warps=4), + triton.Config({'BT': 32}, num_warps=2), + triton.Config({'BT': 64}, num_warps=8), + triton.Config({'BT': 64}, num_warps=4), + ], + key=[] +) +@triton.jit +def chunk_global_reversed_cumsum_scalar_kernel( + s, + o, + T: tl.constexpr, + BT: tl.constexpr, +): + i_bh = tl.program_id(0) + b_z = tl.zeros([], dtype=tl.float32) + for i_t in range(tl.cdiv(T, BT) - 1, -1, -1): + p_s = tl.make_block_ptr(s + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_zz = tl.sum(b_s, axis=0) + b_z += b_zz + b_o = b_s - tl.cumsum(b_s, axis=0) + b_z[None] + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.autotune( + configs=[ + triton.Config({'BT': 16}, num_warps=2), + triton.Config({'BT': 32}, num_warps=4), + triton.Config({'BT': 32}, num_warps=2), + triton.Config({'BT': 64}, num_warps=8), + triton.Config({'BT': 64}, num_warps=4), + ], + key=[] +) +@triton.jit +def chunk_global_cumsum_scalar_kernel( + s, + o, + T: tl.constexpr, + BT: tl.constexpr, +): + i_bh = tl.program_id(0) + b_z = tl.zeros([], dtype=tl.float32) + for i_t in range(tl.cdiv(T, BT)): + p_s = tl.make_block_ptr(s + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + b_z[None] + b_zz = tl.sum(b_s, axis=0) + b_z += b_zz + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.autotune( + configs=[ + triton.Config({'BS': 16}, num_warps=2), + triton.Config({'BS': 16}, num_warps=4), + triton.Config({'BS': 16}, num_warps=8), + triton.Config({'BS': 32}, num_warps=2), + triton.Config({'BS': 32}, num_warps=4), + triton.Config({'BS': 32}, num_warps=8), + triton.Config({'BS': 64}, num_warps=2), + triton.Config({'BS': 64}, num_warps=4), + triton.Config({'BS': 64}, num_warps=8), + ], + key=['S', 'BT'] +) +@triton.jit +def chunk_local_cumsum_vector_kernel( + s, + o, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8) + ], + key=['BT'] +) +@triton.jit +def chunk_local_cumsum_scalar_kernel( + s, + o, + T: tl.constexpr, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + p_s = tl.make_block_ptr(s + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +def chunk_local_cumsum_vector(g, BT): + B, H, T, S = g.shape + NT = triton.cdiv(T, BT) + g_org, g = g, torch.empty_like(g, dtype=torch.float) + def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) + # keep cummulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid]( + g_org, g, + g.stride(1), g.stride(2), g.stride(3), + T=T, S=S, BT=BT + ) + return g + + +def chunk_local_cumsum_scalar(g, BT): + B, H, T = g.shape + NT = triton.cdiv(T, BT) + g_org, g = g, torch.empty_like(g, dtype=torch.float) + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + g_org, g, + T=T, BT=BT + ) + return g + + +@contiguous +def chunk_local_cumsum(g, BT): + if len(g.shape) == 3: + return chunk_local_cumsum_scalar(g, BT) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector(g, BT) + else: + raise ValueError( + f"Unsupported shape {g.shape}. Should be either (batch size, num_heads, seq_len, dim) or (batch_size, num_heads, seq_len)" + ) + + +@contiguous +def chunk_global_reversed_cumsum_vector( + s: torch.Tensor, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + B, H, T, S = s.shape + BS = 32 + dtype = dtype or s.dtype + grid = (triton.cdiv(S, BS), B * H) + z = torch.empty_like(s, dtype=dtype) + chunk_global_reversed_cumsum_vector_kernel[grid]( + s, z, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=S, BS=BS + ) + return z + + +@contiguous +def chunk_global_reversed_cumsum_scalar( + s: torch.Tensor, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + B, H, T = s.shape + dtype = dtype or s.dtype + grid = (B * H,) + z = torch.empty_like(s, dtype=dtype) + chunk_global_reversed_cumsum_scalar_kernel[grid]( + s, z, + T=T + ) + return z + + +@contiguous +def chunk_global_cumsum_vector( + s: torch.Tensor, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + B, H, T, S = s.shape + BS = 32 + dtype = dtype or s.dtype + grid = (triton.cdiv(S, BS), B * H) + z = torch.empty_like(s, dtype=dtype) + chunk_global_cumsum_vector_kernel[grid]( + s, z, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=S, BS=BS + ) + return z + + +@contiguous +def chunk_global_cumsum_scalar( + s: torch.Tensor, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + B, H, T = s.shape + dtype = dtype or s.dtype + grid = (B * H,) + z = torch.empty_like(s, dtype=dtype) + chunk_global_cumsum_scalar_kernel[grid]( + s, z, + T=T + ) + return z + + +@contiguous +def chunk_global_cumsum(s, dtype=None): + if len(s.shape) == 3: + return chunk_global_cumsum_scalar(s, dtype) + elif len(s.shape) == 4: + return chunk_global_cumsum_vector(s, dtype) + else: + raise ValueError(f"Unsupported shape {s.shape}. " + f"Should be either [batch size, num_heads, seq_len] or [batch_size, num_heads, seq_len, dim]") + + +@contiguous +def chunk_global_reversed_cumsum(s, dtype=None): + if len(s.shape) == 3: + return chunk_global_reversed_cumsum_scalar(s, dtype) + elif len(s.shape) == 4: + return chunk_global_reversed_cumsum_vector(s, dtype) + else: + raise ValueError(f"Unsupported shape {s.shape}. " + f"Should be either [batch size, num_heads, seq_len] or [batch_size, num_heads, seq_len, dim]")