diff --git a/fla/layers/__pycache__/__init__.cpython-312.pyc b/fla/layers/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adad64ad3abc2d69b093dd1db2597bd9a8ac0e75 Binary files /dev/null and b/fla/layers/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/layers/__pycache__/bitattn.cpython-312.pyc b/fla/layers/__pycache__/bitattn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..403dc67aa4d4824e2abcaffef3e5b3b1ef69ef22 Binary files /dev/null and b/fla/layers/__pycache__/bitattn.cpython-312.pyc differ diff --git a/fla/layers/__pycache__/forgetting_attn.cpython-312.pyc b/fla/layers/__pycache__/forgetting_attn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7102985e01eb2c6b39030524963c1afa6bf8f626 Binary files /dev/null and b/fla/layers/__pycache__/forgetting_attn.cpython-312.pyc differ diff --git a/fla/layers/__pycache__/gla.cpython-312.pyc b/fla/layers/__pycache__/gla.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a27f95c553cf26fbe0d258c877966cb5d6e2f585 Binary files /dev/null and b/fla/layers/__pycache__/gla.cpython-312.pyc differ diff --git a/fla/layers/__pycache__/hgrn2.cpython-312.pyc b/fla/layers/__pycache__/hgrn2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80101b4dc112e76c6c3aa9030ad1e84000cc72ae Binary files /dev/null and b/fla/layers/__pycache__/hgrn2.cpython-312.pyc differ diff --git a/fla/models/gated_deltaproduct/__init__.py b/fla/models/gated_deltaproduct/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea7067f53a23309a9817b6de3bc3eb22480cf753 --- /dev/null +++ b/fla/models/gated_deltaproduct/__init__.py @@ -0,0 +1,14 @@ +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.gated_deltaproduct.configuration_gated_deltaproduct import GatedDeltaProductConfig +from fla.models.gated_deltaproduct.modeling_gated_deltaproduct import GatedDeltaProductForCausalLM, GatedDeltaProductModel + +AutoConfig.register(GatedDeltaProductConfig.model_type, GatedDeltaProductConfig) +AutoModel.register(GatedDeltaProductConfig, GatedDeltaProductModel) +AutoModelForCausalLM.register(GatedDeltaProductConfig, GatedDeltaProductForCausalLM) + +__all__ = [ + "GatedDeltaProductConfig", + "GatedDeltaProductForCausalLM", + "GatedDeltaProductModel", +] diff --git a/fla/models/gsa/__init__.py b/fla/models/gsa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a134f758e0bea0eb844a2db73957936078f889b6 --- /dev/null +++ b/fla/models/gsa/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.gsa.configuration_gsa import GSAConfig +from fla.models.gsa.modeling_gsa import GSAForCausalLM, GSAModel + +AutoConfig.register(GSAConfig.model_type, GSAConfig) +AutoModel.register(GSAConfig, GSAModel) +AutoModelForCausalLM.register(GSAConfig, GSAForCausalLM) + + +__all__ = ['GSAConfig', 'GSAForCausalLM', 'GSAModel'] diff --git a/fla/models/gsa/configuration_gsa.py b/fla/models/gsa/configuration_gsa.py new file mode 100644 index 0000000000000000000000000000000000000000..9e849796bb0c8eaca203ef1fcb5f3139bcec96b8 --- /dev/null +++ b/fla/models/gsa/configuration_gsa.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class GSAConfig(PretrainedConfig): + + model_type = 'gsa' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + hidden_size: int = 2048, + gate_logit_normalizer: Optional[int] = 8, + clamp_min: Optional[float] = None, + clamp_max: Optional[float] = None, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + num_slots: Optional[int] = 64, + use_short_conv: bool = False, + conv_size: int = 4, + exapnd_k: float = 1, + exapnd_v: float = 1, + feature_map: str = 'swish', + use_output_gate: bool = False, + use_norm: bool = True, + max_position_embeddings: int = 2048, + hidden_act: str = "swish", + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + initializer_range: float = 0.006, + tie_word_embeddings: bool = False, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs + ): + self.hidden_size = hidden_size + self.gate_logit_normalizer = gate_logit_normalizer + self.clamp_min = clamp_min + self.clamp_max = clamp_max + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.num_slots = num_slots + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.expand_k = exapnd_k + self.expand_v = exapnd_v + self.feature_map = feature_map + self.use_output_gate = use_output_gate + self.use_norm = use_norm + self.max_position_embeddings = max_position_embeddings + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla/models/rwkv7/__init__.py b/fla/models/rwkv7/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6f132f3fc8de7108242e1accc51e55f4a4e6ed5 --- /dev/null +++ b/fla/models/rwkv7/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config +from fla.models.rwkv7.modeling_rwkv7 import RWKV7ForCausalLM, RWKV7Model + +AutoConfig.register(RWKV7Config.model_type, RWKV7Config, True) +AutoModel.register(RWKV7Config, RWKV7Model, True) +AutoModelForCausalLM.register(RWKV7Config, RWKV7ForCausalLM, True) + + +__all__ = ['RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model'] diff --git a/fla/models/transformer_top/configuration_transformer.py b/fla/models/transformer_top/configuration_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..57d9266145d0ba6c34eee3ec86136d8cde0c6c3b --- /dev/null +++ b/fla/models/transformer_top/configuration_transformer.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class TOPTransformerConfig(PretrainedConfig): + + model_type = 'top_transformer' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + hidden_size: int = 2048, + num_hidden_layers: int = 24, + num_heads: int = 32, + num_kv_heads: int = None, + qkv_bias: bool = False, + qk_norm: bool = False, + window_size: Optional[int] = None, + rope_theta: Optional[float] = 10000., + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + initializer_range: float = 0.006, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + use_top_loss: bool = False, + top_loss_ratio: float = 0.5, + top_window_size: Optional[int] = None, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.qkv_bias = qkv_bias + self.qk_norm = qk_norm + self.window_size = window_size + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + + self.initializer_range = initializer_range + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + self.use_top_loss = use_top_loss + self.top_loss_ratio = top_loss_ratio + self.top_window_size = top_window_size if top_window_size is not None else max_position_embeddings + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla/modules/fused_cross_entropy.py b/fla/modules/fused_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..f85091f66fe5539d4d6c68ca801b3b51ac8b94e4 --- /dev/null +++ b/fla/modules/fused_cross_entropy.py @@ -0,0 +1,419 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Tri Dao. + +from typing import Any, Tuple + +import torch +import torch.nn as nn +import triton +import triton.language as tl + +from fla.ops.utils.op import exp, log +from fla.utils import input_guard + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 2 lines are for backward compatibility with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base + + +@triton.heuristics({ + "HAS_SMOOTHING": lambda args: args["label_smoothing"] > 0.0, +}) +@triton.jit +def cross_entropy_fwd_kernel( + loss_ptr, # data ptrs + lse_ptr, + z_loss_ptr, + logits_ptr, + labels_ptr, + label_smoothing, + logit_scale, + lse_square_scale, + ignore_index, + total_classes, + class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes + n_cols, # shapes + n_rows, + logits_row_stride, # strides + BLOCK_SIZE: tl.constexpr, + HAS_SMOOTHING: tl.constexpr, + # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE + SPLIT: tl.constexpr, +): + row_idx = tl.program_id(0) + col_block_idx = tl.program_id(1) + logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) + col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + label_idx = tl.load(labels_ptr + row_idx) + logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")) + logits = logits.to(tl.float32) * logit_scale + max_logits = tl.max(logits, 0) + if HAS_SMOOTHING: + sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0) + lse = log(tl.sum(exp(logits - max_logits), 0)) + max_logits + tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse) + if label_idx == ignore_index: + loss = 0.0 + z_loss = 0.0 + else: + label_idx -= class_start_idx + if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min( + n_cols, (col_block_idx + 1) * BLOCK_SIZE + ): + logits_label = tl.load(logits_ptr + label_idx) * logit_scale + if HAS_SMOOTHING: + loss = ( + (lse if not SPLIT else 0.0) + - label_smoothing * sum_logits / total_classes + - (1 - label_smoothing) * logits_label + ) + else: + loss = (lse if not SPLIT else 0.0) - logits_label + else: + # If label is out of bounds, we set the CE loss to 0.0. But we still want the label_smoothing loss + if HAS_SMOOTHING: + loss = label_smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes) + else: + loss = 0.0 + if not SPLIT: + z_loss = lse_square_scale * lse * lse + loss += z_loss + else: + z_loss = 0.0 + tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss) + if not SPLIT: + tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss) + + +@triton.heuristics({ + "HAS_SMOOTHING": lambda args: args["label_smoothing"] > 0.0, +}) +@triton.jit +def cross_entropy_bwd_kernel( + dlogits_ptr, # data ptrs + dloss_ptr, + logits_ptr, + lse_ptr, + labels_ptr, + label_smoothing, + logit_scale, + lse_square_scale, + ignore_index, + total_classes, + class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes + n_cols, # shapes + logits_row_stride, # strides + dlogits_row_stride, + dloss_row_stride, + BLOCK_SIZE: tl.constexpr, + HAS_SMOOTHING: tl.constexpr, +): + row_idx = tl.program_id(0) + col_block_idx = tl.program_id(1) + logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) + dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) + col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + label_idx = tl.load(labels_ptr + row_idx) + if label_idx != ignore_index: + dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) + else: + dloss = 0.0 + logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( + tl.float32 + ) * logit_scale + lse = tl.load(lse_ptr + row_idx) + probs = exp(logits - lse) + probs += 2.0 * lse_square_scale * lse * probs + label_idx -= class_start_idx + if HAS_SMOOTHING: + smooth_negative = label_smoothing / total_classes + probs = tl.where(col_offsets == label_idx, probs - (1 - label_smoothing), probs) - smooth_negative + else: + probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) + tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) + + +def fused_cross_entropy_forward( + logits: torch.Tensor, + target: torch.Tensor, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + lse_square_scale: float = 0.0, + ignore_index: int = -100, + process_group=None, +): + n_rows, n_cols = logits.shape + assert target.shape == (n_rows,) + world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) + total_classes = world_size * n_cols + rank = 0 if process_group is None else torch.distributed.get_rank(process_group) + class_start_idx = rank * n_cols + + if logits.stride(-1) != 1: + logits = logits.contiguous() + # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py + MAX_BLOCK_SIZE = 64 * 1024 + BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) + num_warps = ( + 4 + if BLOCK_SIZE < 2048 + else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) + ) + # We may split the lse computation across multiple blocks, then do a reduction + # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k) + # where having just one thread block processing more than 64k elements is slow. + split = world_size > 1 or n_cols > MAX_BLOCK_SIZE + n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE + loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,) + losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) + lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) + z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) + + cross_entropy_fwd_kernel[(n_rows, n_splits)]( + losses, # data ptrs + lse, + z_losses, + logits, + target, + label_smoothing, + logit_scale, + lse_square_scale, + ignore_index, + total_classes, + class_start_idx, + n_cols, # shapes + n_rows, + logits.stride(0), # strides + BLOCK_SIZE=BLOCK_SIZE, # constants + num_warps=num_warps, + SPLIT=split + ) + + if split: + # If there's no label_smoothing, if target are in the vocab of this partition, losses contains + # - predicted logit, and 0 otherwise. + # If there's label_smoothing=0.1, for target in the vocab of this partition, losses contains + # -0.9 * predicted logit - 0.1 * sum logit / total_classes. + # For target not in the vocab of this partition, losses contains + # -0.1 * sum logit / total_classes. + if n_splits > 1: + lse = torch.logsumexp(lse, dim=0) + losses = losses.sum(dim=0) + if world_size > 1: + lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) + torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) + handle_losses = torch.distributed.all_reduce( + losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True + ) + lse = torch.logsumexp(lse_allgather, dim=0) + handle_losses.wait() + # After the allreduce, if there's no label_smoothing, the total losses are - predicted_logit, + # we just have to add the (global) lse. + # If there's label_smoothing=0.1, the total losses are + # -0.9 * predicted_logit - 0.1 * sum logit / total_classes. + # Again, we just have to add the (global) lse. + losses += lse + if lse_square_scale != 0.0: + z_losses = lse_square_scale * lse.square() + z_losses.masked_fill_(target == ignore_index, 0.0) + losses += z_losses + else: + z_losses = torch.zeros_like(losses) + losses.masked_fill_(target == ignore_index, 0.0) + + return losses, z_losses, lse, total_classes, class_start_idx + + +class CrossEntropyLossFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + logits, + target, + label_smoothing=0.0, + logit_scale=1.0, + lse_square_scale=0.0, + ignore_index=-100, + inplace_backward=False, + process_group=None, + ): + losses, z_losses, lse, total_classes, class_start_idx = fused_cross_entropy_forward( + logits, + target, + label_smoothing, + logit_scale, + lse_square_scale, + ignore_index, + process_group, + ) + ctx.save_for_backward(logits, lse, target) + ctx.mark_non_differentiable(z_losses) + ctx.label_smoothing = label_smoothing + ctx.logit_scale = logit_scale + ctx.lse_square_scale = lse_square_scale + ctx.ignore_index = ignore_index + ctx.total_classes = total_classes + ctx.class_start_idx = class_start_idx + ctx.inplace_backward = inplace_backward + + return losses, z_losses + + @staticmethod + @input_guard + def backward(ctx, grad_losses, grad_z_losses): + del grad_z_losses # z_losses are only for logging. + + logits, lse, target = ctx.saved_tensors + dlogits = logits if ctx.inplace_backward else torch.empty_like(logits) + n_rows, n_cols = logits.shape + BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) + num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) + def grid(META): return (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa + cross_entropy_bwd_kernel[grid]( + dlogits, # data ptrs + grad_losses, + logits, + lse, + target, + ctx.label_smoothing, + ctx.logit_scale, + ctx.lse_square_scale, + ctx.ignore_index, + ctx.total_classes, + ctx.class_start_idx, + n_cols, # shapes + logits.stride(0), # strides + dlogits.stride(0), + grad_losses.stride(0), + BLOCK_SIZE=BLOCK_SIZE, # constants + num_warps=num_warps, + ) + return dlogits, None, None, None, None, None, None, None, None + + +def cross_entropy_loss( + logits: torch.Tensor, + target: torch.Tensor, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + lse_square_scale: float = 0.0, + ignore_index=-100, + inplace_backward: bool = False, + process_group=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + logits: [batch, vocab_size] + target: [batch,] + label_smoothing: float + logit_scale: float. + Multiply logits by this scale before calculating the loss. + lse_square_scale: float. + If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. + This is also referred to as "z-loss". + ignore_index: int. + If target == ignore_index, the loss is set to 0.0. + inplace_backward: bool. + If True, we do the backward pass in-place by modifying the logits. + This saves memory. + process_group: + if not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss will be aggregated across processes. + Returns: + losses: [batch,], float + z_losses: [batch,], float + """ + return CrossEntropyLossFunction.apply( + logits, + target, + label_smoothing, + logit_scale, + lse_square_scale, + ignore_index, + inplace_backward, + process_group, + ) + + +class FusedCrossEntropyLoss(nn.Module): + def __init__( + self, + ignore_index: int = -100, + reduction: str = "mean", + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + lse_square_scale: float = 0.0, + inplace_backward: bool = False, + process_group: Any = None, + return_z_loss: bool = False, + ): + """ + Arguments: + ignore_index: int. If target == ignore_index, the loss is set to 0.0. + label_smoothing: float + lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. + This is also referred to as "z-loss". + inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. + This saves memory. + process_group: if not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss will be aggregated across processes. + return_z_loss: bool. If True, we return the component of the loss contributed by + the lse_square_scale value. This value is only for logging and does not support + backprop. + """ + super().__init__() + if reduction not in ["mean", "none", "sum"]: + raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.logit_scale = logit_scale + self.lse_square_scale = lse_square_scale + self.inplace_backward = inplace_backward + self.process_group = process_group + self.return_z_loss = return_z_loss + + def forward(self, input, target): + """ + Arguments: + input: (batch, vocab_size) + target: (batch,) + Returns: + losses: (batch,) if reduction is 'none', else (1,), dtype float + z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss) + """ + assert input.is_cuda and target.is_cuda, "Only support CUDA tensors" + loss, z_loss = cross_entropy_loss( + input, + target, + label_smoothing=self.label_smoothing, + logit_scale=self.logit_scale, + lse_square_scale=self.lse_square_scale, + ignore_index=self.ignore_index, + inplace_backward=self.inplace_backward, + process_group=self.process_group, + ) + if self.reduction == "mean": + loss = loss.sum() / (target != self.ignore_index).sum() + elif self.reduction == "sum": + loss = loss.sum() + else: + loss = loss + + if not self.return_z_loss: + return loss + + if self.reduction == "mean": + z_loss = z_loss.sum() / (target != self.ignore_index).sum() + elif self.reduction == "sum": + z_loss = z_loss.sum() + else: + z_loss = z_loss + + return loss, z_loss diff --git a/fla/ops/abc/__init__.py b/fla/ops/abc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fdac8d900fc51485a55716443ee1f00424b522b9 --- /dev/null +++ b/fla/ops/abc/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_abc + +__all__ = [ + 'chunk_abc' +] diff --git a/fla/ops/abc/chunk.py b/fla/ops/abc/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..8538e04800cd71414782ff72668df1fbd97984b1 --- /dev/null +++ b/fla/ops/abc/chunk.py @@ -0,0 +1,1116 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import logcumsumexp_fwd_kernel, softmax_bwd, softmax_fwd +from fla.ops.utils.op import exp +from fla.utils import input_guard + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_fwd_kernel_h( + k, + v, + z, + h, + h0, + ht, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + NORMK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + if NORMK: + p_z0 = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_k * BK,), (BK,), (0,)) + else: + p_z0 = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_v * BV,), (BV,), (0,)) + b_zp = tl.load(p_z0).to(tl.float32) + for i_t in range(NT): + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + if NORMK: + p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = exp(b_zp - b_zc), b_zc + # [BK, BV] + b_h = b_h * b_r[:, None] + b_k = exp(b_k - b_zc[:, None]).to(b_k.dtype) + else: + p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = exp(b_zp - b_zc), b_zc + # [BK, BV] + b_h = b_h * b_r[None, :] + b_v = exp(b_v - b_zc[None, :]).to(b_v.dtype) + # [BK, BV] + b_h += tl.dot(b_k, b_v, allow_tf32=False) + + if STORE_FINAL_STATE: + p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_fwd_kernel_intra_K( + v, + z, + o, + A, + T, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BV] + b_o = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(0, i_i): + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, exp(b_v - b_zn[None, :]).to(b_v.dtype), allow_tf32=False) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_o *= exp(b_zn[None, :] - b_z) + + o_i = tl.arange(0, BC) + o_A = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + p_v = tl.make_block_ptr(v + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + # [BC,] + b_A = tl.load(A + o_A + j, mask=m_A, other=0) + # [BV,] + b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32) + # [BC, BV] + # avoid 0 * inf = inf + m_i = o_i[:, None] >= j + b_o += tl.where(m_i, b_A[:, None] * exp(b_v[None, :] - b_z), 0) + p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_fwd_kernel_K( + q, + k, + z, + h, + o, + A, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_o += tl.dot(b_q, b_h, allow_tf32=False) + # [BT, BT] + b_A += tl.dot(b_q, b_k, allow_tf32=False) + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + # [BT, BV] + p_zp = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,)) + b_zp = tl.load(p_zp, boundary_check=(0,)) + b_o = b_o * exp(b_zp[None, :] - b_z) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BT] + b_A = tl.where(m_s, b_A, 0.) + if i_v == 0: + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_fwd_kernel_intra_V( + q, + k, + z, + A, + scale, + T, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + n_bh = tl.num_programs(2) + + if i_i > i_j: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_q = (b_q * exp(b_zn[None, :] - b_z) * scale).to(b_q.dtype) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k = exp(b_k - b_zn[:, None]).to(b_k.dtype) + # [BC, BC] + b_A = tl.dot(b_q, b_k, allow_tf32=False) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + + o_i = tl.arange(0, BC) + o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + # [BK,] + b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32) + # [BC,] + b_A = tl.sum(b_q * exp(b_k[None, :] - b_z) * scale, 1) + b_A = tl.where(o_i >= j, b_A, 0.) + tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A) + + p_k = tl.advance(p_k, (K,)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_fwd_kernel_V( + q, + v, + z, + h, + o, + A, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_zp = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_z = tl.load(p_z, boundary_check=(0, 1)) + # [BT, BK] + b_zp = tl.load(p_zp, boundary_check=(0,)) + b_q = (b_q * exp(b_zp[None, :] - b_z)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_q, b_h, allow_tf32=False) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A.to(b_v.dtype), b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_dh( + q, + z, + do, + dh, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + NORMK: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + b_zp = tl.full([BK if NORMK else BV], float('inf'), dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + i_p = tl.maximum(i_t * BT - 1, 0) + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + if NORMK: + p_z = tl.make_block_ptr(z + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = exp(b_zc - b_zp), b_zc + # [BK, BT] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_q = (b_q * exp(b_zc[:, None] - b_z)).to(b_q.dtype) + # [BK, BV] + b_dh = b_dh * b_r[:, None] + else: + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = exp(b_zc - b_zp), b_zc + # [BT, BV] + b_z = tl.load(p_z, boundary_check=(0,)) + b_do = (b_do * exp(b_zc[None, :] - b_z)).to(b_do.dtype) + # [BK, BV] + b_dh = b_dh * b_r[None, :] + # [BK, BV] + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_V( + k, + v, + z, + h, + A, + do, + dh, + dq, + dk, + dv, + dA, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + n_bh = tl.num_programs(2) + + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + + # [BK,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k = exp(b_k - b_zc[None, :]).to(b_k.dtype) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * V * K, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + if i_k == 0: + b_dv += tl.dot(b_A.to(b_do.dtype), b_do, allow_tf32=False) + b_do = (b_do * scale).to(b_do.dtype) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BT, BT] + b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + # [BT, BK] + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_zp = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zp = tl.load(p_zp, boundary_check=(0,)) + # [BT, BK] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_z = exp(b_zp[None, :] - b_z) + # [BT, BK] + b_dq = b_dq * b_z + b_dk = b_dk * b_k + + p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT,), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + # [BT, BT] + b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype) + if i_k == 0: + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_intra_V( + q, + k, + z, + dA, + dq, + dk, + T, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BK] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_zq = exp(b_zn[None, :] - b_z) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kz = exp(b_k - b_zn[None, :]).to(b_k.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dA, b_kz, allow_tf32=False) + b_dq *= b_zq + + o_i = tl.arange(0, BC) + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + p_kj = tl.make_block_ptr(k + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK,] + b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] >= j + # [BC, BK] + b_dq += tl.where(m_i, b_dA[:, None] * exp(b_kj[None, :] - b_z), 0.) + p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*K, (T*K,), (1,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kz = exp(b_k - b_zn[None, :]) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_qz = (b_q * exp(b_zn[None, :] - b_z)).to(b_q.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dk += tl.dot(tl.trans(b_dA), b_qz, allow_tf32=False) + b_dk *= b_kz + + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC) + for j in range(0, BC): + p_qj = tl.make_block_ptr(q + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + p_zj = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0) + # [BK,] + b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32) + b_zj = tl.load(p_zj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] <= j + b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp(b_k - b_zj[None, :]), 0.) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_intra_K( + v, + z, + do, + dA, + scale, + T, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + n_bh = tl.num_programs(2) + + if i_i > i_j: + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1)) + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_dA = tl.make_block_ptr(dA+(i_bh+i_v*n_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BV,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * exp(b_zn[None, :] - b_z) * scale).to(b_do.dtype) + # [BV, BC] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v = exp(b_v - b_zn[:, None]).to(b_v.dtype) + # [BC, BC] + b_dA = tl.dot(b_do, b_v, allow_tf32=False) + tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_j * BC) * V + i_v * BV,), (BV,), (0,)) + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) * scale + + o_i = tl.arange(0, BC) + o_A = (i_bh + i_v * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + # [BV,] + b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32) + # [BC,] + b_dA = tl.sum(b_do * exp(b_v[None, :] - b_z), 1) + b_dA = tl.where(o_i >= j, b_dA, 0) + tl.store(dA + o_A + j, b_dA.to(b_do.dtype), mask=m_A) + + p_v = tl.advance(p_v, (V,)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_K( + q, + k, + v, + z, + h, + A, + do, + dh, + dq, + dk, + dv, + dA, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + n_bh = tl.num_programs(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh) * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k), allow_tf32=False) + b_A = tl.where(m_s, b_A, 0.) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_zp = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,)) + p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BV,] + b_zp = tl.load(p_zp, boundary_check=(0,)) + b_zc = tl.load(p_zc, boundary_check=(0,)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v = exp(b_v - b_zc[None, :]).to(b_v.dtype) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_z = exp(b_zp[None, :] - b_z) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * b_z * scale).to(b_do.dtype) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + # [BT, BV] + b_dv = b_v * tl.dot(b_k, b_dh, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BT] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BT, BK] + b_dq += tl.dot(b_dA, b_k, allow_tf32=False) + b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q, allow_tf32=False) + + p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_intra_KV( + v, + z, + A, + do, + dv, + T, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*V, (T*V,), (1,), ((i_t * BT + i_i * BC + BC - 1) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * exp(b_zn[None, :] - b_z)).to(b_do.dtype) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dv += tl.dot(b_A, b_do, allow_tf32=False) + b_dv *= exp(b_v - b_zn[None, :]) + + o_i = tl.arange(0, BC) + for j in range(0, BC): + p_z = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T * BT,), (1,), ((i_t * BT + i_i * BC + j) * BT + i_i * BC,), (BC,), (0,)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + # [BC,] + b_A = tl.load(p_A, boundary_check=(0,)) + # [BV,] + b_z = tl.load(p_z, boundary_check=(0,)) + b_do = tl.load(p_do, boundary_check=(0,)) + # [BC, BV] + m_i = o_i[:, None] <= j + b_dv += tl.where(m_i, exp(b_v - b_z[None, :]) * b_A[:, None] * b_do[None, :], 0.) + p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_rcum_inter( + s, + z, + ss, + doo, + T, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + NT: tl.constexpr +): + i_m, i_bh = tl.program_id(0), tl.program_id(1) + + b_sp = tl.zeros([BS,], dtype=tl.float32) + b_zp = tl.full([BS,], float('inf'), dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + p_zc = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT) * S + i_m * BS,), (BS,), (0,)) + p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + p_doo = tl.make_block_ptr(doo + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + # [BS,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_ss = tl.load(p_ss, boundary_check=(0, 1)) + + b_doo = exp(b_s - b_zp[None, :]) * b_sp[None, :] + tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1)) + # [BS,] + b_sp = b_sp * exp(b_zc - b_zp) + tl.sum(b_ss * exp(b_zc[None, :] - b_z), 0) + b_zp = b_zc + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_rcum_intra( + s, + z, + ss, + doo, + T, + S: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + NC: tl.constexpr +): + i_s, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + o_i = tl.arange(0, BC) + m_o = tl.full([BC, BC], 1., dtype=tl.float32) + + p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + BC - 1) * S + i_s * BS,), (BS,), (0,)) + p_doo = tl.make_block_ptr(doo + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0)) + # [BC, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)) + # [BS,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + + b_doo = tl.zeros([BC, BS], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0)) + p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0)) + # [BC, BS] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_ss = tl.load(p_ss, boundary_check=(0, 1)) + # [BC, BS] + b_doo += b_ss * exp(b_zn[None, :] - b_z) + b_doo = exp(b_s - b_zn[None, :]) * tl.dot(m_o.to(b_s.dtype), b_doo.to(b_s.dtype), allow_tf32=False) + + for j in range(0, BC): + p_z = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,)) + p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,)) + # [BS,] + b_z = tl.load(p_z, boundary_check=(0,)) + b_ss = tl.load(p_ss, boundary_check=(0,)) + # [BC, BS] + m_i = o_i[:, None] <= j + b_doo += tl.where(m_i, exp(b_s - b_z[None, :]) * b_ss[None, :], 0.) + b_doo += tl.load(p_doo, boundary_check=(0, 1)) + tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkABCFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward(ctx, q, k, v, s, initial_state, output_final_state): + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + BT, BC = 64, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + BM = min(64, triton.next_power_of_2(M)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NV, NM = triton.cdiv(V, BV), triton.cdiv(M, BM) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + def fwd_pre(s, B, H, T, S): + # keep cummulative normalizer in fp32 + z = torch.empty_like(s, dtype=torch.float) + grid = (B * H,) + logcumsumexp_fwd_kernel[grid]( + s, z, + T=T, S=S + ) + return z + + def fwd_inner(q, k, v, z, B, H, T, K, V, BT, BK, BV, NT, normk=False, h0=None, ht=None): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + h = q.new_empty(B, H, NT * K, V) + grid = (NV, NK, B * H) + chunk_abc_fwd_kernel_h[grid]( + k, v, z, h, h0, ht, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + NORMK=normk, + USE_INITIAL_STATE=h0 is not None, + STORE_FINAL_STATE=ht is not None, + num_warps=num_warps, + num_stages=num_stages + ) + return h + + final_state = None + if output_final_state: + final_state = (q.new_empty(B, H, K, M, dtype=torch.float), + q.new_empty(B, H, M, V, dtype=torch.float)) + + z = fwd_pre(s, B, H, T, M) + scale = K ** -0.5 + hk = fwd_inner( + q=q, k=k, v=s, z=z, + B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT, + normk=False, + h0=initial_state[0] if initial_state is not None else None, + ht=final_state[0] if final_state is not None else None + ) + ok1 = torch.empty_like(s) + Ak = q.new_empty(B, H, T, BT) + grid = (NM, NT, B * H) + chunk_abc_fwd_kernel_K[grid]( + q, k, z, hk, ok1, Ak, + scale=scale, + T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + ok0 = torch.empty_like(s) + grid = (NM, NT * NC, B * H) + chunk_abc_fwd_kernel_intra_K[grid]( + s, z, ok0, Ak, + T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + ok = ok0.add_(ok1) + + scale = 1. + # p is kept in fp32 for safe softmax backward + p = softmax_fwd(ok, dtype=torch.float) + qv = p.to(q.dtype) + + scale = 1. + hv = fwd_inner( + q=qv, k=s, v=v, z=z, + B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT, + normk=True, + h0=initial_state[1] if initial_state is not None else None, + ht=final_state[1] if final_state is not None else None + ) + Av = q.new_zeros(NM, B, H, T, BT) + grid = (NM, NT * NC * NC, B * H) + chunk_abc_fwd_kernel_intra_V[grid]( + qv, s, z, Av, + scale=scale, + T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + Av = Av.sum(0) + ov = torch.empty_like(v) + grid = (NV, NT, B * H) + chunk_abc_fwd_kernel_V[grid]( + qv, v, z, hv, ov, Av, + scale=scale, + T=T, + K=M, + V=V, + BT=BT, + BK=BM, + BV=BV, + NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v, s, z, ok, p, hk, hv, Av) + ctx.BT = BT + return ov, final_state + + @staticmethod + @input_guard + def backward(ctx, dov, dht=None): + q, k, v, s, z, ok, p, hk, hv, Av = ctx.saved_tensors + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + BT, BC = ctx.BT, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + BM = min(64, triton.next_power_of_2(M)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NK, NM = triton.cdiv(K, BK), triton.cdiv(M, BM) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + def bwd_inner(q, z, do, B, H, T, K, V, BT, BK, BV, NT, scale, normk=False): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + dh = q.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_abc_bwd_kernel_dh[grid]( + q, z, do, dh, + scale=scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + NORMK=normk, + num_warps=num_warps, + num_stages=num_stages + ) + return dh + + def bwd_post(s, z, ss, B, H, T, S, BT, BC, BS, NT, NC, NS): + doo = torch.empty_like(s) + grid = (NS, B * H) + chunk_abc_bwd_kernel_rcum_inter[grid]( + s, z, ss, doo, + T=T, S=S, BT=BT, BS=BS, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + grid = (NS, NT * NC, B * H) + chunk_abc_bwd_kernel_rcum_intra[grid]( + s, z, ss, doo, + T=T, S=S, BT=BT, BC=BC, BS=BS, NC=NC, + num_warps=num_warps, + num_stages=num_stages + ) + return doo + + scale = 1. + qv = p.to(q.dtype) + dhv = bwd_inner( + qv, z, dov, + B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT, + scale=scale, + normk=True + ) + dp1 = torch.empty_like(p) + dsv1 = torch.empty_like(s, dtype=torch.float) + dv = v.new_empty(NM, *v.shape) + dAv = q.new_zeros(B, H, T, BT) + grid = (NM, NT, B * H) + chunk_abc_bwd_kernel_V[grid]( + s, v, z, hv, Av, dov, dhv, dp1, dsv1, dv, dAv, + scale=scale, + T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + dv = dv.sum(0) + dp0 = torch.empty_like(p) + dsv0 = s.new_zeros(s.shape, dtype=torch.float) + grid = (NM, NT * NC, B * H) + chunk_abc_bwd_kernel_intra_V[grid]( + qv, s, z, dAv, dp0, dsv0, + T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + dp = dp1.add_(dp0) + dsv = dsv1.add_(dsv0) + + # softmax gradient, equivalent to: + # dok = p * (dp - (p * dp).sum(-1, True)) + dok = softmax_bwd(p, dp, dtype=ok.dtype) + + scale = K ** -0.5 + dhk = bwd_inner( + q, z, dok, + B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT, + scale=scale, + normk=False + ) + dAk = q.new_zeros(NM, B, H, T, BT) + grid = (NM, NT * NC * NC, B * H) + chunk_abc_bwd_kernel_intra_K[grid]( + s, z, dok, dAk, + scale=scale, + T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + dAk = dAk.sum(0) + + Ak = q.new_zeros(NK, B, H, T, BT) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dsk1 = s.new_empty(NK, *s.shape, dtype=torch.float) + grid = (NK, NT, B * H) + chunk_abc_bwd_kernel_K[grid]( + q, k, s, z, hk, Ak, dok, dhk, dq, dk, dsk1, dAk, + scale=scale, + T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + Ak = Ak.sum(0) + dsk1 = dsk1.sum(0) + dsk0 = torch.empty_like(s, dtype=torch.float) + grid = (NM, NT * NC, B * H) + chunk_abc_bwd_kernel_intra_KV[grid]( + s, z, Ak, dok, dsk0, + T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + ds = dsv.add_(dsk1.add_(dsk0)) + ds -= bwd_post(s, z, ok * dok + p * dp, B, H, T, M, BT, BC, BM, NT, NC, NM) + ds = ds.to(s.dtype) + return dq, dk, dv, ds, None, None + + +@torch.compiler.disable +def chunk_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + initial_state: Optional[Tuple[torch.Tensor]] = None, + output_final_state: bool = False, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + s (torch.Tensor): + slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]` + initial_state (Optional[Tuple[torch.Tensor, torch.Tensor]]): + Initial states of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `False`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[B, H, K, M]` and `[B, H, M, V]` if `output_final_state=True` else `None`. + """ + if not head_first: + q, k, v, s = map(lambda x: x.transpose(1, 2), (q, k, v, s)) + o, final_state = ChunkABCFunction.apply(q, k, v, s, initial_state, output_final_state) + if not head_first: + o = o.transpose(1, 2) + return o, final_state diff --git a/fla/ops/abc/naive.py b/fla/ops/abc/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..a7f25c40db73bcf33d1599761be0008cc5be7c59 --- /dev/null +++ b/fla/ops/abc/naive.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +from einops import repeat + + +def naive_recurrent_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False +) -> torch.Tensor: + dtype = q.dtype + + NG = q.shape[1]//k.shape[1] + # [batch_size, n_heads, seq_len, n_slots] + if g is None: + z = s.float().logcumsumexp(2) + g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z + s = torch.exp(s - z) + q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g)) + k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g)) + if initial_state is not None: + initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state)) + + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + + hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device) + ok = torch.zeros_like(s) + + if scale is None: + scale = q.shape[-1] ** -0.5 + + final_state = None + if initial_state is not None: + hk += initial_state[0] + + for i in range(T): + q_i = q[:, :, i] * scale + k_i = k[:, :, i] + v_i = s[:, :, i] + g_i = g[:, :, i].exp() + hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :] + ok[:, :, i] = (q_i[..., None] * hk).sum(-2) + + qv = ok.softmax(-1) + hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device) + ov = torch.zeros_like(v) + if initial_state is not None: + hv += initial_state[1] + + for i in range(T): + q_i = qv[:, :, i] + k_i = s[:, :, i] + v_i = v[:, :, i] + g_i = g[:, :, i].exp() + hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :] + ov[:, :, i] = (q_i[..., None] * hv).sum(-2) + + if output_final_state: + final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0]) + return ov.to(dtype), final_state + + +def naive_cumsum_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor +) -> torch.Tensor: + """ + A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper. + This is just for demonstration purposes, with no numerical stabilities guaranteed. + """ + + dtype = q.dtype + q, k, v, s = map(lambda x: x.float(), (q, k, v, s)) + + scale = q.shape[-1] ** -0.5 + # [batch_size, n_heads, seq_len, n_slots] + s = (s - s.max(2, True)[0]).exp() + z = s.cumsum(2) + # [batch_size, n_heads, seq_len, n_slots, d_head] + K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) + V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) + # [batch_size, n_heads, seq_len, n_slots] + p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1) + # [batch_size, n_heads, seq_len, d_head] + o = torch.einsum('...m,...md->...d', p, V) + return o.to(dtype), None diff --git a/fla/ops/attn/__pycache__/parallel.cpython-312.pyc b/fla/ops/attn/__pycache__/parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed4f7bcdb92b3dea0c6001a8968a27f796996276 Binary files /dev/null and b/fla/ops/attn/__pycache__/parallel.cpython-312.pyc differ diff --git a/fla/ops/based/__init__.py b/fla/ops/based/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f20b31ba0ea4c7d345761fbd6ab5f6ced5136236 --- /dev/null +++ b/fla/ops/based/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .fused_chunk import fused_chunk_based +from .parallel import parallel_based + +__all__ = [ + 'fused_chunk_based', + 'parallel_based' +] diff --git a/fla/ops/based/fused_chunk.py b/fla/ops/based/fused_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..ff5db4fb73022c677662a4f7d29d6b2ec3015194 --- /dev/null +++ b/fla/ops/based/fused_chunk.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.jit(do_not_specialize=['T']) +def fused_chunk_based_fwd_kernel( + q, + k, + v, + o, + z, + scale, # K ** -0.5 + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + + # [BV], zero-order taylor expansion + b_h_0o = tl.zeros([BV], dtype=tl.float32) + # [BK, BV], first-order taylor expansion + b_h_1o = tl.zeros([BK, BV], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + + p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT) + k_2o = tl.zeros([1, BK * BK], dtype=tl.float32) + k_1o = tl.zeros([1, BK], dtype=tl.float32) + k_0o = 0 + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK*BK, BT] + b_k_2o = b_k[:, None, :] * b_k[None, :, :] + b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype) + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_z = tl.zeros([BT], dtype=tl.float32) + + # interchunk + # zero-order + b_o += b_h_0o + b_z += k_0o + # first-order + b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False) + b_z += tl.sum(b_q * k_1o, axis=1) + # second-order + b_q_2o = b_q[:, :, None] * b_q[:, None, :] + b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype) + b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5 + b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5 + + # update running statistics + k_1o += tl.sum(b_k, axis=1)[None, :] + k_2o += tl.sum(b_k_2o, axis=1)[None, :] + k_0o += BT + + # intrachunk + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + # [TB, BV] + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=(i * BT + tl.arange(0, BT)) < T) + + # update hidden state + # [BK, BV] + b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False) + b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False) + b_h_0o = b_h_0o + tl.sum(b_v, axis=0) + + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_z += BT + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_chunk_based_bwd_kernel( + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, + k, + v, + do, + dz, + dq, + dk, + dv, + scale, # K ** -0.5 + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + # [BV], zero-order taylor expansion + # b_h_0o = tl.zeros([BV], dtype=tl.float32) + # [BK, BV], first-order taylor expansion + b_h_1o = tl.zeros([BV, BK], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32) + + k_1o = tl.zeros([1, BK], dtype=tl.float32) + k_2o = tl.zeros([1, BK * BK], dtype=tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + + # load tensors + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T) + # [BV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # inter-chunk + b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False) + if i_v == 0: + b_dq += b_dz[:, None] * k_1o + b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5 + if i_v == 0: + b_dq_2o += (b_dz[:, None] * k_2o) * 0.5 + b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK]) + b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1) + b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2) + b_dq *= scale + + # intra-chunk + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False) + + # store + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # update hidden state + # [BT, BK*BK] + b_k_2o = b_k[:, :, None] * b_k[:, None, :] + b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype) + # [BV, BK*BK] + b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False) + # [BV, BK] + b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False) + + if i_v == 0: + # update running statistics + k_1o += tl.sum(b_k, axis=0)[None, :] + k_2o += tl.sum(b_k_2o, axis=0)[None, :] + + tl.debug_barrier() + b_h_1o = None + b_h_2o = None + + # [BK, BV], first-order taylor expansion + b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32) + b_dh_0o = tl.zeros([BV], dtype=tl.float32) + m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :] + + dq_1o = tl.zeros([1, BK], dtype=tl.float32) + dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32) + + for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * T*K, (T, K), (K, 1), (i, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (i, i_v*BV), (BT, BV), (1, 0)) + p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i + + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T) + b_q = (b_q * scale).to(b_k.dtype) + + # intra chunk + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + b_ds = tl.where(m_s, b_ds, 0) + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + b_ds *= (1+b_s) + + b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False) + b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False) + + # inter chunk + b_k_2o = b_k[:, :, None] * b_k[:, None, :] + b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype) + + b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False) + b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False) + b_dv += b_dh_0o + + b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False) + + if i_v == 0: + b_dk += dq_1o + + b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype), tl.trans(b_v), allow_tf32=False) + if i_v == 0: + b_dk_2o += dq_2o + b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT]) + b_k_fp32 = tl.trans(b_k.to(tl.float32)) + b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0) + b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1) + b_dk += tl.trans(b_dk2) + + # hidden state update + b_dh_0o += tl.sum(b_do, axis=0) + b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False) + b_q_2o = b_q[None, :, :] * b_q[:, None, :] + b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype) + b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5 + + if i_v == 0: + dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :] + dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None] + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkBasedFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, scale=1): + B, H, T, K, V = *k.shape, v.shape[-1] + + scale = scale + BT = 16 + BK, BV = min(K, 16), min(V, 32) + BK, BV = max(BK, 16), max(BV, 16) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + num_warps = 4 + + # the norm of o might explode, so we need to use float32 here + o = q.new_empty(NK, B, H, T, V, dtype=torch.float32) + z = q.new_empty(NK, B, H, T, dtype=torch.float32) + + grid = (NV, NK, B * H) + fused_chunk_based_fwd_kernel[grid]( + q, k, v, o, z, + scale, + T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + ) + o = o.sum(0) + z = z.sum(0) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.to(q.dtype), z.to(z.dtype) + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + B, H, T, K, V = *k.shape, v.shape[-1] + scale = ctx.scale + + BT = 16 + BK, BV = min(K, 16), min(V, 32) + BK, BV = max(BK, 16), max(BV, 16) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + grid = (NV, NK, B * H) + + fused_chunk_based_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + scale, + T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None + + +def fused_chunk_based( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + use_norm: bool = True, + head_first: bool = True +): + assert q.shape[-1] <= 16, 'only support feature dimension up to 16.' + if scale is None: + scale = q.shape[-1] ** -0.5 + if not head_first: + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + o, z = FusedChunkBasedFunction.apply(q, k, v, scale) + if use_norm: + o = o / (z[..., None] + 1e-6) + if not head_first: + o = o.transpose(1, 2) + return o.to(q.dtype) diff --git a/fla/ops/based/naive.py b/fla/ops/based/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..4de614137ed28567ebb1df39c0892f498b91fb5a --- /dev/null +++ b/fla/ops/based/naive.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +from einops import rearrange + + +def naive_parallel_based( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + use_norm: bool = True +): + if scale is None: + scale = q.shape[-1] ** -0.5 + q = q * scale + attn = q @ k.transpose(-2, -1) + attn = 1 + attn + 1/2 * (attn ** 2) + attn.masked_fill_(~torch.tril(torch.ones( + q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) + o = attn @ v + if use_norm: + z = attn.sum(-1) + return o / (z[..., None] + 1e-6) + else: + return o + + +def naive_chunk_based(q, k, v, chunk_size=256): + q = q * (q.shape[-1] ** -0.5) + # compute normalizer. + k_cumsum = torch.cumsum(k, dim=-2) + kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3) + # first + z = (q * k_cumsum).sum(-1) + # second order + z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5 + # zero-th order + z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :] + + # compute o + # constant term + _o = v.cumsum(-2) + + q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) + + k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) + v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) + + intra_chunk_attn = q @ k.transpose(-2, -1) + intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2) + intra_chunk_attn.masked_fill_(~torch.tril(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device)), 0) + o = intra_chunk_attn @ v + + # quadractic term + kv = torch.einsum('b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v) + kv = kv.cumsum(2) + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + + o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q) + + # linear term + kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v) + kv = kv.cumsum(2) + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q) + + o = rearrange(o, 'b h n c d -> b h (n c) d') + o = o + _o + return o / (z[..., None] + 1e-6) diff --git a/fla/ops/based/parallel.py b/fla/ops/based/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..d4621ea5838bc410a33b1b0f0af40b3c322f02b5 --- /dev/null +++ b/fla/ops/based/parallel.py @@ -0,0 +1,410 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + +# Based: An Educational and Effective Sequence Mixer +# https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based + + +@triton.jit(do_not_specialize=['T']) +def parallel_based_fwd_kernel( + q, + k, + v, + o, + z, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + # i_c: chunk index. used for sequence parallelism + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(V, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BTS, BV), (1, 0)) + + # [BQ, BD] block Q, in the shared memory throughout the whole kernel + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_o = tl.zeros([BTL, BV], dtype=tl.float32) + b_z = tl.zeros([BTL], dtype=tl.float32) + + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for _ in range(0, i_c * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_s = tl.dot(b_q, (b_k), allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_z += tl.sum(b_s, axis=1) + + # [BQ, BD] + b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + + # # rescale interchunk output + tl.debug_barrier() + o_q = tl.arange(0, BTL) + # # sync threads, easy for compiler to optimize + # tl.debug_barrier() + + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + # [BTL, BV] + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + o_k += BTS + + p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c * BTL + tl.arange(0, BTL)) < T)) + + +@triton.jit +def _parallel_based_bwd_dq( + i_bh, + i_c, + i_k, + i_v, + q, + k, + v, + do, + dz, + dq, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, +): + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (i_bh) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dq = tl.zeros([BTL, BK], dtype=tl.float32) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, 0), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL) + b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T) + + for _ in range(0, i_c * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + # [BQ, BD] + b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + + b_dq *= scale + o_q = tl.arange(0, BTL) + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BTL, BK] + b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + o_k += BTS + p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def _parallel_based_bwd_dkv( + i_bh, + i_c, + i_k, + i_v, + q, + k, + v, + do, + dz, + dk, + dv, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, +): + # compute dk dv + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(p_v, boundary_check=(0, 1)) + b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros([BTL, BV], dtype=tl.float32) + + for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS] + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * scale # [BTL, BTS] + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale + if i_v == 0: + b_ds += b_dz[None, :] * scale + else: + b_ds = b_ds + b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False) + + tl.debug_barrier() + o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) + for i in range(i_c*BTL, (i_c+1)*BTL, BTS): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + # [BK, BQ] + m_s = o_k[:, None] <= o_q[None, :] + b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + + b_ds = tl.dot(b_v, b_do, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + # [BK, BD] + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False) + o_q += BTS + + p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit(do_not_specialize=['T']) +def parallel_based_bwd_kernel( + q, + k, + v, + do, + dz, + dq, + dk, + dv, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(V, BV) + i_k = i_kv // (NV) + i_v = i_kv % NV + _parallel_based_bwd_dq( + i_bh, i_c, i_k, i_v, + q, k, v, do, dz, dq, + scale, T, B, H, BTL, BTS, BK, BV, K, V + ) + tl.debug_barrier() + _parallel_based_bwd_dkv( + i_bh, i_c, i_k, i_v, + q, k, v, do, dz, dk, dv, + scale, T, B, H, BTL, BTS, BK, BV, K, V + ) + + +class ParallelBasedFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, scale): + BTL, BTS = 128, 32 + assert BTL % BTS == 0 + # assert q.shape[-1] % 16 == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + B, H, T, K, V = *k.shape, v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + grid = (NK * NV, triton.cdiv(T, BTL), B * H) + + assert NK == 1, "will encounter some synchronization issue if not." + + o = torch.empty(NK, B, H, T, V, device=q.device) + z = torch.empty(NK, B, H, T, device=q.device) + parallel_based_fwd_kernel[grid]( + q, k, v, o, z, + scale, + B=B, + H=H, + T=T, + K=K, + V=V, + BTL=BTL, + BTS=BTS, + BK=BK, + BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype) + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + scale = ctx.scale + BTL, BTS = 64, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + B, H, T, K, V = *k.shape, v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + grid = (NK * NV, triton.cdiv(T, BTL), B * H) + + assert NK == 1, "will encounter some synchronization issue if not" + + dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device) + dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device) + dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device) + + parallel_based_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + scale, + B=B, + H=H, + T=T, + K=K, + V=V, + BTL=BTL, + BTS=BTS, + BK=BK, + BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + + return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None + + +triton_parallel_based = ParallelBasedFunction.apply + + +def parallel_based( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + use_norm: bool = True, + head_first: bool = True +): + assert q.shape[-1] <= 128, "only support feature dim up to 128" + if scale is None: + scale = q.shape[-1] ** -0.5 + if not head_first: + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + o, z = triton_parallel_based(q, k, v, scale) + if use_norm: + o = o / (z[..., None] + 1e-6) + if not head_first: + o = o.transpose(1, 2) + return o.to(q.dtype) diff --git a/fla/ops/common/__init__.py b/fla/ops/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/fla/ops/common/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/fla/ops/common/__pycache__/chunk_h.cpython-312.pyc b/fla/ops/common/__pycache__/chunk_h.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acddac883897862c1c184247c14033a13585ed03 Binary files /dev/null and b/fla/ops/common/__pycache__/chunk_h.cpython-312.pyc differ diff --git a/fla/ops/common/__pycache__/utils.cpython-312.pyc b/fla/ops/common/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..230fc3104a9651ce8e837872f50d04d194109986 Binary files /dev/null and b/fla/ops/common/__pycache__/utils.cpython-312.pyc differ diff --git a/fla/ops/common/chunk_h.py b/fla/ops/common/chunk_h.py new file mode 100644 index 0000000000000000000000000000000000000000..0aa5a7a93b9741968fa03ab630eb8aba062ccc5f --- /dev/null +++ b/fla/ops/common/chunk_h.py @@ -0,0 +1,422 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.utils import prepare_chunk_offsets +from fla.ops.utils.op import exp +from fla.utils import check_shared_mem + +BKV_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_h( + k, + v, + h, + g, + gk, + gv, + h0, + ht, + offsets, + split_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + NS = tl.cdiv(T, BS) + boh = tl.load(split_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + NS = tl.cdiv(T, BS) + boh = i_n * NS + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + i_s = i_t // (BS // BT) + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + o_h = (i_nh * NS + i_s).to(tl.int64) * K*V + p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + o_h = ((boh + i_s) * H + i_h).to(tl.int64) * K*V + p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + if i_t % (BS // BT) == 0: + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + last_idx = min((i_t + 1) * BT, T) - 1 + + # scalar decay + if USE_G: + if HEAD_FIRST: + b_g_last = tl.load(g + i_nh * T + last_idx) + p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT) + p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) + else: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h + b_h *= exp(b_g_last) + b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) + b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype) + + # vector decay, h = Diag(gk) @ h + if USE_GK: + if HEAD_FIRST: + p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK) + p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) + else: + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_h *= exp(b_gk_last)[:, None] + + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype) + + # vector decay, h = h @ Diag(gv) + if USE_GV: + if HEAD_FIRST: + p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV) + p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) + else: + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_h *= exp(b_gv_last)[None, :] + + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype) + + b_h += tl.dot(b_k, b_v) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dh( + q, + g, + gk, + gv, + do, + dh, + dht, + dh0, + offsets, + split_offsets, + scale, + T, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bg = i_nh // NG + i_n, i_hq = i_nh // HQ, i_nh % HQ + i_h = i_hq // NG + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + NS = tl.cdiv(T, BS) + boh = tl.load(split_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + NS = tl.cdiv(T, BS) + boh = i_n * NS + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT - 1, -1, -1): + i_s = i_t // (BS // BT) + if HEAD_FIRST: + o_dh = (i_nh * NS + i_s).to(tl.int64) * K*V + p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + o_dh = ((boh + i_s) * H + i_h).to(tl.int64) * K*V + p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + if i_t % (BS // BT) == 0: + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + last_idx = min(i_t * BT + BT, T) - 1 + # [BK, BT] + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + if USE_G: + if HEAD_FIRST: + p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT) + p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) + b_g_last = tl.load(g + i_bg * T + last_idx) + else: + p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h + b_g_last = tl.load(g + (bos + last_idx) * H + i_h) + b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) + b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype) + + b_dh *= exp(b_g_last) + + if USE_GK: + if HEAD_FIRST: + p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK) + p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) + else: + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_q = (b_q * exp(b_gk)).to(b_q.dtype) + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_dh *= exp(b_gk_last)[:, None] + + if USE_GV: + if HEAD_FIRST: + p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV) + p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) + else: + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_do = (b_do * exp(b_gv)).to(b_do.dtype) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_dh *= exp(b_gv_last)[None, :] + + b_dh += tl.dot(b_q, b_do) + + if STORE_INITIAL_STATE_GRADIENT: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_h( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + gk: torch.Tensor, + gv: torch.Tensor, + h0: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.Tensor] = None, + head_first: bool = True, + chunk_size: int = 64, + split_size: Optional[int] = None, + states_in_fp32: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T))) + assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}" + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + split_offsets, N, NS = None, B, triton.cdiv(T, BS) + else: + split_offsets = prepare_chunk_offsets(offsets, BS) + N, NS = len(offsets) - 1, split_offsets[-1] + + if head_first: + h = k.new_empty(B, H, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + else: + h = k.new_empty(B, NS, H, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) + chunk_fwd_kernel_h[grid]( + k=k, + v=v, + h=h, + g=g, + gk=gk, + gv=gv, + h0=h0, + ht=ht, + offsets=offsets, + split_offsets=split_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BS=BS, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + HEAD_FIRST=head_first + ) + return h, ht + + +def chunk_bwd_dh( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + gk: torch.Tensor, + gv: torch.Tensor, + do: torch.Tensor, + h0: torch.Tensor, + dht: torch.Tensor, + scale: float, + offsets: Optional[torch.Tensor] = None, + head_first: bool = True, + chunk_size: int = 64, + split_size: Optional[int] = None, + states_in_fp32: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + HQ = q.shape[1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T))) + assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}" + # N: the actual number of sequences in the batch with either equal or variable lengths + # NG: number of groups in GQA + if offsets is None: + split_offsets, N, NS = None, B, triton.cdiv(T, BS) + else: + split_offsets = prepare_chunk_offsets(offsets, BS) + N, NS = len(offsets) - 1, split_offsets[-1] + NG = HQ // H + + if head_first: + dh = k.new_empty(B, HQ, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + else: + dh = k.new_empty(B, NS, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None + + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) + chunk_bwd_kernel_dh[grid]( + q=q, + g=g, + gk=gk, + gv=gv, + do=do, + dh=dh, + dht=dht, + dh0=dh0, + offsets=offsets, + split_offsets=split_offsets, + scale=scale, + T=T, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + BS=BS, + NG=NG, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + HEAD_FIRST=head_first + ) + return dh, dh0 diff --git a/fla/ops/common/chunk_h_parallel.py b/fla/ops/common/chunk_h_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..51083eda8efbe012432ebf4a08fb34954a0dfd89 --- /dev/null +++ b/fla/ops/common/chunk_h_parallel.py @@ -0,0 +1,650 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +""" +Fully parallelized state passing. +""" + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for BV in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_h_parallel( + k, + v, + h, + g, + gk, + gv, + h0, + ht, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + NV = tl.cdiv(V, BV) + # i_b: batch index + # i_h: head index + # i_n: sequence index + # i_t: chunk index within current sequence + # i_tg: (global) chunk index across all sequences + i_k, i_v = i_kv // NV, i_kv % NV + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + bos, eos = i_b * T, i_b * T + T + NT = tl.cdiv(T, BT) + i_n, i_tg = i_b, i_b * NT + i_t + i_nh = i_n * H + i_h + + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + if i_t == 0: + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + else: + b_h = tl.zeros([BK, BV], dtype=tl.float32) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + last_idx = min(i_t * BT + BT, T) - 1 + # scalar decay + if USE_G: + if HEAD_FIRST: + b_g_last = tl.load(g + i_bh * T + last_idx) + p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT) + p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) + else: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h + b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) + b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype) + + # vector decay, h = Diag(gk) @ h + if USE_GK: + if HEAD_FIRST: + p_gk = tl.make_block_ptr(gk + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + i_bh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK) + p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) + else: + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype) + + # vector decay, h = h @ Diag(gv) + if USE_GV: + if HEAD_FIRST: + p_gv = tl.make_block_ptr(gv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + i_bh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV) + p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) + else: + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype) + + b_h = tl.dot(b_k, b_v) + if i_t < NT - 1: + if HEAD_FIRST: + p_h = tl.make_block_ptr(h + (i_bh * NT + i_t + 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_h = tl.make_block_ptr(h + ((i_tg + 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + elif STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for BV in [32, 64, 128] + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_h_reduction( + h, + g, + gk, + gv, + kvt, + ht, + offsets, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + if i_t > 0: + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min(i_t * BT + BT, T) - 1 + # scalar decay + if USE_G: + if HEAD_FIRST: + b_g_last = tl.load(g + i_nh * T + last_idx) + else: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + b_h *= exp(b_g_last) + + # vector decay, h = Diag(gk) @ h + if USE_GK: + if HEAD_FIRST: + p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK) + p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) + else: + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_h *= exp(b_gk_last)[:, None] + + # vector decay, h = h @ Diag(gv) + if USE_GV: + if HEAD_FIRST: + p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV) + p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) + else: + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_h *= exp(b_gv_last)[None, :] + + if STORE_FINAL_STATE: + p_kvt = tl.make_block_ptr(kvt + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_kvt, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for BV in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dh_parallel( + q, + g, + gk, + gv, + do, + dh, + dht, + dh0, + offsets, + indices, + scale, + T, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + NV = tl.cdiv(V, BV) + i_k, i_v = i_kv // NV, i_kv % NV + i_b, i_hq, i_bg = i_bh // HQ, i_bh % HQ, i_bh // NG + i_h = i_hq // NG + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + bos, eos = i_b * T, i_b * T + T + NT = tl.cdiv(T, BT) + i_n, i_tg = i_b, i_b * NT + i_t + i_nh = i_n * HQ + i_hq + + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + if i_t == NT - 1: + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh = tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32) + else: + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + if USE_G: + if HEAD_FIRST: + p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT) + p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) + else: + p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h + b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) + b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype) + + if USE_GK: + if HEAD_FIRST: + p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + else: + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_q = (b_q * exp(b_gk)).to(b_q.dtype) + + if USE_GV: + if HEAD_FIRST: + p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_do = (b_do * exp(b_gv)).to(b_do.dtype) + + b_dh = tl.dot(b_q, b_do) + if i_t > 0: + if HEAD_FIRST: + p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t - 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_dh = tl.make_block_ptr(dh + ((i_tg - 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + elif STORE_INITIAL_STATE_GRADIENT: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for BV in [32, 64, 128] + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dh_reduction( + g, + gk, + gv, + dh, + doq0, + dh0, + offsets, + chunk_offsets, + T, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bg = i_nh // NG + i_n, i_hq = i_nh // HQ, i_nh % HQ + i_h = i_hq // NG + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + if HEAD_FIRST: + p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dh, boundary_check=(0, 1)).to(tl.float32) + if i_t < NT - 1: + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min(i_t * BT + BT, T) - 1 + if USE_G: + if HEAD_FIRST: + b_g_last = tl.load(g + i_bg * T + last_idx) + else: + b_g_last = tl.load(g + (bos + last_idx) * H + i_h) + b_dh *= exp(b_g_last) + + if USE_GK: + if HEAD_FIRST: + p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK) + p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) + else: + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_dh *= exp(b_gk_last)[:, None] + + if USE_GV: + if HEAD_FIRST: + p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV) + p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) + else: + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_dh *= exp(b_gv_last)[None, :] + + if STORE_INITIAL_STATE_GRADIENT: + p_doq0 = tl.make_block_ptr(doq0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_doq0, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_h( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + gk: torch.Tensor, + gv: torch.Tensor, + h0: torch.Tensor, + output_final_state: bool, + states_in_fp32: bool = False, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + if indices is None: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()]) + indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) + N, NT = len(offsets) - 1, len(indices) + chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1) + + h = k.new_empty(B, H, NT, K, V, dtype=torch.float) if head_first else k.new_empty(B, NT, H, K, V, dtype=torch.float) + ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None + def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * H) + chunk_fwd_kernel_h_parallel[grid]( + k=k, + v=v, + h=h, + g=g, + gk=gk, + gv=gv, + h0=h0, + ht=ht, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + HEAD_FIRST=head_first + ) + kvt, ht = ht, (torch.empty_like(ht) if output_final_state else None) + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) + chunk_fwd_kernel_h_reduction[grid]( + h=h, + g=g, + gk=gk, + gv=gv, + kvt=kvt, + ht=ht, + offsets=offsets, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + HEAD_FIRST=head_first + ) + h = h.to(k.dtype) if not states_in_fp32 else h + return h, ht + + +def chunk_bwd_dh( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + gk: torch.Tensor, + gv: torch.Tensor, + do: torch.Tensor, + h0: torch.Tensor, + dht: torch.Tensor, + scale: float, + states_in_fp32: bool = False, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + HQ = q.shape[1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + # N: the actual number of sequences in the batch with either equal or variable lengths + # NG: number of groups in GQA + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + if indices is None: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()]) + indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) + N, NT = len(offsets) - 1, len(indices) + chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1) + NG = HQ // H + + if head_first: + dh = k.new_empty(B, HQ, NT, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + else: + dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None + + def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * HQ) + chunk_bwd_kernel_dh_parallel[grid]( + q=q, + g=g, + gk=gk, + gv=gv, + do=do, + dh=dh, + dht=dht, + dh0=dh0, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + NG=NG, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + HEAD_FIRST=head_first + ) + + doq0, dh0 = dh0, (torch.empty_like(dh0) if dh0 is not None else None) + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ) + chunk_bwd_kernel_dh_reduction[grid]( + g=g, + gk=gk, + gv=gv, + dh=dh, + doq0=doq0, + dh0=dh0, + offsets=offsets, + chunk_offsets=chunk_offsets, + T=T, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + NG=NG, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + HEAD_FIRST=head_first + ) + dh = dh.to(q.dtype) if not states_in_fp32 else dh + return dh, dh0 diff --git a/fla/ops/common/chunk_o.py b/fla/ops/common/chunk_o.py new file mode 100644 index 0000000000000000000000000000000000000000..b1e99d1d28bebc49994deaef04c252be74b2d570 --- /dev/null +++ b/fla/ops/common/chunk_o.py @@ -0,0 +1,668 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp, safe_exp +from fla.utils import check_shared_mem, is_nvidia_hopper + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + s_qk = K if HEAD_FIRST else H*K + s_vo = V if HEAD_FIRST else H*V + s_g = 1 if HEAD_FIRST else H + # offset calculation + q += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K) + k += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K) + v += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V) + o += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V) + h += ((i_bh * NT + i_t).to(tl.int64) * K*V) if HEAD_FIRST else ((i_tg * H + i_h).to(tl.int64) * K*V) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) + p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) + + o_i = tl.arange(0, BT) + m_A = o_i[:, None] >= o_i[None, :] + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # to fix mma -> mma layout conversion + # already solved by triton v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'USE_G': lambda args: args['g'] is not None, + 'USE_DW': lambda args: args['dw'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G', 'USE_DW'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dqkwg( + q, + k, + v, + h, + g, + do, + dh, + dq, + dk, + dg, + w, + dv, + dw, + offsets, + indices, + scale, + B: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_DW: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_G: + dg += i_k * B * H * T + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + v += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V + do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V + h += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V + dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V + q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + dq += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + dk += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + s_qk = K if HEAD_FIRST else H*K + s_vo = V if HEAD_FIRST else H*V + s_g = 1 if HEAD_FIRST else H + + # for delta rule only + if USE_DW: + dw += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V + w += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + b_dg_last = tl.zeros([1,], dtype=tl.float32) if USE_G else None + b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + if USE_G: + b_dg_last += (tl.sum(b_h * b_dh)) + # [BT, BV] @ [BV, BT] -> [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v)) + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + if USE_DW: + p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype)) + + if USE_DW and not USE_G: + p_dw = tl.make_block_ptr(dw, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + + p_dq = tl.make_block_ptr(dq, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + if USE_G: + b_dg = tl.zeros([BT,], dtype=tl.float32) + g += i_bh * T if HEAD_FIRST else bos * H + i_h + dg += i_bh * T if HEAD_FIRST else bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * s_g) + b_dg_last *= exp(b_g_last) + + if USE_DW: + p_w = tl.make_block_ptr(w, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_dw = b_dw * exp(b_g)[:, None] + tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + b_dg -= tl.sum(b_w * b_dw, axis=1) + + b_dq = b_dq * exp(b_g)[:, None] * scale + b_dg += tl.sum(b_dq * b_q, axis=1) + + b_dk = b_dk * safe_exp(-b_g + b_g_last)[:, None] + b_dg -= tl.sum(b_k * b_dk, axis=1) + b_dg_last += tl.sum(b_dk * b_k) + + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * safe_exp(b_g[:, None] - b_g[None, :]), 0) * scale + b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k)) + b_dg += tl.sum(b_ds2, axis=1) + b_dg -= tl.sum(b_ds2, axis=0) + + b_ds = b_ds.to(b_k.dtype) + # [BT, BK] + b_dq += tl.dot(b_ds, b_k) + b_dk += tl.dot(tl.trans(b_ds), b_q) + p_dg = tl.make_block_ptr(dg, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) + # (SY 09/21) revcumsum in a separate kernel due to strange triton compiler issue + # b_dg = tl.dot(tl.where(o_i[:, None] <= o_i[None, :], 1., 0.), b_dg, allow_tf32=False) + b_dg_last) + b_dg = tl.where(o_i < min(BT, T-i_t*BT) - 1, b_dg, b_dg + b_dg_last) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + else: + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0) + b_ds = b_ds.to(b_k.dtype) + b_dq += tl.dot(b_ds, b_k) + b_dk += tl.dot(tl.trans(b_ds), b_q) * scale + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'USE_G': lambda args: args['g'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dv( + q, + k, + g, + do, + dv, + dh, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + # offset calculation + q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V + dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V + s_qk = K if HEAD_FIRST else H*K + s_vo = V if HEAD_FIRST else H*V + s_g = 1 if HEAD_FIRST else H + dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, b_q) + p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype)) + + if USE_G: + g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) + p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * s_g) + b_dv *= safe_exp(-b_g + b_g_last)[:, None] + + mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]) + if USE_G: + b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty) + else: + b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty) + p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv += tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dv_local( + q, + k, + g, + do, + dv, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V + dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V + s_qk = K if HEAD_FIRST else H*K + s_vo = V if HEAD_FIRST else H*V + s_g = 1 if HEAD_FIRST else H + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, b_q) + + if USE_G: + g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) + p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + + mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]) + if USE_G: + b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty) + else: + b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: Optional[torch.Tensor] = None, # cumsum of log decay + scale: Optional[float] = None, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> torch.Tensor: + if head_first: + B, H, T, K, V = *q.shape, v.shape[-1] + else: + B, T, H, K, V = *q.shape, v.shape[-1] + if scale is None: + scale = k.shape[-1] ** -0.5 + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + o = torch.empty_like(v) + + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) + chunk_fwd_kernel_o[grid]( + q, + k, + v, + h, + g, + o, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + HEAD_FIRST=head_first + ) + return o + + +def chunk_bwd_dv( + q: torch.Tensor, + k: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> torch.Tensor: + if head_first: + B, H, T, K, V = *k.shape, do.shape[-1] + else: + B, T, H, K, V = *k.shape, do.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + # H100 can have larger block size + if check_shared_mem('hopper', k.device.index): + CONST_TILING = 128 + elif check_shared_mem: + CONST_TILING = 64 + else: + CONST_TILING = 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + NV = triton.cdiv(V, BV) + + dv = torch.empty_like(do) + grid = (NV, NT, B * H) + chunk_bwd_kernel_dv[grid]( + q, + k, + g, + do, + dv, + dh, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dv + + +def chunk_bwd_dv_local( + q: torch.Tensor, + k: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> torch.Tensor: + if head_first: + B, H, T, K, V = *k.shape, do.shape[-1] + else: + B, T, H, K, V = *k.shape, do.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + # H100 can have larger block size + if check_shared_mem('hopper', k.device.index): + CONST_TILING = 128 + elif check_shared_mem: + CONST_TILING = 64 + else: + CONST_TILING = 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + dv = torch.empty_like(do) + grid = (NT, B * H) + chunk_bwd_kernel_dv_local[grid]( + q, + k, + g, + do, + dv, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dv + + +def chunk_bwd_dqkwg( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + h: torch.Tensor, + dh: torch.Tensor, + dv: Optional[torch.Tensor] = None, + w: Optional[torch.Tensor] = None, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + scale: float = 1.0, + head_first: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + NK = triton.cdiv(K, BK) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None + dw = torch.empty_like(w) if w is not None else None + + grid = (NK, NT, B * H) + chunk_bwd_kernel_dqkwg[grid]( + q=q, + k=k, + v=v, + h=h, + g=g, + do=do, + dh=dh, + dv=dv, + w=w, + dw=dw, + dq=dq, + dk=dk, + dg=dg, + offsets=offsets, + indices=indices, + scale=scale, + B=B, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + + if dg is not None: + dg = dg.sum(0) + return dq, dk, dw, dg diff --git a/fla/ops/common/fused_recurrent.py b/fla/ops/common/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..263de38d060716ec525a273d45eb1c3fe08ac4be --- /dev/null +++ b/fla/ops/common/fused_recurrent.py @@ -0,0 +1,575 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import chunk_global_cumsum +from fla.ops.utils.op import exp +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=["BK", "BV", "USE_GK", "USE_GV", "USE_G"], +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_fwd_kernel( + q, + k, + v, + g, + gk, + gv, + o, + h0, + ht, + offsets, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + # indices + i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if HEAD_FIRST: + p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_o = o + (i_k * B*H + i_nh) * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + if USE_G: + p_g = g + i_nh * T + ((T-1) if REVERSE else 0) + if USE_GK: + p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + else: + p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_o = o + ((i_k * all + bos) + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + if USE_G: + p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h + if USE_GK: + p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + mask_k = (i_k * BK + tl.arange(0, BK)) < K + mask_v = (i_v * BV + tl.arange(0, BV)) < V + mask_h = mask_k[None, :] & mask_v[:, None] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) + b_h = b_h * exp(b_gk[None, :]) + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32) + b_h = b_h * exp(b_gv[:, None]) + if USE_G: + b_g = tl.load(p_g).to(tl.float32) + b_h = b_h * exp(b_g) + b_h += b_k[None, :] * b_v[:, None] + b_o = b_h * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + if USE_GK: + p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + if USE_GV: + p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + if USE_G: + p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BK', 'BV', 'USE_GK', 'USE_GV', 'USE_G'], +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_bwd_kernel( + q, + k, + v, + g, + gk, + gv, + h0, + do, + dq, + dk, + dv, + dht, + dh0, + offsets, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if HEAD_FIRST: + p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_do = do + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_dq = dq + (i_v * B*H + i_nh) * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + if USE_G: + p_g = g + i_nh * T + ((T-1) if REVERSE else 0) + if USE_GK: + p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + else: + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_do = do + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_dq = dq + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + if USE_G: + p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h + if USE_GK: + p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + mask_k = i_k * BK + tl.arange(0, BK) < K + mask_v = i_v * BV + tl.arange(0, BV) < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + if USE_G: + b_g = tl.load(p_g).to(tl.float32) + b_h = b_h * exp(b_g) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) + b_h = b_h * exp(b_gk[:, None]) + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32) + b_h = b_h * exp(b_gv[None, :]) + b_h += b_k[:, None] * b_v[None, :] + b_dq = b_h * b_do[None, :] + b_dq = tl.sum(b_dq, axis=1) * scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k) + + p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + p_do += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + p_dq += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + if USE_G: + p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) + if USE_GK: + p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + if USE_GV: + p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + + # sync threads + tl.debug_barrier() + + if HEAD_FIRST: + p_q = q + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_k = k + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_v = v + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_do = do + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_dk = dk + (i_v * B*H + i_nh) * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_dv = dv + (i_k * B*H + i_nh) * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) + if USE_G: + p_g = g + i_nh * T + ((T - 1) if not REVERSE else 0) + if USE_GK: + p_gk = gk + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) + else: + p_q = q + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_k = k + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_do = do + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_dk = dk + ((i_v * all + bos) + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_dv = dv + ((i_k * all + bos) + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + if USE_G: + p_g = g + (bos + ((T - 1) if not REVERSE else 0)) * H + i_h + if USE_GK: + p_gk = gk + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = dht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + b_dh += tl.load(p_dht, mask=mask_h, other=0).to(tl.float32) + + for _ in range(T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + b_dh += b_q[:, None] * b_do[None, :] + b_dk = tl.sum(b_dh * b_v[None, :], axis=1) + b_dv = tl.sum(b_dh * b_k[:, None], axis=0) + if USE_G: + b_g = tl.load(p_g).to(tl.float32) + b_dh *= exp(b_g) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) + b_dh *= exp(b_gk)[:, None] + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32) + b_dh *= exp(b_gv)[None, :] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v) + + p_q += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K + p_k += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K + p_v += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V + p_do += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V + p_dk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K + p_dv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V + if USE_G: + p_g += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) + if USE_GK: + p_gk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K + if USE_GV: + p_gv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V + + if STORE_INITIAL_STATE_GRADIENT: + p_dh0 = dh0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, + gk: Optional[torch.Tensor] = None, + gv: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if offsets is None else len(offsets) - 1 + BK, BV = min(K, 64), min(V, 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + h0 = initial_state + if output_final_state: + ht = q.new_empty(N, H, K, V, dtype=torch.float32) + else: + ht = None + o = q.new_empty(NK, *v.shape, dtype=torch.float32) + + grid = (NV, NK, N * H) + fused_recurrent_fwd_kernel[grid]( + q, + k, + v, + g, + gk, + gv, + o, + h0, + ht, + offsets, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + REVERSE=reverse, + HEAD_FIRST=head_first + ) + o = o.sum(0) + return o, ht + + +def fused_recurrent_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, + gk: Optional[torch.Tensor] = None, + gv: Optional[torch.Tensor] = None, + o: Optional[torch.Tensor] = None, + do: Optional[torch.Tensor] = None, + dht: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + reverse: bool = False, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if offsets is None else len(offsets) - 1 + + BK, BV = min(K, 64), min(V, 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + dq = q.new_empty(NV, *q.shape, dtype=torch.float32) + dk = q.new_empty(NV, *k.shape, dtype=torch.float32) + dv = q.new_empty(NK, *v.shape, dtype=torch.float32) + h0 = initial_state + dh0 = torch.empty_like(initial_state) if initial_state is not None else None + + grid = (NV, NK, N * H) + fused_recurrent_bwd_kernel[grid]( + q, + k, + v, + g, + gk, + gv, + h0, + do, + dq, + dk, + dv, + dht, + dh0, + offsets, + scale, + B=B, + T=T, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + REVERSE=reverse, + HEAD_FIRST=head_first + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dg, dgk, dgv = None, None, None + if g is not None: + dg = chunk_global_cumsum( + (dq * q.float() - dk * k.float()).sum(-1), + reverse=not reverse, + offsets=offsets, + head_first=head_first + ) + if gk is not None: + dgk = chunk_global_cumsum( + dq * q.float() - dk * k.float(), + reverse=not reverse, + offsets=offsets, + head_first=head_first + ) + if gv is not None: + dgv = chunk_global_cumsum( + do.float() * o.float() - dv * v.float(), + reverse=not reverse, + offsets=offsets, + head_first=head_first + ) + + return dq, dk, dv, dg, dgk, dgv, dh0 + + +class FusedRecurrentFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, + gk: Optional[torch.Tensor] = None, + gv: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True + ): + o, ht = fused_recurrent_fwd( + q=q, + k=k, + v=v, + g=g, + gk=gk, + gv=gv, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + offsets=offsets, + head_first=head_first + ) + ctx.save_for_backward(q, k, v, g, gk, gv, initial_state, o) + ctx.scale = scale + ctx.reverse = reverse + ctx.offsets = offsets + ctx.head_first = head_first + return o.to(q.dtype), ht + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht): + q, k, v, g, gk, gv, initial_state, o = ctx.saved_tensors + # not supported yet. + if dht is not None: + if not dht.eq(0).all(): + if g is not None: + assert g.requires_grad is False, "Cannot load final state gradient and use gates at the same time" + if gk is not None: + assert gk.requires_grad is False, "Cannot load final state gradient and use gates at the same time" + if gv is not None: + assert gv.requires_grad is False, "Cannot load final state gradient and use gates at the same time" + dq, dk, dv, dg, dgk, dgv, dh0 = fused_recurrent_bwd( + q=q, + k=k, + v=v, + g=g, + gk=gk, + gv=gv, + o=o, + do=do, + dht=dht, + scale=ctx.scale, + initial_state=initial_state, + reverse=ctx.reverse, + offsets=ctx.offsets, + head_first=ctx.head_first + ) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, dgk, dgv, None, dh0, None, None, None, None + + +def fused_recurrent( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, + gk: Optional[torch.Tensor] = None, + gv: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +): + if scale is None: + scale = k.shape[-1] ** -0.5 + return FusedRecurrentFunction.apply( + q, + k, + v, + g, + gk, + gv, + scale, + initial_state, + output_final_state, + reverse, + cu_seqlens, + head_first + ) diff --git a/fla/ops/delta_rule/README.md b/fla/ops/delta_rule/README.md new file mode 100644 index 0000000000000000000000000000000000000000..607b0d583c7ec2904c18c0f1d86fb0ec2dfdf583 --- /dev/null +++ b/fla/ops/delta_rule/README.md @@ -0,0 +1,90 @@ +# Chunkwise-form Parallelism of DeltaNet + +This section expands on the formulation presented in Appendix B of the DeltaNet paper.[^1] + +To reduce notational clutter, we focus on the first chunk, denoting $\mathbf{S}^r=\mathbf{S}_{[1]}^r$. By partially expanding the recurrence, we have: +```math +\begin{equation} +\begin{aligned} +\mathbf{S}^r &= \underbrace{\left(\prod_{i=1}^r \mathbf{I} - \beta^i \boldsymbol{k}^i \boldsymbol{k}^{i\top} \right)}_{:= \mathbf{P}^r} \cdot\mathbf{S}^{0} + \overbrace{\sum_{i=1}^{r} \underbrace{\left(\prod_{j=i+1}^r \mathbf{I} - \beta^j \boldsymbol{k}^j \boldsymbol{k}^{j\top} \right)}_{:= \mathbf{P}_{i+1}^r}\beta^i \boldsymbol{k}^i\boldsymbol{v}^{i\top}}^{:=\mathbf{H}^r} \\ +&=\mathbf{P}^r \cdot \mathbf{S}^{0} + \mathbf{H}^r +\end{aligned} +\end{equation} +``` + +where $\mathbf{P}_i^r$ involves cumulative products of generalized Householder matrices. +We abbreviate $\mathbf{P}_1^r$ as $\mathbf{P}^r$. +This can be optimized using the classical WY representation: +```math +\begin{equation} +\mathbf{P}^{r} = \mathbf{I} - \sum_{i=1}^{r}\boldsymbol{k}^i\boldsymbol{w}^{i\top} \in \mathbb{R}^{d_k \times d_k};\qquad +\boldsymbol{w}^r = \beta^r \left(\boldsymbol{k}^r - \sum_{i=1}^{r-1} \left(\boldsymbol{k}^{r\top}\boldsymbol{k}^i \right)\boldsymbol{w}^i \right) \in \mathbb{R}^{d_k} +\end{equation} +``` + +We prove this by induction: +```math +\begin{align*} +\mathbf{P}^{r} &= \prod_{i=1}^r \mathbf{I} - \beta^i \boldsymbol{k}^i \boldsymbol{k}^{i\top} \\ +&= \left(\mathbf{I} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top}\right)\mathbf{P}^{r-1} \\ +&= \left(\mathbf{I} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top}\right)\left(\mathbf{I} - \sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top}\right) \\ +&= \mathbf{I} - \sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top} + \beta^r\boldsymbol{k}^r \boldsymbol{k}^{r\top} \left(\sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top}\right) \\ +&= \mathbf{I} - \sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top} - \beta^r \boldsymbol{k}^r \left(\boldsymbol{k}^{r} - \left(\sum_{i=1}^{r-1}\left(\boldsymbol{k}^{r\top} \boldsymbol{k}^i\right)\boldsymbol{w}^{i}\right) \right)^\top \\ +&= \mathbf{I} - \sum_{i=1}^{r}\boldsymbol{k}^i\boldsymbol{w}^{i\top} +\end{align*} +``` + +Similarly, $\mathbf{H}^r$ can be represented as: +```math +\begin{equation} +\mathbf{H}^{r} = \sum_{i=1}^{r} \boldsymbol{k}^i \boldsymbol{u}^{i\top} \in \mathbb{R}^{d_k \times d_v};\qquad \boldsymbol{u}^r = \beta^r \left(\boldsymbol{v}^r - \sum_{i=1}^{r-1} \left(\boldsymbol{k}^{r\top}\boldsymbol{k}^i\right) \boldsymbol{u}^i \right)\in \mathbb{R}^{d_v} +\end{equation} +``` + +This can also be proven by induction: +```math +\begin{align*} +\mathbf{H}^{r} &= \sum_{i=1}^{r} \mathbf{P}_{i+1}^r \beta^i \boldsymbol{k}^i \boldsymbol{v}^{i\top}\\ +&= \left(\mathbf{I} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top}\right) \mathbf{H}^{r-1} + \beta^r \boldsymbol{k}^r \boldsymbol{v}^{r\top}\\ +&= \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top} \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} +\beta^r \boldsymbol{k}^r \boldsymbol{v}^{r\top}\\ +&= \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} + \boldsymbol{k}^r \left(\beta^r \boldsymbol{v}^{r\top}-\beta^r \boldsymbol{k}^{r\top} \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top}\right) \\ +&= \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} + \boldsymbol{k}^r \beta^r\left(\boldsymbol{v}^{r}-\sum_{i=1}^{r-1}\left(\boldsymbol{k}^{r\top}\boldsymbol{k}^{i}\right)\boldsymbol{u}^{i} \right)^\top \\ +&=\sum_{i=1}^{r} \boldsymbol{k}^i \boldsymbol{u}^{i\top} +\end{align*} +``` + +In matrix form, $\mathbf{P}$ and $\mathbf{H}$ can be written as: +```math +\begin{equation} +\mathbf{P}=\mathbf{I}-\mathbf{K}^\top\mathbf{W} \in \mathbb{R}^{d_k \times d_k}, \qquad\mathbf{H}=\mathbf{K}^\top\mathbf{U} \in \mathbb{R}^{d_k\times d_v} +\end{equation} +``` + +Now we can derive the matrix form of $\mathbf{W}$ and $\mathbf{U}$: +```math +\begin{align*} +\mathbf{W} &= \mathrm{diag}(\beta) \mathbf{K} - \mathrm{tril}(\mathrm{diag}(\beta) \mathbf{K}\mathbf{K}^\top, -1)\mathbf{W}\\ +\left(\mathbf{I} + \mathrm{tril}(\mathrm{diag}(\beta) \mathbf{K}\mathbf{K}^\top, -1)\right) \mathbf{W} &= \mathrm{diag}(\beta) \mathbf{K} +\end{align*} +``` +A similar process holds for $\mathbf{U}$. We can further write $\mathbf{W}$ and $\mathbf{U}$ in matrix form: +```math +\begin{align*} +\mathbf{T} &= \left(\mathbf{I} + \mathrm{tril}\left(\mathrm{diag}(\beta)\mathbf{K} \mathbf{K}^\top,-1\right)\right)^{-1}\mathrm{diag}\left(\beta\right)\in \mathbb{R}^{C \times C}\\ +\mathbf{W} &= \mathbf{T} \mathbf{K}\in \mathbb{R}^{C \times d_k}\\ +\mathbf{U} &= \mathbf{T}\mathbf{V}\in \mathbb{R}^{C \times d_v} +\end{align*} +``` + +Substituting these back into the original equations yields a hardware-efficient chunkwise algorithm for DeltaNet that leverages matrix multiplications, enabling tensor core based GPU optimization: +```math +\begin{equation} +\begin{aligned} +\mathbf{S} &= \mathbf{P}\cdot\mathbf{S}^0 + \mathbf{H} \\ +&= \mathbf{S}^0 + \mathbf{K}^\top (\mathbf{U} -\mathbf{W} \mathbf{S}^0) \in \mathbb{R}^{d_k \times d_v}\\ +\mathbf{O} &= \mathbf{Q} \mathbf{S}^0 + (\mathbf{Q} \mathbf{K}^{\top} \odot \mathbf{M}) \left(\mathbf{U} - \mathbf{W} \mathbf{S}^0\right) \in \mathbb{R}^{C \times d_v} +\end{aligned} +\end{equation} +``` + +[^1]: https://arxiv.org/abs/2406.06484 diff --git a/fla/ops/delta_rule/__init__.py b/fla/ops/delta_rule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0acb6a7d0e4eec9a8dc697615604783b8858d13 --- /dev/null +++ b/fla/ops/delta_rule/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_delta_rule +from .fused_chunk import fused_chunk_delta_rule +from .fused_recurrent import fused_recurrent_delta_rule + +__all__ = [ + 'fused_chunk_delta_rule', + 'fused_recurrent_delta_rule', + 'chunk_delta_rule' +] diff --git a/fla/ops/delta_rule/chunk.py b/fla/ops/delta_rule/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..650b63547c81f3b4cd9ae89a33f4815c3abc0435 --- /dev/null +++ b/fla/ops/delta_rule/chunk.py @@ -0,0 +1,373 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +from einops import rearrange + +from fla.modules.l2norm import l2norm_bwd, l2norm_fwd +from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h +from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o +from fla.ops.common.utils import prepare_chunk_indices +from fla.ops.delta_rule.wy_fast import bwd_prepare_wy_repr, fwd_prepare_wy_repr, fwd_recompute_w_u +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +def chunk_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + T = q.shape[2] if head_first else q.shape[1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + # obtain WY representation. u is actually the new v. + w, u, A = fwd_prepare_wy_repr( + k=k, + v=v, + beta=beta, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=None, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=None, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + return o, A, final_state + + +def chunk_delta_rule_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + T = q.shape[2] if head_first else q.shape[1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + w, u = fwd_recompute_w_u( + k=k, + v=v, + beta=beta, + A=A, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + h, v_new, _ = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=None, + initial_state=initial_state, + output_final_state=False, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dv = chunk_bwd_dv_local( + q=q, + k=k, + do=do, + g=None, + dh=None, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu( + q=q, + k=k, + w=w, + g=None, + h0=initial_state, + dht=dht, + do=do, + dv=dv, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dq, dk, dw, _ = chunk_bwd_dqkwg( + q=q, + k=k, + v=v_new, + h=h, + w=w, + dv=dv, + do=do, + dh=dh, + g=None, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dk2, dv, db = bwd_prepare_wy_repr( + k=k, + v=v, + beta=beta, + A=A, + dw=dw, + du=dv, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dk.add_(dk2) + return dq, dk, dv, db, dh0 + + +class ChunkDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True, + use_qk_l2norm_in_kernel: bool = True + ): + T = q.shape[2] if head_first else q.shape[1] + chunk_size = min(64, max(triton.next_power_of_2(T), 16)) + + q_orig = q + k_orig = k + + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None + + o, A, final_state = chunk_delta_rule_fwd( + q=q, + k=k, + v=v, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + ctx.save_for_backward(q_orig, k_orig, v, beta, A, initial_state) + ctx.chunk_size = chunk_size + ctx.scale = scale + ctx.offsets = offsets + ctx.indices = indices + ctx.head_first = head_first + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward( + ctx, + do: torch.Tensor, + dht: torch.Tensor + ): + q, k, v, beta, A, initial_state = ctx.saved_tensors + use_qk_l2norm_in_kernel = ctx.use_qk_l2norm_in_kernel + if use_qk_l2norm_in_kernel: + q, q_orig = l2norm_fwd(q), q + k, k_orig = l2norm_fwd(k), k + + dq, dk, dv, db, dh0 = chunk_delta_rule_bwd( + q=q, + k=k, + v=v, + beta=beta, + A=A, + scale=ctx.scale, + initial_state=initial_state, + do=do, + dht=dht, + offsets=ctx.offsets, + indices=ctx.indices, + head_first=ctx.head_first, + chunk_size=ctx.chunk_size + ) + if use_qk_l2norm_in_kernel: + dq = l2norm_bwd(q_orig, dq) + dk = l2norm_bwd(k_orig, dk) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), db.to(beta.dtype), None, dh0, None, None, None, None, None, None + + +@torch.compiler.disable +def chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + use_qk_l2norm_in_kernel (Optional[bool]): + Whether to use qk l2norm within the kernel for saving GPU memory. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.delta_rule import chunk_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_delta_rule( + q, k, v, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_delta_rule( + q, k, v, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16." + assert len(beta.shape) == 3, "beta must be of shape (batch size, num of head, seq len)." + + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if head_first: + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if head_first: + q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) + beta = rearrange(beta, 'b h t -> b t h') + scale = k.shape[-1] ** -0.5 if scale is None else scale + o, final_state = ChunkDeltaRuleFunction.apply( + q, + k, + v, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + False, + use_qk_l2norm_in_kernel + ) + if head_first: + o = rearrange(o, 'b t h v -> b h t v') + return o, final_state diff --git a/fla/ops/delta_rule/fused_chunk.py b/fla/ops/delta_rule/fused_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..6347fb9af47d3d9f82c03ea9aedbfa09fc1bfbc1 --- /dev/null +++ b/fla/ops/delta_rule/fused_chunk.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +def fused_chunk_delta_rule( + **kwargs +): + raise NotImplementedError("fused_chunk_delta_rule is deprecated. Please use chunk_delta_rule instead.") diff --git a/fla/ops/delta_rule/naive.py b/fla/ops/delta_rule/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..6752aa89a3eac8c00726e09a730152af44343de4 --- /dev/null +++ b/fla/ops/delta_rule/naive.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +def delta_rule_recurrence(q, k, v, beta, initial_state=None, output_final_state=True): + orig_dtype = q.dtype + b, h, l, d_k = q.shape + q, k, v, beta = map(lambda x: x.float(), [q, k, v, beta]) + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + + if beta.ndim < v.ndim: + beta = beta[..., None] + + if initial_state is not None: + S += initial_state + + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v - (S.clone() * _k[..., None]).sum(-2) + _v = _v * beta_i + S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + S = None if output_final_state is False else S + return o.to(orig_dtype), S + + +def delta_rule_chunkwise(q, k, v, beta, chunk_size=32): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + q = q * (d_k ** -0.5) + v = v * beta[..., None] + k_beta = k * beta[..., None] + + assert l % chunk_size == 0 + + # compute (I - tri(diag(beta) KK^T))^{-1} + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0) + q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta]) + attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0) + for i in range(1, chunk_size): + attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device) + + u = attn @ v + w = attn @ k_beta + S = k.new_zeros(b, h, d_k, d_v) + o = torch.zeros_like(v) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) + for i in range(0, l // chunk_size): + q_i, k_i = q[:, :, i], k[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0) + u_i = u[:, :, i] - w[:, :, i] @ S + o_inter = q_i @ S + o[:, :, i] = o_inter + attn @ u_i + S = S + k_i.transpose(-1, -2) @ u_i + + return rearrange(o, 'b h n c d -> b h (n c) d'), S + + +def delta_rule_parallel(q, k, v, beta, BM=128, BN=32): + b, h, l, d_k = q.shape + # d_v = v.shape[-1] + q = q * (d_k ** -0.5) + v = v * beta[..., None] + k_beta = k * beta[..., None] + # compute (I - tri(diag(beta) KK^T))^{-1} + q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=BN), [q, k, v, k_beta]) + mask = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=0) + T = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0) + for i in range(1, BN): + T[..., i, :i] = T[..., i, :i].clone() + (T[..., i, :, None].clone() * T[..., :, :i].clone()).sum(-2) + T = T + torch.eye(BN, dtype=torch.float, device=q.device) + + mask2 = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=1) + A_local = (q @ k.transpose(-1, -2)).masked_fill(mask2, 0) @ T + o_intra = A_local @ v + + # apply cumprod transition matrices on k to the last position within the chunk + k = k - ((k @ k.transpose(-1, -2)).masked_fill(mask, 0) @ T).transpose(-1, -2) @ k_beta + # apply cumprod transition matrices on q to the first position within the chunk + q = q - A_local @ k_beta + o_intra = A_local @ v + + A = torch.zeros(b, h, l, l, device=q.device) + + q, k, v, k_beta, o_intra = map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d'), [q, k, v, k_beta, o_intra]) + o = torch.empty_like(v) + for i in range(0, l, BM): + q_i = q[:, :, i:i+BM] + o_i = o_intra[:, :, i:i+BM] + # intra block + for j in range(i + BM - 2 * BN, i-BN, -BN): + k_j = k[:, :, j:j+BN] + A_ij = q_i @ k_j.transpose(-1, -2) + mask = torch.arange(i, i+BM) >= (j + BN) + A_ij = A_ij.masked_fill_(~mask[:, None].to(A_ij.device), 0) + A[:, :, i:i+BM, j:j+BN] = A_ij + q_i = q_i - A_ij @ k_beta[:, :, j:j+BN] + o_i += A_ij @ v[:, :, j:j+BN] + # inter block + for j in range(i - BN, -BN, -BN): + k_j = k[:, :, j:j+BN] + A_ij = q_i @ k_j.transpose(-1, -2) + A[:, :, i:i+BM, j:j+BN] = A_ij + q_i = q_i - A_ij @ k_beta[:, :, j:j+BN] + o_i += A_ij @ v[:, :, j:j+BN] + o[:, :, i:i+BM] = o_i + + for i in range(0, l//BN): + A[:, :, i*BN:i*BN+BN, i*BN:i*BN+BN] = A_local[:, :, i] + + return o, A diff --git a/fla/ops/delta_rule/parallel.py b/fla/ops/delta_rule/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..722f2dec76271a08398dc105d7d0b8a33917b62e --- /dev/null +++ b/fla/ops/delta_rule/parallel.py @@ -0,0 +1,394 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Tuple + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from fla.ops.delta_rule.wy_fast import fwd_prepare_T +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BT', 'K', 'V'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_transform_qk_fwd_kernel( + q, + k, + v, + beta, + o, + A, + q_new, + k_new, + A_local, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + BT: tl.constexpr, + OUTPUT_ATTENTIONS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, 0), (BT, BV), (1, 0)) + b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(p_q.dtype.element_ty) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + p_T = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_T = tl.load(p_T, boundary_check=(0, 1)) + + o_i = tl.arange(0, BT) + m_t = o_i[:, None] >= o_i[None, :] + b_qk = tl.where(m_t, tl.dot(b_q, tl.trans(b_k), allow_tf32=False), 0).to(b_q.dtype) + m_t = o_i[:, None] > o_i[None, :] + b_kk = tl.where(m_t, tl.dot(b_k, tl.trans(b_k), allow_tf32=False), 0).to(b_k.dtype) + + p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (i_t * BT, ), (BT, ), (0, )) + b_beta = tl.load(p_beta, boundary_check=(0, )) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + + b_qkT = tl.dot(b_qk, b_T, allow_tf32=False).to(b_k.dtype) + + if OUTPUT_ATTENTIONS: + p_a = tl.make_block_ptr(A_local + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_a, b_qkT.to(p_a.dtype.element_ty), boundary_check=(0, 1)) + + b_kkT = tl.dot(b_kk, b_T, allow_tf32=False).to(b_k.dtype) + p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, 0), (BT, BV), (1, 0)) + tl.store(p_o, tl.dot(b_qkT, b_v).to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + p_q_new = tl.make_block_ptr(q_new + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + tl.store(p_q_new, (b_q - tl.dot(b_qkT, b_k_beta, allow_tf32=False)).to(p_q_new.dtype.element_ty), boundary_check=(0, 1)) + + p_k_new = tl.make_block_ptr(k_new + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + b_k_new = b_k - tl.dot(tl.trans(b_kkT), b_k_beta, allow_tf32=False) + tl.store(p_k_new, b_k_new.to(p_k_new.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_transform_qk_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + scale: float, + chunk_size: int, + output_attentions: bool +): + B, H, T, K = k.shape + BT = chunk_size + q_new = torch.empty_like(q) + k_new = torch.empty_like(k) + o = torch.empty_like(v) + grid = (triton.cdiv(T, BT), B*H) + V = v.shape[-1] + A_local = torch.empty_like(A) if output_attentions else None + chunk_transform_qk_fwd_kernel[grid]( + q, + k, + v, + beta, + o, + A, + q_new, + k_new, + A_local, + scale=scale, + T=T, + K=K, + V=V, + BT=BT, + BK=triton.next_power_of_2(K), + BV=triton.next_power_of_2(V), + OUTPUT_ATTENTIONS=output_attentions + ) + return q_new, k_new, o, A_local + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + ], + key=['BT'], +) +@triton.jit(do_not_specialize=['T']) +def save_intra_chunk_attn( + A, + A_local, + T, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + p_A = tl.make_block_ptr(A + i_bh * T * T, (T, T), (T, 1), (i_t * BT, i_t * BT), (BT, BT), (1, 0)) + p_A_local = tl.make_block_ptr(A_local + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A_local = tl.load(p_A_local, boundary_check=(0, 1)) + tl.store(p_A, b_A_local.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'OUTPUT_ATTENTIONS': lambda args: args['attn'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def parallel_delta_rule_fwd_kernel( + q, + k, + k2, # original k + v, + beta, + o, + o_new, + attn, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + OUTPUT_ATTENTIONS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [BT, BK] + b_q = tl.zeros([BT, BK], dtype=tl.float32) + b_q += tl.load(p_q, boundary_check=(0, 1)) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, 0), (BT, BV), (1, 0)) + b_o += tl.load(p_o, boundary_check=(0, 1)) + + # As opposed to Flashattention, this kernel requires scanning the KV blocks from right to left + # Q block and K block have overlap. + # masks required + for offset in range((i_t + 1) * BT - 2 * BS, i_t * BT - BS, -BS): + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (0, offset), (BK, BS), (0, 1)) + p_k2 = tl.make_block_ptr(k2 + i_bh * T*K, (T, K), (K, 1), (offset, 0), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (offset, 0), (BS, BV), (1, 0)) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (offset, ), (BS, ), (0,)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BS] + b_beta = tl.load(p_beta, boundary_check=(0,)) + # [BT, BS] + m_s = tl.arange(0, BT) >= (offset - i_t*BT + BS) + b_s = tl.dot(b_q.to(b_k.dtype), b_k, allow_tf32=False) + b_s = tl.where(m_s[:, None], b_s, 0) + + b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + b_k2 = (tl.load(p_k2, boundary_check=(0, 1)) * b_beta[:, None]).to(b_v.dtype) + b_q -= tl.dot(b_s.to(b_v.dtype), b_k2, allow_tf32=False) + + if OUTPUT_ATTENTIONS: + p_a = tl.make_block_ptr(attn + i_bh * T * T, (T, T), (T, 1), (i_t * BT, offset), (BT, BS), (1, 0)) + tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1)) + + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for offset in range(i_t * BT - BS, -BS, -BS): + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (0, offset), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (offset, 0), (BS, BV), (1, 0)) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (offset, ), (BS, ), (0,)) + p_k2 = tl.make_block_ptr(k2 + i_bh * T*K, (T, K), (K, 1), (offset, 0), (BS, BK), (1, 0)) + + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BS] + b_beta = tl.load(p_beta, boundary_check=(0,)) + # [BT, BS] + b_s = (tl.dot(b_q.to(b_k.dtype), b_k, allow_tf32=False)) + # [BT, BV] + b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + b_k2 = (tl.load(p_k2, boundary_check=(0, 1)) * b_beta[:, None]).to(b_v.dtype) + b_q -= tl.dot(b_s.to(b_v.dtype), b_k2, allow_tf32=False).to(b_q.dtype) + + if OUTPUT_ATTENTIONS: + p_a = tl.make_block_ptr(attn + i_bh * T * T, (T, T), (T, 1), (i_t * BT, offset), (BT, BS), (1, 0)) + tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1)) + + p_o_new = tl.make_block_ptr(o_new + i_bh * T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + tl.store(p_o_new, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +class ParallelDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, beta, scale, output_attentions): + B, H, T, K, V = *k.shape, v.shape[-1] + assert q.shape[-1] <= 128, 'The maximum supported sequence length is 128.' + BT, BS = 128, 32 + BK = triton.next_power_of_2(k.shape[-1]) + BV = triton.next_power_of_2(v.shape[-1]) + assert BT % BS == 0 + + A = fwd_prepare_T(k, beta, BS) + attn = q.new_zeros(B, H, T, T) if output_attentions else None + q_new, k_new, o, A_local = chunk_transform_qk_fwd( + q, + k, + v, + beta, + A, + scale, + BS, + output_attentions + ) + + num_stages = 3 if K <= 64 else 2 + num_warps = 4 + grid = (triton.cdiv(T, BT), B * H) + o_new = torch.empty_like(o) + + parallel_delta_rule_fwd_kernel[grid]( + q=q_new, + k=k_new, + k2=k, + v=v, + beta=beta, + o=o, + o_new=o_new, + attn=attn, + T=T, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + num_stages=num_stages, + num_warps=num_warps + ) + + if output_attentions: + grid = (triton.cdiv(T, BS), B * H) + save_intra_chunk_attn[grid]( + A=attn, + A_local=A_local, + T=T, + BT=BS + ) + return o_new.to(q.dtype), attn + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, d_attn=None): + raise NotImplementedError('Backward pass is not implemented. Stay tuned!') + + +def parallel_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + output_attentions: bool = False, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + beta (torch.Tensor): + betas of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + output_attentions (bool): + Whether to output the materialized attention scores of shape [B, H, T, T]. Default: `False`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + attn (torch.Tensor): + Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None`. + """ + if not head_first: + q, k, v, beta = map(lambda x: x.transpose(1, 2), (q, k, v, beta)) + o, attn = ParallelDeltaRuleFunction.apply(q, k, v, beta, scale, output_attentions) + if not head_first: + o = o.transpose(1, 2) + return o, attn + + +def naive_delta_rule_parallel(q, k, v, beta, BM=128, BN=32): + b, h, l, d_k = q.shape + q = q * (d_k ** -0.5) + v = v * beta[..., None] + k_beta = k * beta[..., None] + # compute (I - tri(diag(beta) KK^T))^{-1} + q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=BN), [q, k, v, k_beta]) + mask = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=0) + T = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0) + for i in range(1, BN): + T[..., i, :i] = T[..., i, :i].clone() + (T[..., i, :, None].clone() * T[..., :, :i].clone()).sum(-2) + T = T + torch.eye(BN, dtype=q.dtype, device=q.device) + + mask2 = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=1) + A_local = (q @ k.transpose(-1, -2)).masked_fill(mask2, 0) @ T + o_intra = A_local @ v + + # apply cumprod transition matrices on k to the last position within the chunk + k = k - ((k @ k.transpose(-1, -2)).masked_fill(mask, 0) @ T).transpose(-1, -2) @ k_beta + # apply cumprod transition matrices on q to the first position within the chunk + q = q - A_local @ k_beta + o_intra = A_local @ v + + A = torch.zeros(b, h, l, l, device=q.device) + + q, k, v, k_beta, o_intra = map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d'), [q, k, v, k_beta, o_intra]) + o = torch.empty_like(v) + for i in range(0, l, BM): + q_i = q[:, :, i:i+BM] + o_i = o_intra[:, :, i:i+BM] + # intra block + for j in range(i + BM - 2 * BN, i-BN, -BN): + k_j = k[:, :, j:j+BN] + A_ij = q_i @ k_j.transpose(-1, -2) + mask = torch.arange(i, i+BM) >= (j + BN) + A_ij = A_ij.masked_fill_(~mask[:, None].to(A_ij.device), 0) + A[:, :, i:i+BM, j:j+BN] = A_ij + q_i = q_i - A_ij @ k_beta[:, :, j:j+BN] + o_i += A_ij @ v[:, :, j:j+BN] + # inter block + for j in range(i - BN, -BN, -BN): + k_j = k[:, :, j:j+BN] + A_ij = q_i @ k_j.transpose(-1, -2) + A[:, :, i:i+BM, j:j+BN] = A_ij + q_i = q_i - A_ij @ k_beta[:, :, j:j+BN] + o_i += A_ij @ v[:, :, j:j+BN] + o[:, :, i:i+BM] = o_i + + for i in range(0, l//BN): + A[:, :, i*BN:i*BN+BN, i*BN:i*BN+BN] = A_local[:, :, i] + + return o, A diff --git a/fla/ops/delta_rule/wy_fast.py b/fla/ops/delta_rule/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..5a863b91556ba0c33ef47f331d24d1c352d64c79 --- /dev/null +++ b/fla/ops/delta_rule/wy_fast.py @@ -0,0 +1,340 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from fla.ops.utils.solve_tril import solve_tril +from fla.utils import check_shared_mem, is_nvidia_hopper + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'], +) +@triton.jit(do_not_specialize=['T']) +def fwd_recompute_w_u_kernel( + k, + v, + beta, + w, + u, + A, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A.to(b_vb.dtype), b_vb, allow_tf32=False) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A.to(b_kb.dtype), b_kb, allow_tf32=False) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'], +) +@triton.jit(do_not_specialize=['T']) +def bwd_prepare_wy_repr_kernel( + k, + v, + beta, + A, + dw, + du, + dk, + dv, + dbeta, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T*BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + else: + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False) + b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False) + b_dv = b_dv_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dv_beta * b_v, 1) + + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False) + b_dk = b_dk_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) + b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + + b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + if HEAD_FIRST: + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_dbeta = tl.make_block_ptr(dbeta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + +def fwd_prepare_wy_repr( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool = False, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + A = chunk_scaled_dot_kkt_fwd( + k=k, + beta=beta, + cu_seqlens=offsets, + head_first=head_first, + chunk_size=chunk_size, + output_dtype=torch.float32 + ) + A = solve_tril( + A=A, + cu_seqlens=offsets, + head_first=head_first, + output_dtype=k.dtype + ) + + w, u = fwd_recompute_w_u( + k=k, + v=v, + beta=beta, + A=A, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return w, u, A + + +def fwd_recompute_w_u( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + u = torch.empty_like(v) + w = torch.empty_like(k) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, + v, + beta, + w, + u, + A, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return w, u + + +def bwd_prepare_wy_repr( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbeta = torch.empty_like(beta) + bwd_prepare_wy_repr_kernel[(NT, B * H)]( + k, + v, + beta, + A, + dw, + du, + dk, + dv, + dbeta, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dk, dv, dbeta diff --git a/fla/ops/forgetting_attn/__init__.py b/fla/ops/forgetting_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e62c741d464f01b5c0c6707671061293b9d48644 --- /dev/null +++ b/fla/ops/forgetting_attn/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .parallel import parallel_forgetting_attn + +__all__ = [ + 'parallel_forgetting_attn' +] diff --git a/fla/ops/forgetting_attn/parallel.py b/fla/ops/forgetting_attn/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..88fea7f29b238e488848711ed894cb6cae7ea91b --- /dev/null +++ b/fla/ops/forgetting_attn/parallel.py @@ -0,0 +1,708 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl +from einops import rearrange, reduce + +from fla.ops.common.utils import prepare_chunk_indices +from fla.ops.utils import chunk_global_cumsum, chunk_local_cumsum +from fla.ops.utils.op import div, exp, log +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_stages in [2, 3, 4, 5] + ], + key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'], +) +@triton.jit +def parallel_forgetting_attn_fwd_kernel( + q, + k, + v, + g, + o, + lse, + scale, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT,] + b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32) + # [BT, BV] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + + b_m = tl.full([BT], float('-inf'), dtype=tl.float32) + b_acc = tl.zeros([BT], dtype=tl.float32) + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BS] + o_k = i_s + tl.arange(0, BS) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BS,] + b_gk = tl.load(p_gk, boundary_check=(0,)) + # [BT, BS] + b_s = tl.dot(b_q, b_k) + b_gq[:, None] - b_gk[None, :] + b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf')) + + # [BT] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [BT, BS] + b_p = exp(b_s - b_m[:, None]) + # [BT] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + # [BT, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) + + b_mp = b_m + + for i_s in range(i_t * BT - BS, -BS, -BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BS,] + b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32) + + b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32) + b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0. + # [BT, BS] + b_s = tl.dot(b_q, b_k) + b_gq[:, None] + (b_gn - b_gk)[None, :] + + b_gq += b_gn - b_gp + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [BT, BS] + b_p = exp(b_s - b_m[:, None]) + # [BT] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + # [BT, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) + + b_mp = b_m + + b_o = div(b_o, b_acc[:, None]) + b_m += log(b_acc) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,)) + + +@triton.jit +def parallel_forgetting_attn_bwd_kernel_preprocess( + o, + do, + delta, + B: tl.constexpr, + V: tl.constexpr +): + i_n = tl.program_id(0) + o_d = tl.arange(0, B) + m_d = o_d < V + + b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0) + b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32) + b_delta = tl.sum(b_o * b_do) + + tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_stages in [2, 3, 4] + ], + key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_forgetting_attn_bwd_kernel_dq( + q, + k, + v, + g, + lse, + delta, + do, + dq, + dg, + scale, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos * HQ + i_hq), (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BT] + b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32) + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + # [BT, BK] + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + # [BT] + b_dg = tl.zeros([BT,], dtype=tl.float32) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1)) + p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BS] + o_k = i_s + tl.arange(0, BS) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BS,] + b_gk = tl.load(p_gk, boundary_check=(0,)) + # [BT, BS] + b_s = tl.dot(b_q, b_k) + (b_gq - b_lse)[:, None] - b_gk[None, :] + b_p = exp(tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))) + + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_do, b_v) + b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + # [BT] + b_dg += tl.sum(b_ds, 1) + + for i_s in range(i_t * BT - BS, -BS, -BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1)) + p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BS,] + b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32) + + b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32) + b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0. + # [BT, BS] + b_s = tl.dot(b_q, b_k) + (b_gq - b_lse)[:, None] + (b_gn - b_gk)[None, :] + b_p = exp(b_s) + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_do, b_v) + b_ds = b_p * (b_dp - b_delta[:, None]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + # [BT] + b_dg += tl.sum(b_ds, 1) + + b_gq += b_gn - b_gp + + b_dq *= scale + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_forgetting_attn_bwd_kernel_dkv( + q, + k, + v, + g, + lse, + delta, + do, + dk, + dv, + dg, + offsets, + indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos * HQ + i_hq), (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + # [BT] + b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32) + b_dg = tl.zeros([BT,], dtype=tl.float32) + + o_k = i_t * BT + tl.arange(0, BT) + m_k = o_k < T + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + p_gq = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BS] + o_q = i_s + tl.arange(0, BS) + # [BS, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BS, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BS] + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32) + + m_q = o_q < T + m_s = (o_k[:, None] <= o_q[None, :]) & m_k[:, None] & m_q[None, :] + # [BT, BS] + b_s = tl.dot(b_k, tl.trans(b_q)) - b_gk[:, None] + (b_gq - b_lse)[None, :] + b_p = tl.where(m_s, exp(b_s), 0) + # [BT, BS] @ [BS, BV] -> [BT, BV] + b_dv += tl.dot(b_p.to(b_do.dtype), b_do) + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BT, BS] + b_ds = b_p * (b_dp - b_delta[None, :]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + # [BT] + b_dg -= tl.sum(b_ds, 1) + + b_gk -= tl.load(g + (bos + min((i_t + 1) * BT, T) - 1) * HQ + i_hq).to(tl.float32) + for i_s in range((i_t + 1) * BT, T, BS): + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + p_gq = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BS] + o_q = i_s + tl.arange(0, BS) + # [BS, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BS, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BS] + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32) + + b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32) + b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0. + # [BT, BS] + b_s = tl.dot(b_k, tl.trans(b_q)) - (b_gk + b_gp)[:, None] + (b_gq - b_lse)[None, :] + b_p = exp(b_s) + # [BT, BS] @ [BS, BV] -> [BT, BV] + b_dv += tl.dot(b_p.to(b_do.dtype), b_do) + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BT, BS] + b_ds = b_p * (b_dp - b_delta[None, :]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + # [BT] + b_dg -= tl.sum(b_ds, 1) + + b_gk -= b_gn - b_gp + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + +def parallel_forgetting_attn_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: float, + chunk_size: int = 128, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + G = HQ // H + BT = chunk_size + BK = max(16, triton.next_power_of_2(K)) + assert V <= 256, "V must be less than or equal to 256" + if check_shared_mem('hopper'): + BS = min(64, max(16, triton.next_power_of_2(T))) + else: + BS = min(32, max(16, triton.next_power_of_2(T))) + BV = min(256, max(16, triton.next_power_of_2(V))) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + + grid = (NV, NT, B * HQ) + parallel_forgetting_attn_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + o=o, + lse=lse, + scale=scale, + offsets=offsets, + indices=indices, + B=B, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + ) + return o, lse + + +def parallel_forgetting_attn_bwd_preprocess( + o: torch.Tensor, + do: torch.Tensor +): + V = o.shape[-1] + delta = torch.empty_like(o[..., 0], dtype=torch.float) + parallel_forgetting_attn_bwd_kernel_preprocess[(delta.numel(),)]( + o=o, + do=do, + delta=delta, + B=triton.next_power_of_2(V), + V=V, + ) + return delta + + +def parallel_forgetting_attn_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + o: torch.Tensor, + lse: torch.Tensor, + do: torch.Tensor, + scale: float = None, + chunk_size: int = 128, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + G = HQ // H + BT = chunk_size + BS = min(32, max(16, triton.next_power_of_2(T))) + BK = max(16, triton.next_power_of_2(K)) + BV = max(16, triton.next_power_of_2(V)) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + delta = parallel_forgetting_attn_bwd_preprocess(o, do) + dq = q.new_empty(B, T, HQ, K, dtype=q.dtype) + dk = q.new_empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float) + dv = q.new_empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float) + dg = q.new_empty(g.shape, dtype=torch.float) + # NOTE: the original `dg` can be destroyed during autotuning + # this is [a known triton issue](https://github.com/triton-lang/triton/issues/5082), which will be fixed in 3.3 (?) + # so we need to make a copy of `dg` + dg2 = q.new_empty(g.shape, dtype=torch.float) + grid = (NV, NT, B * HQ) + parallel_forgetting_attn_bwd_kernel_dq[grid]( + q=q, + k=k, + v=v, + g=g, + lse=lse, + delta=delta, + do=do, + dq=dq, + dg=dg, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV + ) + parallel_forgetting_attn_bwd_kernel_dkv[grid]( + q=q, + k=k, + v=v, + g=g, + lse=lse, + delta=delta, + do=do, + dk=dk, + dv=dv, + dg=dg2, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV + ) + dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum') + dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum') + dg = dg.add_(dg2) + return dq, dk, dv, dg + + +@torch.compile +class ParallelForgettingAttentionFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, g, scale, offsets): + ctx.dtype = q.dtype + if check_shared_mem('hopper'): + chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1]))) + else: + chunk_size = min(64, max(16, triton.next_power_of_2(q.shape[1]))) + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None + + g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=False) + o, lse = parallel_forgetting_attn_fwd( + q=q, + k=k, + v=v, + g=g, + scale=scale, + chunk_size=chunk_size, + offsets=offsets, + indices=indices + ) + ctx.save_for_backward(q, k, v, g, o, lse) + ctx.chunk_size = chunk_size + ctx.offsets = offsets + ctx.indices = indices + ctx.scale = scale + return o.to(q.dtype) + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do): + q, k, v, g, o, lse = ctx.saved_tensors + dq, dk, dv, dg = parallel_forgetting_attn_bwd( + q=q, + k=k, + v=v, + g=g, + o=o, + lse=lse, + do=do, + scale=ctx.scale, + chunk_size=ctx.chunk_size, + offsets=ctx.offsets, + indices=ctx.indices + ) + dg = chunk_global_cumsum(dg, reverse=True, head_first=False, offsets=ctx.offsets) + return dq.to(q), dk.to(k), dv.to(v), dg.to(g), None, None, None, None, None, None, None, None + + +def parallel_forgetting_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA will be applied if HQ is divisible by H. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + Forget gates (in **log space**) of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if g is not None: + g = g.float() + if head_first: + q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) + g = rearrange(g, 'b h t -> b t h') + o = ParallelForgettingAttentionFunction.apply(q, k, v, g, scale, cu_seqlens) + if head_first: + o = rearrange(o, 'b t h d -> b h t d') + return o diff --git a/fla/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc b/fla/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3efbd1215a5e510bb6a7150c6fab00c697cb899 Binary files /dev/null and b/fla/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/ops/gated_delta_rule/chunk.py b/fla/ops/gated_delta_rule/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..abbb52a56fbaf62a4c818c32217dc8c95a0e2292 --- /dev/null +++ b/fla/ops/gated_delta_rule/chunk.py @@ -0,0 +1,392 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +from einops import rearrange + +from fla.modules.l2norm import l2norm_bwd, l2norm_fwd +from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h +from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o +from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr, fwd_prepare_wy_repr, fwd_recompute_w_u +from fla.ops.utils import chunk_local_cumsum +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first) + # obtain WY representation. u is actually the new v. + w, u, Aw, Au = fwd_prepare_wy_repr( + k=k, + v=v, + beta=beta, + g=g, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + + # obtain output + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return g, o, Aw, Au, final_state + + +def chunk_gated_delta_rule_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + Aw: torch.Tensor, + Au: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + T = q.shape[2] if head_first else q.shape[1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + w, u = fwd_recompute_w_u( + k=k, + v=v, + beta=beta, + Aw=Aw, + Au=Au, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + h, v_new, _ = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=False, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dv = chunk_bwd_dv_local( + q=q, + k=k, + g=g, + do=do, + dh=None, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu( + q=q, + k=k, + w=w, + g=g, + h0=initial_state, + dht=dht, + do=do, + dv=dv, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dq, dk, dw, dg = chunk_bwd_dqkwg( + q=q, + k=k, + v=v_new, + w=w, + g=g, + h=h, + dv=dv, + do=do, + dh=dh, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dk2, dv, db, dg2 = bwd_prepare_wy_repr( + k=k, + v=v, + beta=beta, + g=g, + Aw=Aw, + Au=Au, + dw=dw, + du=dv, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dk.add_(dk2) + dg.add_(dg2) + assert dg.dtype == torch.float32, "dg should be fp32" + dg = chunk_local_cumsum(dg, chunk_size, reverse=True, offsets=offsets, indices=indices, head_first=head_first) + return dq, dk, dv, db, dg, dh0 + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True, + use_qk_l2norm_in_kernel: bool = False + ): + chunk_size = 64 + q_orig = q + k_orig = k + + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = None + if offsets is not None: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()]) + indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) + + g, o, Aw, Au, final_state = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size, + ) + ctx.save_for_backward(q_orig, k_orig, v, g, beta, Aw, Au, initial_state, offsets, indices) + ctx.chunk_size = chunk_size + ctx.scale = scale + ctx.head_first = head_first + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward( + ctx, + do: torch.Tensor, + dht: torch.Tensor + ): + q, k, v, g, beta, Aw, Au, initial_state, offsets, indices = ctx.saved_tensors + if ctx.use_qk_l2norm_in_kernel: + q, q_orig = l2norm_fwd(q), q + k, k_orig = l2norm_fwd(k), k + dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + Aw=Aw, + Au=Au, + scale=ctx.scale, + initial_state=initial_state, + do=do, + dht=dht, + offsets=offsets, + indices=indices, + head_first=ctx.head_first, + chunk_size=ctx.chunk_size + ) + if ctx.use_qk_l2norm_in_kernel: + dq = l2norm_bwd(q_orig, dq) + dk = l2norm_bwd(k_orig, dk) + return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + head_first=False + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens, + head_first=False + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + assert len(beta.shape) == 3, "beta must be of shape [B, H, T] if head_first=True, or [B, T, H] if head_first=False." + + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if head_first: + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if head_first: + q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) + beta, g = map(lambda x: rearrange(x, 'b h t -> b t h'), (beta, g)) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "Scale must be positive." + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + False, + use_qk_l2norm_in_kernel + ) + if head_first: + o = rearrange(o, 'b t h v -> b h t v') + return o, final_state diff --git a/fla/ops/gated_delta_rule/fused_recurrent.py b/fla/ops/gated_delta_rule/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..4c73b8a40be4044982d714f5922a7fc324a4fbb8 --- /dev/null +++ b/fla/ops/gated_delta_rule/fused_recurrent.py @@ -0,0 +1,321 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from fla.ops.utils.op import exp +from fla.utils import input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + offsets, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * H + i_h) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * H + i_h) * V + o_v + else: + p_beta = beta + bos * H + i_h + p_g = g + bos * H + i_h + p_o = o + ((i_k * all + bos) * H + i_h) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_g = tl.load(p_g).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) + b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) + b_q = b_q * scale + # [BK, BV] + b_h *= exp(b_g) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + p_q += H*K + p_k += H*K + p_o += H*V + p_v += H*V + p_g += H + p_beta += H * (V if IS_BETA_HEADWISE else 1) + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + use_qk_l2norm_in_kernel: bool = False, + offsets: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if offsets is None else len(offsets) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if output_final_state: + final_state = q.new_empty(N, H, K, V, dtype=torch.float32) + else: + final_state = None + + grid = (NK, NV, N * H) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + offsets=offsets, + scale=scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False + ): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + offsets=offsets + ) + + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht): + raise NotImplementedError( + "Backward pass is not implemented yet and we do not have plans to implement it " + "because we haven't figured out how to compute dg without materializing the full " + "hidden states for all time steps." + ) + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, + head_first: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + g (decays) of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, H, device='cuda')) + >>> beta = torch.rand(B, T, H, device='cuda').sigmoid() + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if head_first: + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + if head_first: + q, k, v, g, beta = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g, beta)) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel + ) + if head_first: + o = rearrange(o, 'b t h v -> b h t v') + return o, final_state diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..f80b2251f32e60dda83735f74183546b15ef45a0 --- /dev/null +++ b/fla/ops/gated_delta_rule/wy_fast.py @@ -0,0 +1,620 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import safe_exp +from fla.utils import check_shared_mem + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'BT', 'BK', 'BC', 'HEAD_FIRST', 'USE_OFFSETS'], +) +@triton.jit(do_not_specialize=['T']) +def fwd_prepare_wy_repr_kernel_chunk32( + k, + g, + beta, + Aw, + Au, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_Aw = tl.zeros([BC, BC], dtype=tl.float32) + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + + b_beta = tl.load(p_beta, boundary_check=(0,)) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_Aw += tl.dot(b_kb, tl.trans(b_k)) + + b_Aw = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw, 0) + + if HEAD_FIRST: + p_g = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_g = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + + b_g = tl.load(p_g, boundary_check=(0,)) + b_Au = b_Aw * safe_exp(b_g[:, None] - b_g[None, :]) + + for i in range(1, BC): + mask = tl.arange(0, BC) == i + b_aw = tl.sum(tl.where(mask[:, None], b_Aw, 0), 0) + b_au = tl.sum(tl.where(mask[:, None], b_Au, 0), 0) + b_aw = b_aw + tl.sum(b_aw[:, None] * b_Aw, 0) * (tl.arange(0, BC) < i) + b_au = b_au + tl.sum(b_au[:, None] * b_Au, 0) * (tl.arange(0, BC) < i) + b_Aw = tl.where(mask[:, None], b_aw, b_Aw) + b_Au = tl.where(mask[:, None], b_au, b_Au) + + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_Aw += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_Au += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + if HEAD_FIRST: + p_Aw = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Au = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + else: + p_Aw = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Au = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + tl.store(p_Aw, b_Aw.to(p_Aw.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Au, b_Au.to(p_Au.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'BT', 'BK', 'BC', 'USE_OFFSETS', 'HEAD_FIRST'], +) +@triton.jit(do_not_specialize=['T']) +def fwd_prepare_wy_repr_kernel_chunk64( + k, + g, + beta, + Aw, + Au, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_Aw = tl.zeros([BC, BC], dtype=tl.float32) + b_Aw2 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aw3 = tl.zeros([BC, BC], dtype=tl.float32) + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT,), (BC,), (0,)) + p_beta2 = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT + BC,), (BC,), (0,)) + else: + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BC,), (0,)) + p_beta2 = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT + BC,), (BC,), (0,)) + + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_beta2 = tl.load(p_beta2, boundary_check=(0,)) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_k2 = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_k2 = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_k2 = tl.load(p_k2, boundary_check=(0, 1)) + b_kb2 = (b_k2 * b_beta2[:, None]).to(b_k2.dtype) + b_Aw += tl.dot(b_kb, tl.trans(b_k)) + b_Aw2 += tl.dot(b_kb2, tl.trans(b_k2)) + b_Aw3 += tl.dot(b_kb2, tl.trans(b_k)) + + b_Aw = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw, 0) + b_Aw2 = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw2, 0) + + if HEAD_FIRST: + p_g = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT,), (BC,), (0,)) + p_g2 = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT + BC,), (BC,), (0,)) + else: + p_g = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT,), (BC,), (0,)) + p_g2 = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT + BC,), (BC,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g2 = tl.load(p_g2, boundary_check=(0,)) + + mask_c = tl.arange(0, BC)[:, None] >= tl.arange(0, BC)[None, :] + mask_g = i_t * BT + tl.arange(0, BC) < T + mask_g2 = i_t * BT + BC + tl.arange(0, BC) < T + + b_Au = tl.where(mask_g[None, :] & mask_c, b_Aw * safe_exp(b_g[:, None] - b_g[None, :]), 0) + b_Au2 = tl.where(mask_g2[None, :] & mask_c, b_Aw2 * safe_exp(b_g2[:, None] - b_g2[None, :]), 0) + b_Au3 = tl.where(mask_g[None, :], b_Aw3 * safe_exp(b_g2[:, None] - b_g[None, :]), 0) + + for i in range(1, BC): + mask = tl.arange(0, BC) == i + b_aw = tl.sum(tl.where(mask[:, None], b_Aw, 0), 0) + b_aw2 = tl.sum(tl.where(mask[:, None], b_Aw2, 0), 0) + b_au = tl.sum(tl.where(mask[:, None], b_Au, 0), 0) + b_au2 = tl.sum(tl.where(mask[:, None], b_Au2, 0), 0) + b_aw = b_aw + tl.sum(b_aw[:, None] * b_Aw, 0) * (tl.arange(0, BC) < i) + b_aw2 = b_aw2 + tl.sum(b_aw2[:, None] * b_Aw2, 0) * (tl.arange(0, BC) < i) + b_au = b_au + tl.sum(b_au[:, None] * b_Au, 0) * (tl.arange(0, BC) < i) + b_au2 = b_au2 + tl.sum(b_au2[:, None] * b_Au2, 0) * (tl.arange(0, BC) < i) + b_Aw = tl.where(mask[:, None], b_aw, b_Aw) + b_Aw2 = tl.where(mask[:, None], b_aw2, b_Aw2) + b_Au = tl.where(mask[:, None], b_au, b_Au) + b_Au2 = tl.where(mask[:, None], b_au2, b_Au2) + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_Aw += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_Aw2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + # improve precision by disallowing tf32. + b_Aw3 = -tl.dot(tl.dot(b_Aw2, b_Aw3, allow_tf32=False), b_Aw, allow_tf32=False) + b_Au += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_Au2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_Au3 = -tl.dot(tl.dot(b_Au2, b_Au3, allow_tf32=False), b_Au, allow_tf32=False) + + if HEAD_FIRST: + p_Aw1 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Aw2 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_Aw3 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_Aw4 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + p_Au1 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Au2 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_Au3 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_Au4 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + else: + p_Aw1 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Aw2 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_Aw3 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_Aw4 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + p_Au1 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Au2 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_Au3 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_Au4 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + + tl.store(p_Aw1, b_Aw.to(p_Aw1.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aw2, b_Aw2.to(p_Aw2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aw3, b_Aw3.to(p_Aw3.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aw4, tl.zeros([BC, BC], dtype=tl.float32).to(p_Aw4.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Au1, b_Au.to(p_Au1.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Au2, b_Au2.to(p_Au2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Au3, b_Au3.to(p_Au3.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Au4, tl.zeros([BC, BC], dtype=tl.float32).to(p_Au4.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'], +) +@triton.jit(do_not_specialize=['T']) +def fwd_recompute_w_u_kernel( + k, + v, + beta, + w, + u, + Aw, + Au, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_Au = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_Au = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_Au = tl.load(p_Au, boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_Au, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + b_Au = None + if HEAD_FIRST: + p_Aw = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_Aw = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_Aw = tl.load(p_Aw, boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_w = tl.dot(b_Aw, b_kb) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def fwd_prepare_wy_repr( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K = k.shape + else: + B, T, H, K = k.shape + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BC = min(BT, 32) + BK = min(triton.next_power_of_2(K), 64) + # bf16 should be good enough. + Aw = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=k.dtype) + Au = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=k.dtype) + + fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32 + fwd_fn[(NT, B*H)]( + k=k, + g=g, + beta=beta, + Aw=Aw, + Au=Au, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + BT=BT, + BK=BK, + BC=BC, + HEAD_FIRST=head_first + ) + w, u = fwd_recompute_w_u( + k=k, + v=v, + beta=beta, + Aw=Aw, + Au=Au, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return w, u, Aw, Au + + +def fwd_recompute_w_u( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + Aw: torch.Tensor, + Au: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + + u = torch.empty_like(v) + w = torch.empty_like(k) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + Aw=Aw, + Au=Au, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return w, u + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'] +) +@triton.jit(do_not_specialize=['T']) +def bwd_prepare_wy_repr_kernel( + k, + v, + beta, + g, + Aw, + Au, + dw, + du, + dk, + dv, + dbeta, + dg, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(Aw + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + else: + p_beta = tl.make_block_ptr(beta + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False) + b_dk = b_dk_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) + b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty) + + if HEAD_FIRST: + p_A = tl.make_block_ptr(Au + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + else: + p_A = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dA2 = tl.zeros([BT, BT], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA2 += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False) + b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False) + b_dv = b_dv_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dv_beta * b_v, 1) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dA2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA2, 0) + b_dA2 = tl.dot(b_dA2.to(b_A.dtype), b_A) + b_dA2 = tl.dot(b_A, b_dA2.to(b_A.dtype)) + b_dA2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA2, 0).to(k.dtype.element_ty) + if HEAD_FIRST: + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_g = tl.make_block_ptr(g + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_dA2 *= safe_exp(b_g[:, None] - b_g[None, :]) + b_dA += b_dA2 + b_dA = b_dA.to(k.dtype.element_ty) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_A += tl.dot(b_k_beta, tl.trans(b_k)) + b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + b_dA2 *= b_A + b_dg = tl.sum(b_dA2, axis=1) - tl.sum(b_dA2, axis=0) + if HEAD_FIRST: + p_dg = tl.make_block_ptr(dg + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_dg = tl.make_block_ptr(dg + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_dbeta = tl.make_block_ptr(dbeta + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + +def bwd_prepare_wy_repr( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + Aw: torch.Tensor, + Au: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbeta = torch.empty_like(beta) + dg = torch.empty_like(g) + bwd_prepare_wy_repr_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + g=g, + Aw=Aw, + Au=Au, + dw=dw, + du=du, + dk=dk, + dv=dv, + dbeta=dbeta, + dg=dg, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dk, dv, dbeta, dg diff --git a/fla/ops/generalized_delta_rule/README.md b/fla/ops/generalized_delta_rule/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f96c22f44a51ad3e6fdeb824eb2aded660223600 --- /dev/null +++ b/fla/ops/generalized_delta_rule/README.md @@ -0,0 +1,37 @@ +# Generalized Delta Rule + +In delta rule we have the recurrence: + +```math +\mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}-\beta_t \mathbf{k}_t\mathbf{k}_t^T) + \beta_t \mathbf{v}_t\mathbf{k}_t^T +``` + +This repository implements a delta rule variant where $\mathbf{I}$ is not necessarily an identity matrix; $\mathbf{k}_t$ in $\mathbf{I} - \beta_t \mathbf{k}_t\mathbf{k}_t^T$ might be different from input $\mathbf{k}_t$ in $\mathbf{v}_t\mathbf{k}_t^T$. + +## IPLR (Identity Plus Low Rank) + +The first variant is IPLR, where we have: + +```math +\mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T +``` + +When $\mathbf{a}_t = -\beta_t \mathbf{k}_t$, $\mathbf{b}_t = \mathbf{k}_t$, $\mathbf{v}_t= \beta_t \mathbf{v}_t$, we recover the original delta rule. Since here the transition matrix is identity-plus-low-rank, we refer to this variant as IPLR. + +### Numerical Stability + +$\mathbf{a}_t$ and $\mathbf{b}_t$ must be in opposite directions, that is, $\mathbf{b}_t = \lambda_t \mathbf{a}_t$ where $\lambda_t < 0$. For an understanding of why this is necessary, you can derive the eigenvalues of the transition matrix. + +## DPLR (Diagonal Plus Low Rank) + +The second variant is DPLR, where we have: + +```math +\mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{D}_t+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T +``` + +Here, $\mathbf{I}$ is replaced by a diagonal matrix $\mathbf{D}_t$. This transition matrix structure has been utilized in RWKV7. + +## Efficient Chunkwise Implementation + +For detailed information about efficient chunkwise implementation, please refer to our [technical note](https://drive.google.com/file/d/1rJbO3dU4fe7OKG3w7Yg058z_BNIuavNF/view?usp=sharing). diff --git a/fla/ops/generalized_delta_rule/__init__.py b/fla/ops/generalized_delta_rule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b4155a215ca8c44ea45d6b151b1e584872ed6c --- /dev/null +++ b/fla/ops/generalized_delta_rule/__init__.py @@ -0,0 +1,9 @@ +from .dplr import chunk_dplr_delta_rule, fused_recurrent_dplr_delta_rule +from .iplr import chunk_iplr_delta_rule, fused_recurrent_iplr_delta_rule + +__all__ = [ + 'chunk_dplr_delta_rule', + 'fused_recurrent_dplr_delta_rule', + 'chunk_iplr_delta_rule', + 'fused_recurrent_iplr_delta_rule' +] diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..341f2284c89732b572ceb387ba01cfe70a5e629f Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-312.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bffc0ddb6face01258537fd6a65d86c07d901222 Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-312.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py b/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..d9ff775184d4f1fa4472bb172da19fdd45553ed6 --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import check_shared_mem, is_intel_alchemist, use_cuda_graph + +# https://github.com/intel/intel-xpu-backend-for-triton/issues/3449 +triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {} + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config(triton_config, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def bwd_prepare_wy_repr_kernel( + A_ab_inv, + A_ak, + ag, + v, + dw, + du, + dv, + dv0, + dag, + dAak, + dAab, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_Aak_t = tl.make_block_ptr(A_ak + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_dAak = tl.make_block_ptr(dAak + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dAab = tl.make_block_ptr(dAab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_Aak_t = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_dAak = tl.make_block_ptr(dAak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dAab = tl.make_block_ptr(dAab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1)) + b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1)) + b_A_ak_t = tl.where(tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0) + b_A_ab_inv_t = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0) + b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty) + b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv0 = tl.make_block_ptr(dv0 + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv0 = tl.make_block_ptr(dv0 + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v)) + b_dv0 = tl.load(p_dv0, boundary_check=(0, 1)) + b_dv = b_dv0 + tl.dot(b_A_tmp_t, b_du) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dA_tmp = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_tmp, 0) + b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp) + b_dA_ak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ak, 0) + tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1)) + b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dag = tl.make_block_ptr(dag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dag = tl.make_block_ptr(dag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_ag = tl.load(p_ag, boundary_check=(0, 1)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag)) + b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw) + tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1)) + + # if we know dL/dA^(-1), for dL/dA, we can use the following formula: + # dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T + # in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1. + # denote A = I - lower(A_ab), B = A^-1 + # in the backward pass. + # dL/dA = -(B)^T @ (dL/dB) @ B^T + # dL/dA_ab = lower(B^T @ dL/dB @ B^T) + b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0) + b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv) + b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t) + b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ab_inv, 0) + tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1)) + + +def chunk_dplr_bwd_wy( + A_ab_inv: torch.Tensor, + A_ak: torch.Tensor, + v: torch.Tensor, + ag: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + dv0: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + A_ab_inv, A_ak, v, ag, dw, du = map(lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du]) + if head_first: + B, H, T, K, V = *dw.shape, du.shape[-1] + else: + B, T, H, K, V = *dw.shape, du.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(V), 32) + + dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float) + dA_ak = torch.empty_like(A_ak, dtype=torch.float) + dv = torch.empty_like(v) + dag = torch.empty_like(ag) + + bwd_prepare_wy_repr_kernel[(NT, B * H)]( + A_ab_inv=A_ab_inv, + A_ak=A_ak, + ag=ag, + v=v, + dw=dw, + du=du, + dv=dv, + dv0=dv0, + dag=dag, + dAak=dA_ak, + dAab=dA_ab, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dA_ab, dA_ak, dv, dag diff --git a/fla/ops/generalized_delta_rule/iplr/wy_fast.py b/fla/ops/generalized_delta_rule/iplr/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..9fdfa7091500873765a36c6ef86506a203f4be19 --- /dev/null +++ b/fla/ops/generalized_delta_rule/iplr/wy_fast.py @@ -0,0 +1,338 @@ + +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import check_shared_mem, is_nvidia_hopper + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BK'] +) +@triton.jit(do_not_specialize=['T']) +def fwd_prepare_wy_repr_kernel_chunk32( + a, + b, + A, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, # dummy placeholder + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + else: + p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_A += tl.dot(b_a, b_b) + + b_A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0) + for i in range(1, BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + + if HEAD_FIRST: + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BK'] +) +@triton.jit(do_not_specialize=['T']) +def fwd_prepare_wy_repr_kernel_chunk64( + a, + b, + A, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + b_A2 = tl.zeros([BC, BC], dtype=tl.float32) + b_A3 = tl.zeros([BC, BC], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_a1 = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_a2 = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) + p_b1 = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BC), (0, 1)) + p_b2 = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1)) + else: + p_a1 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_a2 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) + p_b1 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BC), (0, 1)) + p_b2 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1)) + b_a1 = tl.load(p_a1, boundary_check=(0, 1)) + b_a2 = tl.load(p_a2, boundary_check=(0, 1)) + b_b1 = tl.load(p_b1, boundary_check=(0, 1)) + b_b2 = tl.load(p_b2, boundary_check=(0, 1)) + b_A += tl.dot(b_a1, b_b1, allow_tf32=False) + b_A2 += tl.dot(b_a2, b_b2, allow_tf32=False) + b_A3 += tl.dot(b_a2, b_b1, allow_tf32=False) + + b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0) + b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0) + + for i in range(1, BC): + mask = tl.arange(0, BC) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i) + b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A2 = tl.where(mask[:, None], b_a2, b_A2) + + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A3 = tl.dot(tl.dot(b_A2, b_A3, allow_tf32=False), b_A, allow_tf32=False) + + if HEAD_FIRST: + p_A1 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A3 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A4 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + else: + p_A1 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A3 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A4 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1)) + # causal mask + tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in NUM_WARPS + ], + key=['BT', 'BK', 'BV'] +) +@triton.jit(do_not_specialize=['T']) +def fwd_wu_kernel( + w, + u, + a, + k, + v, + A, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_Aak = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_w = tl.dot(b_A, b_a) + b_Aak += tl.dot(b_a, tl.trans(b_k)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + b_Aak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_Aak, 0) + b_Aak = b_Aak.to(k.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v = tl.dot(b_Aak, b_v).to(v.dtype.element_ty) + b_u = tl.dot(b_A, b_v) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +def fwd_prepare_wy_repr( + a: torch.Tensor, + b: torch.Tensor, + v: torch.Tensor, + k: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K = a.shape + else: + B, T, H, K = a.shape + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BC = min(BT, 32) + BK = min(triton.next_power_of_2(K), 64) + + A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=a.device, dtype=a.dtype) + fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32 + + fwd_fn[(NT, B * H)]( + a=a, + b=b, + A=A, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + BT=BT, + BK=BK, + BC=BC, + HEAD_FIRST=head_first + ) + w, u = fwd_wu( + a=a, + v=v, + k=k, + A=A, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return w, u, A + + +def fwd_wu( + a: torch.Tensor, + v: torch.Tensor, + k: torch.Tensor, + A: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *a.shape, v.shape[-1] + else: + B, T, H, K, V = *a.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + + u = torch.empty_like(v) + w = torch.empty_like(a) + fwd_wu_kernel[(NT, B*H)]( + a=a, + v=v, + w=w, + u=u, + A=A, + k=k, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return w, u diff --git a/fla/ops/gla/fused_chunk.py b/fla/ops/gla/fused_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..318be21402e6e1d1324a5eed4bc318f100e4c59c --- /dev/null +++ b/fla/ops/gla/fused_chunk.py @@ -0,0 +1,631 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Tuple + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from einops import rearrange +from packaging import version + +from fla.ops.utils import chunk_local_cumsum +from fla.ops.utils.op import exp, safe_exp +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.jit(do_not_specialize=['T']) +def prepare_qg_kg( + q, + k, + g, + qg, + kg, + scale, + T, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_q = q + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK) + p_g = g + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK) + p_qg = qg + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK) + p_kg = kg + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK) + + mask = (i_k * BK + tl.arange(0, BK)) < K + + last_decay = tl.load(g + i_bh * T*K + (i_c * BT + BT - 1) * K + i_k * BK + tl.arange(0, BK)) + + for _ in range(BT): + b_q = tl.load(p_q, mask=mask, other=0) + b_k = tl.load(p_k, mask=mask, other=0) + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_q *= exp(b_g) * scale + b_k *= exp(last_decay - b_g) + tl.store(p_kg, b_k.to(p_kg.dtype.element_ty), mask=mask) + tl.store(p_qg, b_q.to(p_qg.dtype.element_ty), mask=mask) + p_q += K + p_g += K + p_k += K + p_kg += K + p_qg += K + + +@triton.jit(do_not_specialize=['T']) +def bwd_decay_global_cumsum( + dq_inner, + dq_inter, + dk_inner, + dk_inter, + q, + k, + g, + dg, + T, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_q = q + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_k = k + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_g = g + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_dg = dg + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_dq_inner = dq_inner + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_dk_inner = dk_inner + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_dq_inter = dq_inter + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_dk_inter = dk_inter + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + cum_grad_dg = tl.zeros([BK], dtype=tl.float32) + mask = (i_k * BK + tl.arange(0, BK)) < K + last_g = tl.zeros([BK], dtype=tl.float32) + for j in range(BT-1, -1, -1): + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + if j == (BT-1): + last_g = b_g + b_dq1 = tl.load(p_dq_inner, mask=mask, other=0) + b_dq2 = tl.load(p_dq_inter, mask=mask, other=0) + b_dq2 *= exp(b_g) + b_dq = b_dq1 + b_dq2 + tl.store(p_dq_inter, b_dq, mask=mask) + b_dk1 = tl.load(p_dk_inner, mask=mask, other=0) + b_dk2 = tl.load(p_dk_inter, mask=mask, other=0) + b_dk2 *= safe_exp(last_g - b_g) + b_dk = b_dk1 + b_dk2 + tl.store(p_dk_inter, b_dk, mask=mask) + b_q = tl.load(p_q, mask=mask, other=0) + b_k = tl.load(p_k, mask=mask, other=0) + b_dg = b_dq * b_q - b_dk * b_k + cum_grad_dg += b_dg + tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask) + p_g -= K + p_k -= K + p_q -= K + p_dq_inner -= K + p_dk_inner -= K + p_dq_inter -= K + p_dk_inter -= K + p_dg -= K + + +@triton.jit(do_not_specialize=['T']) +def fused_chunk_gla_fwd_kernel( + q, + k, + v, + g, + o, + h0, + ht, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0)) + p_gn = g + i_bh * T*K + (BT - 1) * K + i_k * BK + tl.arange(0, BK) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + mask = (i_k * BK + tl.arange(0, BK)) < K + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_gn = tl.load(p_gn, mask=mask, other=0).to(tl.float32) + if CHECK and i == 0: + b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False) + b_h = b_h * exp(b_gn)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False) + else: + b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False) + b_h = b_h * exp(b_gn)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False) + + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_gn += BT * K + + if STORE_FINAL_STATE: + p_final = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1)) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit(do_not_specialize=['T']) +def fused_chunk_gla_bwd_kernel( + q, k, v, g, + do, + dq, + dk, + dv, + h0, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + # clamp_min, # minimum log value of the gate for numerical stability. default: -5 + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + mask = (i_k * BK + tl.arange(0, BK)) < K + for i in range(0, tl.cdiv(T, BT)): + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = g + i_bh * T*K + ((i+1) * BT - 1) * K + i_k * BK + tl.arange(0, BK) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + # [BT, K] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gn = tl.load(p_gn, mask=mask, other=0).to(tl.float32) + + # [V, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [V, K] + if CHECK and i == 0: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h * exp(b_gn)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False) + else: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h * exp(b_gn)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + + # cum = tl.zeros([BK], dtype=tl.float32) + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = g + i_bh * T*K + (T - (i-1) * BT - 1) * K + i_k * BK + tl.arange(0, BK) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * T*K, (T, K), + (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * T*V, (T, V), + (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + # [K, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BT, K] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, V] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_db = tl.load(p_gn, mask=mask, other=0).to(tl.float32) + + # inter-chunk + # [K, V] + if CHECK and i == 1: + b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False)) + b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False) + b_dh = b_dh * exp(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False) + else: + b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False)) + b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False) + b_dh = b_dh * exp(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def fwd_inner_chunk( + q, k, g, A, + scale, # K ** -0.5 + B: tl.constexpr, # B + H: tl.constexpr, # H + T, # T + K: tl.constexpr, # K + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr # BLOCK SIZE along the K dimension +): + + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + + mask = (i_k * BK + tl.arange(0, BK)) < K + o_i = tl.arange(0, BT) + + p_q = q + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK) + p_gq = g + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK) + p_A = A + (i_bh + (i_k * B * H)) * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT) + + for i in range(BT): + b_q = tl.load(p_q, mask=mask, other=0) * scale + b_gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32) + s = b_q[None, :] * b_k * safe_exp(b_gq[None, :] - b_g) + score = tl.sum(s, axis=1) + score = tl.where(o_i <= i, score, 0) + tl.store(p_A, score.to(p_A.dtype.element_ty)) + p_q += K + p_gq += K + p_A += BT + + +@triton.jit +def bwd_inner_chunk( + q, + k, + g, + dA, + dq, + dk, + T, # T + K: tl.constexpr, # K + # clamp_min, # minimum log value of the gate for numerical stability. default: -5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + + mask = (i_k * BK + tl.arange(0, BK)) < K + o_i = tl.arange(0, BT) + + p_q = q + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK) + p_dq = dq + (i_bh) * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK) + p_gq = g + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK) + p_dA = dA + i_bh * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT) + + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + + for i in range(BT): + b_q = tl.load(p_q, mask=mask, other=0) + b_gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32) + score = safe_exp(b_gq[None, :] - b_g) + score = tl.where(o_i[:, None] <= i, score, 0) + b_dA = tl.load(p_dA) + b_dA = tl.where(o_i <= i, b_dA, 0) + b_dk += (b_dA[:, None] * score * b_q[None, :]) + b_dq = tl.sum(b_dA[:, None] * score * b_k, axis=0) + tl.store(p_dq, b_dq, mask=mask) + p_q += K + p_dq += K + p_gq += K + p_dA += BT + + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkGLAFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, g, scale, initial_state, output_final_state): + ctx.g_dtype = g.dtype + ctx.scale = scale + B, H, T, K, V = *k.shape, v.shape[-1] + BT = 16 # chunk_size + BK, BV = min(K, 64), min(V, 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 2 + + g_org = g + # cumulative decay should be in float32, otherwise the err will be accumulated and amplified. + g = chunk_local_cumsum(g_org, chunk_size=BT) + o = q.new_empty(NK, B, H, T, V) + q_g = torch.empty_like(q) + k_g = torch.empty_like(k) + + grid = (NK, triton.cdiv(T, BT), B * H) + prepare_qg_kg[grid]( + q, + k, + g, + q_g, + k_g, + scale, + T=T, + K=K, + BT=BT, + BK=BK, + num_warps=1 + ) + + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float, requires_grad=False) + else: + final_state = None + # the bug still exists even for Triton 2.2 on H100 GPUs + # so we always enable initial checks + CHECK = True + if version.parse(triton.__version__) < version.parse('2.2.0'): + import warnings + warnings.warn( + "Triton<2.2.0 detected for running this kernel, " + "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + "that lead to significant precision loss. " + "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + ) + CHECK = True + + grid = (NV, NK, B * H) + fused_chunk_gla_fwd_kernel[grid]( + q_g, k_g, v, g, o, initial_state, final_state, + T=T, + B=B, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0) + + # intra-chunk + chunk_size = 16 + num_chunk = T // chunk_size + v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk) + BK = min(K, 64) + NK = triton.cdiv(K, BK) + A = q.new_empty(NK, B, H, triton.cdiv(T, BT), BT, BT) + grid = (NK, triton.cdiv(T, BT), B * H) + fwd_inner_chunk[grid]( + q, k, g, A, + scale, + B=B, + H=H, + T=T, + K=K, + BT=BT, + BK=BK, + num_stages=3, + num_warps=4 + ) + A = A.sum(0) + o2 = A @ v2 + o2 = rearrange(o2, 'b h n c d -> b h (n c) d') + # combine inner and inter + o.add_(o2) + ctx.save_for_backward(q, k, v, g_org, A, initial_state) + ctx.CHECK = CHECK + return o.to(v), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht=None): + q, k, v, g_org, A, initial_state = ctx.saved_tensors + B, H, T, K, V = *k.shape, v.shape[-1] + scale = ctx.scale + + # recomputation + # inter-chunk + BT = 16 # chunk_size + g = chunk_local_cumsum(g_org, chunk_size=BT) + BK, BV = min(K, 64), min(V, 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + q_g = torch.empty_like(q) + k_g = torch.empty_like(k) + grid = (NK, triton.cdiv(T, BT), B * H) + prepare_qg_kg[grid]( + q, + k, + g, + q_g, + k_g, + scale, + T=T, + K=K, + BT=BT, + BK=BK, + num_warps=1 + ) + + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 2 + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + + grid = (NV, NK, B * H) + + fused_chunk_gla_bwd_kernel[grid]( + q_g, + k_g, + v, + g, + do, + dq, + dk, + dv, + initial_state, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=ctx.CHECK, + num_warps=num_warps, + num_stages=num_stages, + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + + # intra chunk + NT = T // BT + v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=NT) + do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=NT) + dA2 = (do2 @ v2.transpose(-2, -1)) * scale + dv2 = A.transpose(-1, -2) @ do2 + dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=NT) + + BK = min(triton.next_power_of_2(K), 16) + NK = triton.cdiv(K, BK) + dk2 = torch.empty_like(k) + dq2 = torch.empty_like(q) + + grid = (NK, NT, B * H) + bwd_inner_chunk[grid]( + q, k, g, + dA2, + dq2, + dk2, + T=T, + K=K, + BT=BT, + BK=BK, + num_warps=1, + num_stages=3 + ) + + BK = min(triton.next_power_of_2(K), 32) + NK = triton.cdiv(K, BK) + dg = torch.empty_like(g, dtype=torch.float32) + grid = (NK, triton.cdiv(T, BT), B * H) + bwd_decay_global_cumsum[grid]( + dq2, + dq, + dk2, + dk, + q, + k, + g, + dg, + T=T, + K=K, + BT=BT, + BK=BK, + num_warps=1, + num_stages=1 + ) + dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT) + + def rev_cumsum_exclusive(x): + cumsum_x = x.cumsum(-2) + rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x + return rev_cumsum_x + + rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :]) + dg.add_(rev_cumsum_dg.unsqueeze(-2)) + dv.add_(dv2) + dg = rearrange(dg, 'b h n c d -> b h (n c) d') + + return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None + + +def ceildiv(a, b): + return -(a // -b) + + +def pad(x, chunk_size=16): + T = x.shape[-2] + padded_seq_len = ceildiv(T, chunk_size) * chunk_size + if x.shape[-2] % chunk_size != 0: + x = F.pad(x, (0, 0, 0, padded_seq_len - T)) + return x + + +def fused_chunk_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: int = -1, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale == -1: + scale = q.shape[-1] ** -0.5 + if not head_first: + q, k, v, g = map(lambda x: x.transpose(1, 2), (q, k, v, g)) + seq_len = q.shape[-2] + q, k, v, g = map(lambda x: pad(x), [q, k, v, g]) + o, final_state = FusedChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state) + o = o[..., :seq_len, :].contiguous() + if not head_first: + o = o.transpose(1, 2) + return o, final_state diff --git a/fla/ops/gla/fused_recurrent.py b/fla/ops/gla/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..d211541d789809ee89a688b380626026b1dbed88 --- /dev/null +++ b/fla/ops/gla/fused_recurrent.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch + +from fla.ops.common.fused_recurrent import fused_recurrent + + +def fused_recurrent_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gk: Optional[torch.Tensor] = None, + gv: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + gk (torch.Tensor): + Forget gates of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` applied to keys. + gv (torch.Tensor): + Forget gates of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` applied to values. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + reverse (Optional[bool]): + If `True`, process the state passing in reverse order. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gla import fused_recurrent_gla + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = fused_recurrent_gla(q, k, v, g, + initial_state=h0, + output_final_state=True, + head_first=False) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_recurrent_gla(q, k, v, g, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens, + head_first=False) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = fused_recurrent( + q=q, + k=k, + v=v, + g=None, + gk=gk, + gv=gv, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + cu_seqlens=cu_seqlens, + head_first=head_first + ) + return o, final_state diff --git a/fla/ops/gla/naive.py b/fla/ops/gla/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..507a7395c0c28b0a9c54008e1735098cd3fbdc85 --- /dev/null +++ b/fla/ops/gla/naive.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch + + +def ceildiv(a, b): + return -(a // -b) + + +def naive_recurrent_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gk: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False +): + dtype = q.dtype + q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk)) + B, H, T, K, V = *q.shape, v.shape[-1] + o = torch.zeros_like(v) + scale = K ** -0.5 + + h = q.new_zeros(B, H, K, V, dtype=torch.float32) + if initial_state is not None: + h += initial_state.float() + + for i in range(T): + q_i = q[:, :, i] * scale + k_i = k[:, :, i] + v_i = v[:, :, i] + gk_i = gk[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + h = h * gk_i[..., None] + kv_i + o[:, :, i] = (q_i[..., None] * h).sum(-2) + + if not output_final_state: + h = None + return o.to(dtype), h diff --git a/fla/ops/gsa/__init__.py b/fla/ops/gsa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed8a88014ddfc3143e67d3a48c38a54b75d7f3d6 --- /dev/null +++ b/fla/ops/gsa/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_gsa +from .fused_recurrent import fused_recurrent_gsa + +__all__ = [ + 'chunk_gsa', + 'fused_recurrent_gsa' +] diff --git a/fla/ops/hgrn/__init__.py b/fla/ops/hgrn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2012c3c15f125271df225ce755ed3b2dbe01a83 --- /dev/null +++ b/fla/ops/hgrn/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_hgrn +from .fused_recurrent import fused_recurrent_hgrn + +__all__ = [ + 'chunk_hgrn', + 'fused_recurrent_hgrn' +] diff --git a/fla/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc b/fla/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9e9ff339795f271bf4e3966731c812ff35e1ba6 Binary files /dev/null and b/fla/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla/ops/hgrn/chunk.py b/fla/ops/hgrn/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..6847622ebfb071230720b7ae6669f5412a42470b --- /dev/null +++ b/fla/ops/hgrn/chunk.py @@ -0,0 +1,282 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# this function implements the chunkwise form of HGRN, inspired by +# [Volodymyr Kyrylov in his blog post](https://proger.github.io/posts/scan/chunk.html) +# also refer to the `accelerated-scan` lib: https://github.com/proger/accelerated-scan + +# from tests on H800, with B, D = 16, 128, we see that the chunk can be greatly faster than the recurrent: +# +# Performance: +# seq_len chunk recurrent chunk_bwd recurrent_bwd +# 0 128.0 0.039360 0.061056 0.312160 0.205008 +# 1 256.0 0.045824 0.123712 0.308784 0.297696 +# 2 512.0 0.058688 0.241952 0.310720 0.626528 +# 3 1024.0 0.088288 0.476992 0.313184 1.333152 +# 4 2048.0 0.169472 0.943264 0.452464 2.724864 +# 5 4096.0 0.329920 1.886144 0.881600 5.551520 +# 6 8192.0 0.647872 3.755040 1.740496 11.117184 +# 7 16384.0 1.272064 7.520576 3.446608 22.362528 + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp +from fla.utils import input_guard + + +@triton.autotune( + configs=[ + triton.Config({'BD': 32}, num_warps=1), + triton.Config({'BD': 32}, num_warps=2), + triton.Config({'BD': 32}, num_warps=4), + triton.Config({'BD': 32}, num_warps=8), + triton.Config({'BD': 64}, num_warps=1), + triton.Config({'BD': 64}, num_warps=2), + triton.Config({'BD': 64}, num_warps=4), + triton.Config({'BD': 64}, num_warps=8), + triton.Config({'BD': 128}, num_warps=1), + triton.Config({'BD': 128}, num_warps=2), + triton.Config({'BD': 128}, num_warps=4), + triton.Config({'BD': 128}, num_warps=8), + ], + key=['D'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_hgrn_fwd_kernel_h( + x, + g, + gc, + o, + h0, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr +): + i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + p_x = x + i_b * T * D + i_t * BT * D + o_d + p_g = g + i_b * T * D + i_t * BT * D + o_d + p_gc = gc + i_b * T * D + i_t * BT * D + o_d + p_o = o + i_b * T * D + i_t * BT * D + o_d + + b_h = tl.zeros([BD], dtype=tl.float32) + b_gc = tl.zeros([BD], dtype=tl.float32) + if USE_INITIAL_STATE: + if i_t == 0: + b_h += tl.load(h0 + i_b * D + o_d, mask=mask, other=0).to(tl.float32) + for i in range(0, BT): + mask_t = mask & ((i_t * BT + i) < T) + b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32) + b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32) + b_h = exp(b_g) * b_h + b_x + b_gc = b_gc + b_g + tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t) + tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t) + + p_x += D + p_g += D + p_gc += D + p_o += D + + +@triton.jit(do_not_specialize=['T']) +def chunk_hgrn_fwd_kernel_o( + gc, + o, + s_b, + s_t, + s_d, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr +): + i_d, i_b = tl.program_id(0), tl.program_id(1) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + for i_t in range(1, tl.cdiv(T, BT)): + p_gc = tl.make_block_ptr(gc + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_o = tl.make_block_ptr(o + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + + # [BD,] + b_h0 = tl.load(o + i_b * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32) + # [BT, BD] + b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) + b_o = b_o + exp(b_gc) * b_h0[None, :] + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=num_warps) + for BD in [32, 64, 128] + for num_warps in [1, 2, 4, 8] + ], + key=['D'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_hgrn_bwd_kernel_h( + g, + gc, + dx, + do, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr +): + i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + BC = min(BT, T - i_t * BT) + NT = tl.num_programs(1) + + p_g = g + (i_b * T + i_t * BT + BC - 1) * D + o_d + p_gc = gc + (i_b * T + i_t * BT + BC - 1) * D + o_d + p_dx = dx + (i_b * T + i_t * BT + BC - 1) * D + o_d + p_do = do + (i_b * T + i_t * BT + BC - 1) * D + o_d + + if i_t == NT - 1: + b_gc = tl.zeros([BD], dtype=tl.float32) + else: + b_gc = tl.load(g + (i_b * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32) + b_dh = tl.zeros([BD], dtype=tl.float32) + for _ in range(BC - 1, -1, -1): + tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask) + + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32) + + b_gc = b_gc + b_g + b_dh = b_dh + b_do + b_dx = b_dh + b_dh = b_dh * exp(b_g) + + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) + + p_g -= D + p_gc -= D + p_dx -= D + p_do -= D + + +@triton.jit(do_not_specialize=['T']) +def chunk_hgrn_bwd_kernel_o( + g, + gc, + o, + dx, + dg, + s_b, + s_t, + s_d, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr +): + i_d, i_b = tl.program_id(0), tl.program_id(1) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + for i_t in range(tl.cdiv(T, BT) - 1, -1, -1): + p_g = tl.make_block_ptr(g + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_gc = tl.make_block_ptr(gc + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_o = tl.make_block_ptr(o + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0)) + p_dx = tl.make_block_ptr(dx + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + + # [BD,] + mask_t = mask & ((i_t + 1) * BT < T) + b_ht = tl.load(dx + i_b * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32) + # [BT, BD] + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) + b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32) + + b_dx = b_dx + exp(b_gc) * b_ht[None, :] + b_dg = b_o * b_dx * exp(b_g) + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkHGRNFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward(ctx, x, g, initial_state=None, output_final_state=False): + B, T, D = x.shape + BT, BD = 128, min(64, triton.next_power_of_2(D)) + num_warps = 8 if BD == 64 else 4 + + gc = torch.empty_like(g, dtype=torch.float) + o = torch.empty_like(x, dtype=torch.float) + def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B) + chunk_hgrn_fwd_kernel_h[grid]( + x, g, gc, o, initial_state, + T=T, D=D, BT=BT, + USE_INITIAL_STATE=initial_state is not None + ) + def grid(meta): return (triton.cdiv(D, meta['BD']), B) + chunk_hgrn_fwd_kernel_o[grid]( + gc, o, + o.stride(-3), o.stride(-2), o.stride(-1), + T=T, D=D, BT=BT, BD=BD, + num_warps=num_warps + ) + final_state = None + if output_final_state: + final_state = o[:, -1].clone() + o = o.to(x.dtype) + ctx.save_for_backward(g, o, initial_state) + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht=None): + g, o, initial_state = ctx.saved_tensors + B, T, D = do.shape + BT, BD = 128, min(64, triton.next_power_of_2(D)) + num_warps = 8 if BD == 64 else 4 + + gc = torch.empty_like(g, dtype=torch.float) + dx = torch.empty_like(o, dtype=torch.float) + def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B) + chunk_hgrn_bwd_kernel_h[grid]( + g, gc, dx, do, + T=T, D=D, BT=BT + ) + + dg = torch.empty_like(g, dtype=torch.float) + def grid(meta): return (triton.cdiv(D, meta['BD']), B) + chunk_hgrn_bwd_kernel_o[grid]( + g, gc, o, dx, dg, + o.stride(-3), o.stride(-2), o.stride(-1), + T=T, D=D, BT=BT, BD=BD, + num_warps=num_warps + ) + if initial_state is not None: + dg[:, 0] = (initial_state * dx[:, 0] * g[:, 0].float().exp()).to(dg.dtype) + + return dx.to(o.dtype), dg, None, None + + +@torch.compiler.disable +def chunk_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + return ChunkHGRNFunction.apply(x, g, initial_state, output_final_state) diff --git a/fla/ops/hgrn/fused_recurrent.py b/fla/ops/hgrn/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a70f0c7e4e12fc3648f1f0c19fc946fb85eb97 --- /dev/null +++ b/fla/ops/hgrn/fused_recurrent.py @@ -0,0 +1,308 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp +from fla.utils import input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=num_warps) + for BD in [32, 64, 128] + for num_warps in [1, 2, 4, 8] + ], + key=['D'] +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_hgrn_fwd_kernel( + x, + g, + o, + h0, + ht, + offsets, + T, + D: tl.constexpr, + BD: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_d, i_n = tl.program_id(0), tl.program_id(1) + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + p_x = x + bos * D + o_d + p_g = g + bos * D + o_d + p_o = o + bos * D + o_d + + b_h = tl.zeros([BD], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_n * D + o_d + b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32) + for _ in range(0, T): + b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32) + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_h = exp(b_g) * b_h + b_x + tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask) + + p_x += D + p_g += D + p_o += D + + if STORE_FINAL_STATE: + p_ht = ht + i_n * D + o_d + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=num_warps) + for BD in [32, 64, 128] + for num_warps in [1, 2, 4, 8] + ], + key=['D'] +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_hgrn_bwd_kernel( + g, + o, + h0, + dx, + dg, + do, + dht, + dh0, + offsets, + T, + D: tl.constexpr, + BD: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_d, i_n = tl.program_id(0), tl.program_id(1) + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + p_g = g + (bos + T - 1) * D + o_d + p_o = o + (bos + T - 2) * D + o_d + p_dx = dx + (bos + T - 1) * D + o_d + p_dg = dg + (bos + T - 1) * D + o_d + p_do = do + (bos + T - 1) * D + o_d + + b_dh = tl.zeros([BD], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = dht + i_n * D + o_d + b_dh += tl.load(p_dht, mask=mask, other=0).to(tl.float32) + + for i in range(T - 1, -1, -1): + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32) + if i > 0: + b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32) + elif USE_INITIAL_STATE: + b_o = tl.load(h0 + i_n * D + o_d, mask=mask, other=0).to(tl.float32) + else: + b_o = tl.zeros([BD], dtype=tl.float32) + + b_dh = b_dh + b_do + b_dx = b_dh + b_dh = b_dh * exp(b_g) + b_dg = b_dh * b_o + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask) + + p_g -= D + p_o -= D + p_dx -= D + p_dg -= D + p_do -= D + + if USE_INITIAL_STATE: + p_dh0 = dh0 + i_n * D + o_d + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask) + + +def fused_recurrent_hgrn_fwd( + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + offsets: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, D = x.shape + N = B if offsets is None else len(offsets) - 1 + + o = torch.empty_like(x) + final_state = x.new_empty(N, D) if output_final_state else None + + def grid(meta): return (triton.cdiv(D, meta['BD']), N) + fused_recurrent_hgrn_fwd_kernel[grid]( + x=x, + g=g, + o=o, + h0=initial_state, + ht=final_state, + offsets=offsets, + T=T, + D=D + ) + return o, final_state + + +def fused_recurrent_hgrn_bwd( + g: torch.Tensor, + o: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor = None, + initial_state: torch.Tensor = None, + offsets: Optional[torch.LongTensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, D = do.shape + N = B if offsets is None else len(offsets) - 1 + + dx = torch.empty_like(o, dtype=torch.float) + dg = torch.empty_like(g, dtype=torch.float) + dh0 = torch.empty_like(initial_state, dtype=torch.float) if initial_state is not None else None + def grid(meta): return (triton.cdiv(D, meta['BD']), N) + fused_recurrent_hgrn_bwd_kernel[grid]( + g=g, + o=o, + h0=initial_state, + dx=dx, + dg=dg, + do=do, + dht=dht, + dh0=dh0, + offsets=offsets, + T=T, + D=D + ) + return dx, dg, dh0 + + +class FusedRecurrentHGRNFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + offsets: Optional[torch.LongTensor] = None + ): + o, ht = fused_recurrent_hgrn_fwd( + x=x, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets + ) + ctx.save_for_backward(g, o, initial_state) + ctx.offsets = offsets + return o, ht + + @staticmethod + @input_guard + def backward(ctx, do, dht=None): + g, o, initial_state = ctx.saved_tensors + offsets = ctx.offsets + + dx, dg, dh0 = fused_recurrent_hgrn_bwd( + g=g, + o=o, + do=do, + dht=dht, + initial_state=initial_state, + offsets=offsets + ) + return dx, dg, dh0, None, None + + +@torch.compiler.disable +def fused_recurrent_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + x (torch.Tensor): + inputs of shape `[B, T, D]. + g (torch.Tensor): + Forget gates of shape `[B, T, D]`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, D]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, D]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, D]`. + final_state (torch.Tensor): + Final state of shape `[N, D]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.hgrn import fused_recurrent_hgrn + # inputs with equal lengths + >>> B, T, D = 4, 2048, 512 + >>> x = torch.randn(B, T, D, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, D, device='cuda')) + >>> h0 = torch.randn(B, D, device='cuda') + >>> o, ht = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> x, g = map(lambda x: rearrange(x, 'b t d -> 1 (b t) d'), (x, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = x.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True, cu_seqlens=cu_seqlens) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + return FusedRecurrentHGRNFunction.apply( + x, + g, + initial_state, + output_final_state, + cu_seqlens + ) diff --git a/fla/ops/hgrn/naive.py b/fla/ops/hgrn/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcddc1967b31c5181d330704c7b5ff2127e9d68 --- /dev/null +++ b/fla/ops/hgrn/naive.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch + + +def naive_recurrent_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False +) -> torch.Tensor: + dtype = x.dtype + x, g = map(lambda i: i.float(), (x, g)) + B, T, D = x.shape + + h = torch.zeros(B, D, dtype=torch.float, device=x.device) + o = torch.zeros_like(x) + + final_state = None + if initial_state is not None: + h += initial_state + + for i in range(T): + h = g[:, i].exp() * h + x[:, i] + o[:, i] = h + + if output_final_state: + final_state = h + return o.to(dtype), final_state + + +def naive_chunk_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False, + chunk_size: int = 64 +) -> torch.Tensor: + dtype = x.dtype + x, g = map(lambda i: i.float(), (x, g)) + B, T, D = x.shape + + gc = g.view(B, chunk_size, D).cumsum(-2).view_as(g) + h = torch.zeros(B, D, dtype=torch.float, device=x.device) + o = torch.zeros_like(x) + + final_state = None + if initial_state is not None: + h += initial_state + + for i in range(0, T, chunk_size): + hp = h + h = torch.zeros(B, D, dtype=torch.float, device=x.device) + for j in range(i, i + chunk_size): + h = g[:, j].exp() * h + x[:, j] + o[:, j] = hp * gc[:, j].exp() + h + h = o[:, j].clone() + + if output_final_state: + final_state = h + return o.to(dtype), final_state diff --git a/fla/ops/lightning_attn/__init__.py b/fla/ops/lightning_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c28c3af59f61d32cbb68a63926ac67fa2bb73447 --- /dev/null +++ b/fla/ops/lightning_attn/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_lightning_attn +from .fused_recurrent import fused_recurrent_lightning_attn + +__all__ = [ + 'chunk_lightning_attn', + 'fused_recurrent_lightning_attn' +] diff --git a/fla/ops/linear_attn/__init__.py b/fla/ops/linear_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a981054aaf9ab98b30ac08fa525bde73e68e7e4 --- /dev/null +++ b/fla/ops/linear_attn/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_linear_attn +from .fused_chunk import fused_chunk_linear_attn +from .fused_recurrent import fused_recurrent_linear_attn + +__all__ = [ + 'chunk_linear_attn', + 'fused_chunk_linear_attn', + 'fused_recurrent_linear_attn' +] diff --git a/fla/ops/linear_attn/__pycache__/__init__.cpython-312.pyc b/fla/ops/linear_attn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04bb633bdd6c1c65e6a1678dcde72573ad74b563 Binary files /dev/null and b/fla/ops/linear_attn/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/ops/linear_attn/__pycache__/fused_chunk.cpython-312.pyc b/fla/ops/linear_attn/__pycache__/fused_chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0421d26f28a3482ce7188b7d358b0039857ea06 Binary files /dev/null and b/fla/ops/linear_attn/__pycache__/fused_chunk.cpython-312.pyc differ diff --git a/fla/ops/linear_attn/chunk.py b/fla/ops/linear_attn/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..8283e707923389e5c0f4e8294f7c491277f7243d --- /dev/null +++ b/fla/ops/linear_attn/chunk.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch + +from fla.ops.linear_attn.utils import normalize_output +from fla.ops.simple_gla import chunk_simple_gla + + +@torch.compiler.disable +def chunk_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + normalize: bool = True, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + scale (Optional[int]): + Scale factor for the linear attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[B, H, K, V]`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. + normalize (bool): + Whether to normalize the output. Default: `True`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + final_state (torch.Tensor): + Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None` + """ + + if scale is None: + scale = k.shape[-1] ** -0.5 + + o, final_state = chunk_simple_gla( + q=q, + k=k, + v=v, + scale=scale, + g=None, + initial_state=initial_state, + output_final_state=output_final_state, + head_first=head_first + ) + if normalize: + o = normalize_output(q * scale, k, o) + return o, final_state diff --git a/fla/ops/linear_attn/fused_chunk.py b/fla/ops/linear_attn/fused_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..bfcc1212a534aa3debb5bb2d1cdbbce5f95f06e4 --- /dev/null +++ b/fla/ops/linear_attn/fused_chunk.py @@ -0,0 +1,318 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl +from packaging import version + +from fla.ops.linear_attn.utils import normalize_output +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.jit +def fused_chunk_linear_attn_fwd_kernel( + q, # query [B, H, T, K] + k, # key [B, H, T, V] + v, # value [B, H, T, V] + o, # output [B, H, T, V] + h0, + ht, + scale, + B, # batch size + H, # H + T, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BT, BV] + b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + if CHECK and i == 0: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False) + else: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def fused_chunk_linear_attn_bwd_kernel( + q, # query [B, H, T, K] + k, # key [B, H, T, V] + v, # value [B, H, T, V] + do, # gradient of output [B, H, T, V] + dq, # gradient of query [NV, B, H, T, K] + dk, # gradient of key [NV, B, H, T, K] + dv, # gradient of value [NK, B, H, T, V] + h0, # initial state of the chunk [B, H, K, V] + scale, # K ** -0.5 + B, # B + H, # H + T, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + + m_s = o_i[:, None] >= o_i[None, :] + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [V, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0) + # [BT, BK] + b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False) + # [BV, BK] + if CHECK and i == 0: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + else: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + m_s = o_i[:, None] <= o_i[None, :] + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * T*K, (T, K), (K, 1), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0).to(b_q.dtype) + # [BT, BT] + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype) + # [BT, BK] + b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) + # [BT, BV] + b_dv = tl.dot(b_s, b_do, allow_tf32=False) + if CHECK and i == 1: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + else: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkLinearAttentionFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, scale, initial_state, output_final_state): + B, H, T, K, V = *k.shape, v.shape[-1] + BT = 64 + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_warps = 4 + num_stages = 1 + + o = q.new_empty(NK, B, H, T, V) + final_state = q.new_empty(B, H, K, V, dtype=torch.float) if output_final_state else None + # the bug still exists even for Triton 2.2 on H100 GPUs + # so we always enable initial checks + CHECK = True + if version.parse(triton.__version__) < version.parse('2.2.0'): + import warnings + warnings.warn( + "Triton<2.2.0 detected for running this kernel, " + "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + "that lead to significant precision loss. " + "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + ) + CHECK = True + + grid = (NV, NK, B * H) + fused_chunk_linear_attn_fwd_kernel[grid]( + q, k, v, o, initial_state, final_state, + scale, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + o = o.sum(0) if NK > 1 else o[0] + + ctx.save_for_backward(q, k, v, initial_state) + ctx.scale = scale + ctx.CHECK = CHECK + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht=None): + q, k, v, initial_state = ctx.saved_tensors + B, H, T, K, V = *k.shape, v.shape[-1] + scale = ctx.scale + + BT = 64 + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_warps = 4 + num_stages = 1 + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + grid = (NV, NK, B * H) + + fused_chunk_linear_attn_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, initial_state, + scale, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=ctx.CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None + + +def fused_chunk_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = True, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + scale (Optional[int]): + Scale factor for linear attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[B, H, K, V]`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. + normalize (bool): + Whether to normalize the output. Default: `True`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + final_state (torch.Tensor): + Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None` + """ + if scale is None: + scale = q.shape[-1] ** -0.5 + if not head_first: + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) + if normalize: + o = normalize_output(q * scale, k, o) + if not head_first: + o = o.transpose(1, 2) + return o, final_state diff --git a/fla/ops/linear_attn/fused_recurrent.py b/fla/ops/linear_attn/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..b50b8c7bfb470b69be5ba3327de24ed07ffa974d --- /dev/null +++ b/fla/ops/linear_attn/fused_recurrent.py @@ -0,0 +1,251 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.linear_attn.utils import normalize_output +from fla.utils import input_guard + + +@triton.jit +def fused_recurrent_linear_attn_fwd_kernel( + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + o, # output [B, H, L, V] + h0, + ht, # final hidden state [B, H, K, V] + + s_k_h, # stride size: L * K + s_v_h, # stride size: L * V + + scale, + B, # batch size + H, # H + T, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < K + mask_bv = (i_v * BV + tl.arange(0, BV)) < V + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + + b_h += b_k[None, :] * b_v[:, None] + b_o = b_h * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv) + + p_q += K + p_k += K + p_o += V + p_v += V + + if STORE_FINAL_STATE: + p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_linear_attn_bwd_kernel( + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + + do, # gradient of output [B, H, L, V] + dq, # gradient of query [NV, B, H, L, K] + dk, # gradient of key [NV, B, H, L, K] + dv, # gradient of value [NK, B, H, L, V] + h0, # initial hidden state initialization [B, H, K, V] + + s_k_h, # stride size: L * K + s_v_h, # stride size: L * V + scale, # K ** -0.5 + + B, # B + H, # H + T, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + + p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + mask_bk = i_k * BK + tl.arange(0, BK) < K + mask_bv = i_v * BV + tl.arange(0, BV) < V + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + mask_kv = mask_bk[:, None] & mask_bv[None, :] + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + + b_h += b_k[:, None] * b_v[None, :] + _d_q = b_h * b_do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + p_k += K + p_do += V + p_v += V + p_dq += K + + # sync threads + tl.debug_barrier() + + p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + d_h = tl.zeros([BK, BV], dtype=tl.float32) + + for _ in range(T): + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + d_h += b_q[:, None] * b_do[None, :] + d_k = tl.sum(d_h * b_v[None, :], axis=1) + d_v = tl.sum(d_h * b_k[:, None], axis=0) + + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + + p_do -= V + p_q -= K + p_k -= K + p_v -= V + p_dk -= K + p_dv -= V + + +class FusedRecurrentLinearAttentionFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward(ctx, q, k, v, scale, initial_state=None, output_final_state=False): + B, H, T, K = q.shape + V = v.shape[-1] + + BK, BV = min(K, 32), min(V, 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_warps = 1 + num_stages = 1 + + o = q.new_empty(NK, B, H, T, V) + final_state = q.new_empty(B, H, K, V) if output_final_state else None + + grid = (NV, NK, B * H) + fused_recurrent_linear_attn_fwd_kernel[grid]( + q, k, v, o, initial_state, final_state, + q.stride(1), + v.stride(1), scale, + B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, initial_state) + ctx.scale = scale + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht=None): + q, k, v, initial_state = ctx.saved_tensors + B, H, T, K = q.shape + V = v.shape[-1] + scale = ctx.scale + + BK, BV = min(K, 32), min(V, 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_warps = 1 + num_stages = 1 + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + grid = (NV, NK, B * H) + + fused_recurrent_linear_attn_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, initial_state, + q.stride(1), + v.stride(1), + scale, + B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq, dk, dv, None, None, None + + +def fused_recurrent_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = False, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale is None: + scale = q.shape[-1] ** -0.5 + if not head_first: + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + o, final_state = FusedRecurrentLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) + if normalize: + o = normalize_output(q * scale, k, o) + if not head_first: + o = o.transpose(1, 2) + return o, final_state diff --git a/fla/ops/linear_attn/naive.py b/fla/ops/linear_attn/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..b6ecf2718fcac8eef80f445ed02b95f36329f3c4 --- /dev/null +++ b/fla/ops/linear_attn/naive.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +from typing import Optional, Tuple + +import torch +from einops import rearrange + +from fla.ops.linear_attn.utils import normalize_output + + +def naive_chunk_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + normalize: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale is None: + scale = q.shape[-1] ** -0.5 + chunk_size = 64 + q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) * scale + k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) + v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) + kv = k.transpose(-1, -2) @ v + kv = kv.cumsum(2) + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + inter = q @ kv + intra = (( + q @ k.transpose(-1, -2)).masked_fill_( + torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), + 0 + )) @ v + o = inter + intra + if normalize: + o = normalize_output(q * scale, k, o) + return rearrange(o, 'b h n c d -> b h (n c) d') diff --git a/fla/ops/linear_attn/utils.py b/fla/ops/linear_attn/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b444376833f5d512af6fc2db387db75a43a92e5d --- /dev/null +++ b/fla/ops/linear_attn/utils.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- + +import torch + + +@torch.jit.script +def normalize_output(q, k, o): + k = k.cumsum(-2) + z = (q * k).sum(-1, keepdim=True) + return o / (z + 1e-10) diff --git a/fla/ops/nsa/__init__.py b/fla/ops/nsa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..941a1be41e1650961af0d28e64837421826ffab2 --- /dev/null +++ b/fla/ops/nsa/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .naive import naive_nsa +from .parallel import parallel_nsa + +__all__ = [ + 'naive_nsa', + 'parallel_nsa' +] diff --git a/fla/ops/nsa/__pycache__/utils.cpython-312.pyc b/fla/ops/nsa/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b823dc06ea9a7c40e262d5daa20a19f7c863f58c Binary files /dev/null and b/fla/ops/nsa/__pycache__/utils.cpython-312.pyc differ diff --git a/fla/ops/nsa/parallel.py b/fla/ops/nsa/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..7e89d964c7357ceeabaaeb9500849ce6cbdecfad --- /dev/null +++ b/fla/ops/nsa/parallel.py @@ -0,0 +1,1435 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import warnings +from typing import Optional, Union + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from fla.ops.common.utils import prepare_chunk_indices, prepare_chunk_offsets, prepare_lens, prepare_token_indices +from fla.ops.nsa.utils import _bitonic_merge +from fla.ops.utils import mean_pooling +from fla.ops.utils.op import exp, log +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func +except ImportError: + warnings.warn( + "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`", + category=ImportWarning + ) + flash_attn_func = None + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit +def parallel_nsa_compression_fwd_kernel( + q, + k, + v, + o, + lse, + scale, + offsets, + token_indices, + chunk_offsets, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + boc = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + boc = i_b * tl.cdiv(T, BS) + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + # the number of compression representations in total + TC = tl.cdiv(T, BS) + # the number of compression representations required to iterate over + # incomplete compression blocks are not included + NC = (i_t + 1) // BS + + p_o = tl.make_block_ptr(o + (bos + i_t) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + # [G, BV] + b_o = tl.zeros([G, BV], dtype=tl.float32) + # max scores for the current block + b_m = tl.full([G], float('-inf'), dtype=tl.float32) + # lse = log(acc) + m + b_acc = tl.zeros([G], dtype=tl.float32) + + for i_c in range(0, NC, BC): + o_c = i_c + tl.arange(0, BC) + + p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) + p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c, i_v * BV), (BC, BV), (1, 0)) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [G, BC] + b_s = tl.dot(b_q, b_k) + b_s = tl.where((o_c < NC)[None, :], b_s, float('-inf')) + + # [G] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [G, BC] + b_p = exp(b_s - b_m[:, None]) + # [G] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + + # [G, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) + + b_mp = b_m + if NC == 0: + b_lse = tl.zeros([G], dtype=tl.float32) + else: + b_o = b_o / b_acc[:, None] + b_lse = b_m + log(b_acc) + + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + if i_v == 0: + tl.store(lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G), b_lse.to(lse.dtype.element_ty)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_nsa_compression_bwd_kernel_dq( + q, + k, + v, + lse, + delta, + do, + dq, + scale, + offsets, + token_indices, + chunk_offsets, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + boc = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + boc = i_b * tl.cdiv(T, BS) + + q += (bos + i_t) * HQ*K + do += (bos + i_t) * HQ*V + lse += (bos + i_t) * HQ + delta += (bos + i_t) * HQ + dq += (i_v * B * T + bos + i_t) * HQ*K + + p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do = tl.make_block_ptr(do, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse = lse + i_h * G + tl.arange(0, G) + p_delta = delta + i_h * G + tl.arange(0, G) + + # the number of compression representations in total + TC = tl.cdiv(T, BS) + # the number of compression representations required to iterate over + # incomplete compression blocks are not included + NC = (i_t + 1) // BS + + # [G, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [G] + b_lse = tl.load(p_lse) + b_delta = tl.load(p_delta) + + # [G, BK] + b_dq = tl.zeros([G, BK], dtype=tl.float32) + for i_c in range(0, NC, BC): + o_c = i_c + tl.arange(0, BC) + p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) + p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (V, TC), (1, H*V), (i_v * BV, i_c), (BV, BC), (0, 1)) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BC] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [G, BC] + b_s = tl.dot(b_q, b_k) + b_p = exp(b_s - b_lse[:, None]) + b_p = tl.where((o_c < NC)[None, :], b_p, 0) + + # [G, BV] @ [BV, BC] -> [G, BC] + b_dp = tl.dot(b_do, b_v) + b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) + # [G, BC] @ [BC, BK] -> [G, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_nsa_compression_bwd_kernel_dkv( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + offsets, + chunk_indices, + chunk_offsets, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_c = tl.load(chunk_indices + i_c * 2).to(tl.int32), tl.load(chunk_indices + i_c * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + boc = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + boc = i_b * tl.cdiv(T, BS) + + # the number of compression representations in total + TC = tl.cdiv(T, BS) + + p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_v * B*T*H + boc * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_v * B*T*H + boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0)) + + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BC, BV], dtype=tl.float32) + + for i in range(i_c * BC * BS, T): + o_c = i_c * BC + tl.arange(0, BC) + + p_q = tl.make_block_ptr(q + (bos + i) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do = tl.make_block_ptr(do + (bos + i) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse = lse + (bos + i) * HQ + i_h * G + tl.arange(0, G) + p_delta = delta + (bos + i) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [G] + b_lse = tl.load(p_lse) + b_delta = tl.load(p_delta) + # [BC, G] + b_s = tl.dot(b_k, tl.trans(b_q)) + b_p = exp(b_s - b_lse[None, :]) + b_p = tl.where((i >= max(0, (o_c + 1) * BS - 1))[:, None], b_p, 0) + # [BC, G] @ [G, BV] -> [BC, BV] + b_dv += tl.dot(b_p.to(b_do.dtype), b_do) + # [BC, BV] @ [BV, G] -> [BC, G] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BC, G] + b_ds = b_p * (b_dp - b_delta[None, :]) + # [BC, G] @ [G, BK] -> [BC, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK'], +) +@triton.jit +def parallel_nsa_kernel_topk( + q, + k, + lse, + scale, + block_indices, + offsets, + token_indices, + chunk_offsets, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + S: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + USE_OFFSETS: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + boc = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + boc = i_b * tl.cdiv(T, BS) + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + # the number of compression representations in total + TC = tl.cdiv(T, BS) + # the number of compression representations required to iterate over + # incomplete compression blocks are not included + NC = (i_t + 1) // BS + ################################ + # 1. lse computation + ################################ + if lse is not None: + b_lse = tl.load(lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)) + else: + # max scores for the current block + b_m = tl.full([G], float('-inf'), dtype=tl.float32) + # lse = log(acc) + m + b_acc = tl.zeros([G], dtype=tl.float32) + for i_c in range(0, NC, BC): + o_c = i_c + tl.arange(0, BC) + + p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + + # [G, BC] + b_s = tl.dot(b_q, b_k) + b_s = tl.where((o_c < NC)[None, :], b_s, float('-inf')) + + # [G] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [G, BC] + b_p = exp(b_s - b_m[:, None]) + # [G] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + + b_mp = b_m + if NC == 0: + b_lse = tl.zeros([G], dtype=tl.float32) + else: + b_lse = b_m + log(b_acc) + + ################################ + # 2. topk selection + ################################ + # [BC] + b_i = tl.full([BC], -1, dtype=tl.float32) + o_i = tl.zeros([BC], dtype=tl.int32) + m_i = tl.arange(0, BC) < BC//2 + for i_c in range(0, i_t // BS + 1, BC): + o_c = i_c + tl.arange(0, BC) + + p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [G, BC] + b_s = tl.dot(b_q, b_k) + b_s = tl.where((i_t // BS > o_c)[None, :], b_s, float('-inf')) + # [G, BC] + b_p = tl.where((i_t // BS == o_c)[None, :], float(1.0), exp(b_s - b_lse[:, None])) + # the importance scores of the current block + # [BC] + b_i, b_ip = tl.sum(b_p, 0), b_i + o_i, o_ip = tl.where(o_c <= i_t // BS, o_c + 1, 0), o_i + + n_dims: tl.constexpr = tl.standard._log2(b_i.shape[0]) + for i in tl.static_range(1, n_dims): + b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), i, 2, n_dims) + + if i_c != 0: + b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), n_dims, False, n_dims) + b_i_new = b_ip * m_i + b_i * (1 - m_i) + o_i_new = o_ip * m_i + o_i * (1 - m_i) + b_i, o_i = _bitonic_merge(b_i_new, o_i_new.to(tl.int32), n_dims, True, n_dims) + else: + b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), n_dims, True, n_dims) + + m_top = tl.arange(0, BC//S) == 0 + b_top = tl.sum(m_top[:, None] * tl.reshape(o_i - 1, [BC//S, S]), 0) + + p_b = tl.make_block_ptr(block_indices + (bos + i_t) * H*S, (H*S,), (1,), (i_h * S,), (S,), (0,)) + tl.store(p_b, b_top.to(p_b.dtype.element_ty)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit +def parallel_nsa_fwd_kernel( + q, + k, + v, + o, + lse, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + block_indices += (bos + i_t) * H*S + i_h * S + + if USE_BLOCK_COUNTS: + NS = tl.load(block_counts + (bos + i_t) * H + i_h) + else: + NS = S + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_o = tl.make_block_ptr(o + (bos + i_t) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse = lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o = tl.zeros([G, BV], dtype=tl.float32) + + b_m = tl.full([G], float('-inf'), dtype=tl.float32) + b_acc = tl.zeros([G], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [G, BS] + b_s = tl.dot(b_q, b_k) + b_s = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s, float('-inf')) + + # [G] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [G, BS] + b_p = exp(b_s - b_m[:, None]) + # [G] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + # [G, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) + + b_mp = b_m + b_o = b_o / b_acc[:, None] + b_m += log(b_acc) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse, b_m.to(p_lse.dtype.element_ty)) + + +@triton.heuristics({ + 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor) +}) +@triton.jit +def parallel_nsa_kernel_mask( + block_indices, + block_counts, + block_mask, + T: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + NS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr +): + i_t, i_b, i_hs = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h, i_s = i_hs // S, i_hs % S + + b_i = tl.load(block_indices + i_b * T * H * S + i_t * H * S + i_h * S + i_s) + if USE_BLOCK_COUNTS: + b_m = b_i * BS <= i_t and i_s < tl.load(block_counts + i_b * T * H + i_t * H + i_h) + else: + b_m = b_i * BS <= i_t + + if b_i < NS and b_i >= 0: + tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, b_m.to(block_mask.dtype.element_ty)) + + +@triton.jit +def parallel_nsa_bwd_kernel_preprocess( + o, + do, + delta, + B: tl.constexpr, + V: tl.constexpr +): + i_n = tl.program_id(0) + o_d = tl.arange(0, B) + m_d = o_d < V + + b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0) + b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32) + b_delta = tl.sum(b_o * b_do) + + tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor) +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_nsa_bwd_kernel_dq( + q, + k, + v, + lse, + delta, + do, + dq, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + q += (bos + i_t) * HQ*K + do += (bos + i_t) * HQ*V + lse += (bos + i_t) * HQ + delta += (bos + i_t) * HQ + dq += (i_v * B * T + bos + i_t) * HQ*K + block_indices += (bos + i_t) * H*S + i_h * S + + if USE_BLOCK_COUNTS: + NS = tl.load(block_counts + (bos + i_t) * H + i_h) + else: + NS = S + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + + p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do = tl.make_block_ptr(do, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse = lse + i_h * G + tl.arange(0, G) + p_delta = delta + i_h * G + tl.arange(0, G) + + # [G, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [G] + b_lse = tl.load(p_lse) + b_delta = tl.load(p_delta) + + # [G, BK] + b_dq = tl.zeros([G, BK], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [G, BS] + b_s = tl.dot(b_q, b_k) + b_p = exp(b_s - b_lse[:, None]) + b_p = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_p, 0) + + # [G, BV] @ [BV, BS] -> [G, BS] + b_dp = tl.dot(b_do, b_v) + b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) + # [G, BS] @ [BS, BK] -> [G, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + b_dq *= scale + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_nsa_bwd_kernel_dkv( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + block_mask, + offsets, + chunk_indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + M: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_v * B*T*H + bos * H + i_h) * K, (T, K), (H*K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) + + # [BS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BS, BK], dtype=tl.float32) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BS, BV], dtype=tl.float32) + + for i in range(i_s * BS, T): + b_m = tl.load(block_mask + (bos + i) * H*M + i_h * M + i_s) + if b_m: + p_q = tl.make_block_ptr(q + (bos + i) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do = tl.make_block_ptr(do + (bos + i) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse = lse + (bos + i) * HQ + i_h * G + tl.arange(0, G) + p_delta = delta + (bos + i) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [G] + b_lse = tl.load(p_lse) + b_delta = tl.load(p_delta) + # [BS, G] + b_s = tl.dot(b_k, tl.trans(b_q)) + b_p = exp(b_s - b_lse[None, :]) + b_p = tl.where((i >= (i_s * BS + tl.arange(0, BS)))[:, None], b_p, 0) + # [BS, G] @ [G, BV] -> [BS, BV] + b_dv += tl.dot(b_p.to(b_do.dtype), b_do) + # [BS, BV] @ [BV, G] -> [BS, G] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BS, G] + b_ds = b_p * (b_dp - b_delta[None, :]) + # [BS, G] @ [G, BK] -> [BS, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def parallel_nsa_compression_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_size: int, + scale: float, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, HQ, K, V = *q.shape, v.shape[-1] + H = k.shape[2] + G = HQ // H + BC = BS = block_size + if check_shared_mem('hopper', q.device.index): + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + else: + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, "The key dimension can not be larger than 256" + + chunk_offsets = prepare_chunk_offsets(offsets, BS) if offsets is not None else None + + grid = (T, NV, B * H) + o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + + parallel_nsa_compression_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o=o, + lse=lse, + scale=scale, + offsets=offsets, + token_indices=token_indices, + chunk_offsets=chunk_offsets, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BC=BC, + BS=BS, + BK=BK, + BV=BV, + ) + return o, lse + + +def parallel_nsa_compression_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + lse: torch.Tensor, + do: torch.Tensor, + block_size: int = 64, + scale: float = None, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, HQ, K, V = *q.shape, v.shape[-1] + H = k.shape[2] + G = HQ // H + BC = BS = block_size + BK = triton.next_power_of_2(K) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + NV = triton.cdiv(V, BV) + if offsets is not None: + lens = prepare_lens(offsets) + chunk_indices = torch.cat([torch.arange(n) for n in triton.cdiv(triton.cdiv(lens, BS), BC).tolist()]) + chunk_indices = torch.stack([chunk_indices.eq(0).cumsum(0) - 1, chunk_indices], 1).to(offsets) + chunk_offsets = prepare_chunk_offsets(offsets, BS) + NC = len(chunk_indices) + else: + chunk_indices, chunk_offsets = None, None + NC = triton.cdiv(triton.cdiv(T, BS), BC) + + delta = parallel_nsa_bwd_preprocess(o, do) + + dq = torch.empty(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device) + grid = (T, NV, B * H) + parallel_nsa_compression_bwd_kernel_dq[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dq=dq, + scale=scale, + offsets=offsets, + token_indices=token_indices, + chunk_offsets=chunk_offsets, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BC=BC, + BS=BS, + BK=BK, + BV=BV + ) + dq = dq.sum(0) + + dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device) + dv = torch.empty(v.shape, dtype=v.dtype, device=q.device) + + grid = (NV, NC, B * H) + parallel_nsa_compression_bwd_kernel_dkv[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dk=dk, + dv=dv, + offsets=offsets, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BC=BC, + BS=BS, + BK=BK, + BV=BV + ) + dk = dk.sum(0) + return dq, dk, dv + + +class ParallelNSACompressionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward( + ctx, + q, + k, + v, + block_size, + scale, + offsets + ): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o, lse = parallel_nsa_compression_fwd( + q=q, + k=k, + v=v, + block_size=block_size, + scale=scale, + offsets=offsets, + token_indices=token_indices + ) + ctx.save_for_backward(q, k, v, o, lse) + ctx.offsets = offsets + ctx.token_indices = token_indices + ctx.block_size = block_size + ctx.scale = scale + return o.to(q.dtype), lse + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, *args): + q, k, v, o, lse = ctx.saved_tensors + dq, dk, dv = parallel_nsa_compression_bwd( + q=q, + k=k, + v=v, + o=o, + lse=lse, + do=do, + block_size=ctx.block_size, + scale=ctx.scale, + offsets=ctx.offsets, + token_indices=ctx.token_indices + ) + return dq.to(q), dk.to(k), dv.to(v), None, None, None + + +def parallel_nsa_topk( + q: torch.Tensor, + k: torch.Tensor, + lse: torch.Tensor, + block_counts: Union[torch.LongTensor, int], + block_size: int = 64, + scale: float = None, + offsets: Optional[torch.LongTensor] = None, +) -> torch.LongTensor: + B, T, HQ, K = q.shape + H = k.shape[2] + G = HQ // H + S = block_counts if isinstance(block_counts, int) else block_counts.max().item() + S = triton.next_power_of_2(S) + # here we set BC = BS, but beware that they are actually decoupled + BC = BS = block_size + BK = triton.next_power_of_2(K) + + block_indices = torch.zeros(B, T, H, S, dtype=torch.int32, device=q.device) + token_indices = prepare_token_indices(offsets) if offsets is not None else None + chunk_offsets = prepare_chunk_offsets(offsets, BS) if offsets is not None else None + grid = (T, B * H) + parallel_nsa_kernel_topk[grid]( + q=q, + k=k, + lse=lse, + scale=scale, + block_indices=block_indices, + offsets=offsets, + token_indices=token_indices, + chunk_offsets=chunk_offsets, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + S=S, + BC=BC, + BS=BS, + BK=BK + ) + return block_indices + + +def parallel_nsa_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + block_size: int, + scale: float, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + if check_shared_mem('hopper', q.device.index): + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + else: + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, "The key dimension can not be larger than 256" + + grid = (T, NV, B * H) + o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + + parallel_nsa_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o=o, + lse=lse, + scale=scale, + block_indices=block_indices, + block_counts=block_counts, + offsets=offsets, + token_indices=token_indices, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + S=S, + BS=BS, + BK=BK, + BV=BV, + ) + return o, lse + + +def parallel_nsa_block_mask( + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + offsets: torch.LongTensor, + block_size: int, +): + B, T, H, S = block_indices.shape + BS = block_size + if offsets is not None: + NS = triton.cdiv(prepare_lens(offsets).max().item(), BS) + else: + NS = triton.cdiv(T, BS) + block_mask = torch.zeros(B, T, H, NS, dtype=torch.bool, device=block_indices.device) + + parallel_nsa_kernel_mask[(T, B, H*S)]( + block_indices=block_indices, + block_counts=block_counts, + block_mask=block_mask, + T=T, + H=H, + S=S, + BS=BS, + NS=NS + ) + return block_mask + + +def parallel_nsa_bwd_preprocess( + o: torch.Tensor, + do: torch.Tensor +): + V = o.shape[-1] + delta = torch.empty_like(o[..., 0], dtype=torch.float32) + parallel_nsa_bwd_kernel_preprocess[(delta.numel(),)]( + o=o, + do=do, + delta=delta, + B=triton.next_power_of_2(V), + V=V, + ) + return delta + + +def parallel_nsa_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + lse: torch.Tensor, + do: torch.Tensor, + block_indices: torch.Tensor, + block_counts: Union[torch.LongTensor, int], + block_size: int = 64, + scale: float = None, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + BK = triton.next_power_of_2(K) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + NV = triton.cdiv(V, BV) + + delta = parallel_nsa_bwd_preprocess(o, do) + + dq = torch.empty(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device) + grid = (T, NV, B * H) + parallel_nsa_bwd_kernel_dq[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dq=dq, + block_indices=block_indices, + block_counts=block_counts, + offsets=offsets, + token_indices=token_indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + S=S, + BS=BS, + BK=BK, + BV=BV + ) + dq = dq.sum(0) + + if offsets is not None: + chunk_indices = prepare_chunk_indices(offsets, BS) + NS = len(chunk_indices) + else: + chunk_indices = None + NS = triton.cdiv(T, BS) + + # [B, T, H, M] + block_mask = parallel_nsa_block_mask(block_indices, block_counts, offsets, block_size) + dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device) + dv = torch.empty(v.shape, dtype=v.dtype, device=q.device) + + grid = (NV, NS, B * H) + parallel_nsa_bwd_kernel_dkv[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dk=dk, + dv=dv, + block_mask=block_mask, + offsets=offsets, + chunk_indices=chunk_indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + M=block_mask.shape[-1], + BS=BS, + BK=BK, + BV=BV + ) + dk = dk.sum(0) + return dq, dk, dv + + +@torch.compile +class ParallelNSAFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_counts, block_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o, lse = parallel_nsa_fwd( + q=q, + k=k, + v=v, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + scale=scale, + offsets=offsets, + token_indices=token_indices + ) + ctx.save_for_backward(q, k, v, o, lse) + ctx.block_indices = block_indices + ctx.block_counts = block_counts + ctx.offsets = offsets + ctx.token_indices = token_indices + ctx.block_size = block_size + ctx.scale = scale + return o.to(q.dtype) + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + dq, dk, dv = parallel_nsa_bwd( + q=q, + k=k, + v=v, + o=o, + lse=lse, + do=do, + block_indices=ctx.block_indices, + block_counts=ctx.block_counts, + block_size=ctx.block_size, + scale=ctx.scale, + offsets=ctx.offsets, + token_indices=ctx.token_indices + ) + return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None + + +def parallel_nsa_compression( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_size: int = 64, + scale: float = None, + offsets: Optional[torch.LongTensor] = None +): + return ParallelNSACompressionFunction.apply( + q, + k, + v, + block_size, + scale, + offsets + ) + + +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_cmp: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: Optional[torch.LongTensor] = None, + block_counts: Union[torch.LongTensor, int] = 16, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_cmp (torch.Tensor): + Gate score for compressed attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + If `g_cmp` is provided, the passed `block_indices` will be ignored. + block_counts (Optional[Union[torch.LongTensor, int]]): + Number of selected blocks for each query. + If a tensor is provided, with shape `[B, T, H]` if `head_first=False` else `[B, H, T]`, + each query can select the same number of blocks. + If not provided, it will default to 16. + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + assert block_counts is not None, "block counts must be provided for selection" + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) + g_cmp, g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h') if x is not None else None, (g_cmp, g_slc, g_swa)) + if not isinstance(block_counts, int): + block_counts = rearrange(block_counts, 'b h t -> b t h') + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" + + k_cmp, v_cmp = mean_pooling(k, block_size, cu_seqlens), mean_pooling(v, block_size, cu_seqlens) + o_cmp, lse_cmp = None, None + if g_cmp is not None: + o_cmp, lse_cmp = parallel_nsa_compression( + q=q, + k=k_cmp, + v=v_cmp, + block_size=block_size, + scale=scale, + offsets=cu_seqlens + ) + if block_indices is not None: + warnings.warn("`block_indices` will be ignored when `g_cmp` is provided") + block_indices = parallel_nsa_topk( + q=q, + k=k_cmp, + lse=lse_cmp, + block_counts=block_counts, + block_size=block_size, + scale=scale, + offsets=cu_seqlens + ) + o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, scale, cu_seqlens) + o = o_slc * g_slc.unsqueeze(-1) + if o_cmp is not None: + o = torch.addcmul(o, o_cmp, g_cmp.unsqueeze(-1)) + if window_size > 0: + if cu_seqlens is not None: + max_seqlen = q.shape[1] + o_swa = flash_attn_varlen_func( + q.squeeze(0), k.squeeze(0), v.squeeze(0), + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(window_size-1, 0) + ).unsqueeze(0) + else: + o_swa = flash_attn_func( + q, k, v, + causal=True, + window_size=(window_size-1, 0) + ) + o = torch.addcmul(o, o_swa, g_swa.unsqueeze(-1)) + if head_first: + o = rearrange(o, 'b t h d -> b h t d') + return o diff --git a/fla/ops/nsa/utils.py b/fla/ops/nsa/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..73e54138b750a280c4f8edd04ca36ffb3f58705f --- /dev/null +++ b/fla/ops/nsa/utils.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# Implements argsort based on bitonic sort. +# [What is bitonic sort?](https://en.wikipedia.org/wiki/Bitonic_sorter) + +# Code adapted from https://github.com/triton-lang/triton/issues/3698#issuecomment-2067681396 + + +import triton +import triton.language as tl + +from fla.ops.utils.op import log2 + + +@triton.jit +def _compare_and_swap( + x, + ids, + flip, + i: tl.constexpr, + n_dims: tl.constexpr, +): + n_outer: tl.constexpr = x.numel >> n_dims + shape: tl.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)] + y = tl.reshape(x, shape) + # slice left/right with 'stride' 2**(n_dims - i - 1) + mask = tl.arange(0, 2)[None, :, None] + left = tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype) + right = tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(y.dtype) + left = tl.reshape(left, x.shape) + right = tl.reshape(right, x.shape) + # idx + y_idx = tl.reshape(ids, shape) + left_idx = tl.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape) + right_idx = tl.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape) + left_idx = tl.reshape(left_idx, x.shape).to(y_idx.dtype) + right_idx = tl.reshape(right_idx, x.shape).to(y_idx.dtype) + # actual compare-and-swap + idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + ileft = left.to(idtype, bitcast=True) + iright = right.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + + cond = (left > right) != flip + ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix)) + new_ids = ids ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(ids)) + return ret.to(x.dtype, bitcast=True), new_ids + + +@triton.jit +def _bitonic_merge( + x, + ids, + stage: tl.constexpr, + order: tl.constexpr, + n_dims: tl.constexpr, +): + n_outer: tl.constexpr = x.numel >> n_dims + tl.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if order == 2: + shape: tl.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage] + flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape) + else: + flip = order + # perform `stage` rounds of `compare-and-swap` + for i in tl.static_range(stage): + x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) + return x, ids + + +@triton.jit +def argsort( + x, + ids, + dim: tl.constexpr = None, + descending: tl.constexpr = tl.core.CONSTEXPR_0, +): + # handle default dimension or check that it is the most minor dim + _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim + tl.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") + # iteratively run bitonic merge-sort steps + n_dims: tl.constexpr = log2(x.shape[_dim]) + + for i in tl.static_range(1, n_dims + 1): + x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) + return x, ids diff --git a/fla/ops/rebased/__init__.py b/fla/ops/rebased/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec6a0cb31f7f635aa528cad753d5e19196a2028 --- /dev/null +++ b/fla/ops/rebased/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .parallel import parallel_rebased + +__all__ = [ + 'parallel_rebased' +] diff --git a/fla/ops/retention/__init__.py b/fla/ops/retention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a38ab43c9982c9751bb9db146b9d9fe05663964a --- /dev/null +++ b/fla/ops/retention/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_retention +from .fused_chunk import fused_chunk_retention +from .fused_recurrent import fused_recurrent_retention +from .parallel import parallel_retention + +__all__ = [ + 'chunk_retention', + 'fused_chunk_retention', + 'parallel_retention', + 'fused_recurrent_retention' +] diff --git a/fla/ops/retention/__pycache__/parallel.cpython-312.pyc b/fla/ops/retention/__pycache__/parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f7e5666f4a6fcfa1237e651290fdbc1268bf5ef Binary files /dev/null and b/fla/ops/retention/__pycache__/parallel.cpython-312.pyc differ diff --git a/fla/ops/retention/chunk.py b/fla/ops/retention/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..cca1bd290ebc5fdafebc0611bc19ae607d6328ec --- /dev/null +++ b/fla/ops/retention/chunk.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch + +from fla.ops.simple_gla.chunk import chunk_simple_gla + + +@torch.compiler.disable +def chunk_retention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + """ + if head_first: + n_heads = q.shape[1] + else: + n_heads = q.shape[2] + s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log() + if head_first: + g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + else: + g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + return chunk_simple_gla( + q=q, + k=k, + v=v, + scale=scale, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + head_first=head_first, + cu_seqlens=cu_seqlens + ) diff --git a/fla/ops/retention/fused_recurrent.py b/fla/ops/retention/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..b84eb83e739d16ad44485c8a7166be7e9e08e775 --- /dev/null +++ b/fla/ops/retention/fused_recurrent.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch + +from fla.ops.simple_gla.fused_recurrent import fused_recurrent_simple_gla + + +def fused_recurrent_retention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + n_heads = q.shape[1] + else: + n_heads = q.shape[2] + s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log() + if head_first: + g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + else: + g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + return fused_recurrent_simple_gla( + q=q, + k=k, + v=v, + g=g, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + cu_seqlens=cu_seqlens, + head_first=head_first + ) diff --git a/fla/ops/retention/parallel.py b/fla/ops/retention/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..8186fc78d43674d777bd9732980e31701004b2b3 --- /dev/null +++ b/fla/ops/retention/parallel.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch + +from fla.ops.simple_gla.parallel import parallel_simple_gla + + +def parallel_retention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + output_attentions: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + output_attentions (bool): + Whether to output the materialized attention scores of shape [B, H, T, T]. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + attn (torch.Tensor): + Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None` + """ + if head_first: + n_heads = q.shape[1] + else: + n_heads = q.shape[2] + s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log() + if head_first: + g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + else: + g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + + return parallel_simple_gla( + q=q, + k=k, + v=v, + scale=scale, + g=g, + output_attentions=output_attentions, + head_first=head_first, + cu_seqlens=cu_seqlens + ) diff --git a/fla/ops/rwkv4/__init__.py b/fla/ops/rwkv4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49de2cf83aeec67069b67e0972cfccef8a81383a --- /dev/null +++ b/fla/ops/rwkv4/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .fused_recurrent import fused_recurrent_rwkv4 + +__all__ = [ + 'fused_recurrent_rwkv4' +] diff --git a/fla/ops/rwkv6/__init__.py b/fla/ops/rwkv6/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3c7c218eb873a1a2115b5587530fe55f29a9d02 --- /dev/null +++ b/fla/ops/rwkv6/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_rwkv6 +from .fused_recurrent import fused_recurrent_rwkv6 + +__all__ = [ + 'chunk_rwkv6', + 'fused_recurrent_rwkv6' +] diff --git a/fla/ops/rwkv6/chunk.py b/fla/ops/rwkv6/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..b495dbb21f3023b5af2f744793b0fa58c90dc9e8 --- /dev/null +++ b/fla/ops/rwkv6/chunk.py @@ -0,0 +1,1465 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.chunk_h import chunk_fwd_h +from fla.ops.gla.chunk import chunk_gla_bwd_dA, chunk_gla_bwd_dv, chunk_gla_fwd_o_gk +from fla.ops.utils.op import exp +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, use_cuda_graph + +BK_LIST = [32, 64] if check_shared_mem() else [16, 32] +BV_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BS': BS}, num_warps=num_warps, num_stages=num_stages) + for BS in [16, 32, 64] + for num_warps in [4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=['S', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_fwd_cumsum_kernel( + s, + oi, + oe, + offsets, + indices, + T, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + m_i = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.).to(tl.float32) + m_e = tl.where(o_i[:, None] > o_i[None, :], 1., 0.).to(tl.float32) + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_oi = tl.make_block_ptr(oi + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_oe = tl.make_block_ptr(oe + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_oi = tl.make_block_ptr(oi + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_oe = tl.make_block_ptr(oe + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_oi = tl.dot(m_i, b_s) + b_oe = tl.dot(m_e, b_s) + tl.store(p_oi, b_oi.to(p_oi.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_oe, b_oe.to(p_oe.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +def chunk_rwkv6_fwd_cumsum( + g: torch.Tensor, + chunk_size: int, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + head_first: bool = True +) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + gi, ge = torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float) + def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) + # keep cummulative normalizer in fp32 + chunk_rwkv6_fwd_cumsum_kernel[grid]( + g, + gi, + ge, + offsets, + indices, + T=T, + H=H, + S=S, + BT=BT, + HEAD_FIRST=head_first + ) + return gi, ge + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BC'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_fwd_A_kernel_intra_sub_inter( + q, + k, + gi, # cumulative decay inclusive + ge, # cumulative decay exclusive + A, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_i, i_j = i_c // NC, i_c % NC + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + if i_i <= i_j: + return + + m_i = i_t * BT + i_i * BC + tl.arange(0, BC) < T + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gq = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(gi + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_i * BC - 1) * K + o_k, BK), BK) + else: + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gq = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(gi + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = gi + (bos + i_t * BT + i_i * BC - 1) * H*K + i_h * K + o_k + + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_gq = tl.where(m_i[:, None] & m_k, tl.load(p_gq, boundary_check=(0, 1)), float('-inf')) + b_qg = b_q * exp(b_gq - b_gn[None, :]) * scale + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = b_k * exp(b_gn[:, None] - b_gk) + # [BC, BC] using tf32 to improve precision here. + b_A += tl.dot(b_qg, b_kg) + + if HEAD_FIRST: + p_A = tl.make_block_ptr(A + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + else: + p_A = tl.make_block_ptr(A + (bos*H + i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['BK', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_fwd_A_kernel_intra_sub_intra( + q, + k, + gi, + ge, + u, + A, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_j = i_i + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + + o_i = tl.arange(0, BC) + o_k = tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + if HEAD_FIRST: + o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_qj = tl.max_contiguous(tl.multiple_of(q + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + p_kj = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + p_gk = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + else: + o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_qj = q + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_kj = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_gk = gi + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + + p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (0,), (BK,), (0,)) + b_u = tl.load(p_u, boundary_check=(0,)) + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) + b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) + b_A = tl.sum(b_q * b_kj[None, :] * exp(b_g - b_gk[None, :]), 1) + b_A = tl.where(o_i > j, b_A * scale, 0.) + b_A = tl.where(o_i != j, b_A, tl.sum(b_qj * b_kj * b_u * scale)) + tl.store(A + o_A + j, b_A, mask=m_A) + p_qj += K if HEAD_FIRST else H*K + p_kj += K if HEAD_FIRST else H*K + p_gk += K if HEAD_FIRST else H*K + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['BC', 'BK'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_fwd_A_kernel_intra_sub_intra_split( + q, + k, + gi, + ge, + u, + A, + offsets, + indices, + scale, + B: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_tc, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_t, i_i = i_tc // NC, i_tc % NC + i_j = i_i + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + all = T + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + all = B * T + + if i_t * BT + i_i * BC >= T: + return + + o_i = tl.arange(0, BC) + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + + if HEAD_FIRST: + o_A = (i_k * B*H + i_bh) * T * BC + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BC + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_qj = tl.max_contiguous(tl.multiple_of(q + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + p_kj = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + p_gk = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + else: + o_A = (i_k * all + bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BC + i_h * BC + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_qj = q + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_kj = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_gk = gi + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + + p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (i_k * BK), (BK,), (0,)) + b_u = tl.load(p_u, boundary_check=(0,)) + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) + b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) + b_A = tl.sum(b_q * b_kj[None, :] * exp(b_g - b_gk[None, :]), 1) + b_A = tl.where(o_i > j, b_A * scale, 0.) + b_A = tl.where(o_i != j, b_A, tl.sum(b_qj * b_kj * b_u * scale)) + tl.store(A + o_A + j, b_A, mask=m_A) + p_qj += K if HEAD_FIRST else H*K + p_kj += K if HEAD_FIRST else H*K + p_gk += K if HEAD_FIRST else H*K + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['BC'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_fwd_A_kernel_intra_sub_intra_merge( + A, + A2, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + NK: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + all = T + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + all = B * T + + if i_t * BT + i_c * BC >= T: + return + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + for i_k in range(0, NK): + if HEAD_FIRST: + p_A = tl.make_block_ptr(A + (i_k*B*H+i_bh)*T*BC, (T, BC), (BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0)) + else: + p_A = tl.make_block_ptr(A + (i_k*all+bos)*H*BC+i_h*BC, (T, BC), (H*BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0)) + b_A += tl.load(p_A, boundary_check=(0, 1)) + if HEAD_FIRST: + p_A2 = tl.make_block_ptr(A2 + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0)) + else: + p_A2 = tl.make_block_ptr(A2 + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0)) + tl.store(p_A2, b_A.to(A2.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_bwd_kernel_dh( + q, + gi, + ge, + do, + dh, + dht, + dh0, + offsets, + chunk_offsets, + scale, + T, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bg = i_nh // NG + i_n, i_hq = i_nh // HQ, i_nh % HQ + i_h = i_hq // NG + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT - 1, -1, -1): + if HEAD_FIRST: + p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + last_idx = min(i_t * BT + BT, T) - 1 + # [BK, BT] + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + if HEAD_FIRST: + p_gk = tl.make_block_ptr(ge + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gi + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK) + p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) + else: + p_gk = tl.make_block_ptr(ge + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gi + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_q = (b_q * exp(b_gk) * scale).to(b_q.dtype) + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_dh *= exp(b_gk_last)[:, None] + b_dh += tl.dot(b_q, b_do) + + if STORE_INITIAL_STATE_GRADIENT: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8] + ], + key=['BK', 'NC', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_bwd_kernel_intra( + q, + k, + gi, + ge, + dA, + dq, + dk, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_t, i_i = i_c // NC, i_c % NC + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + if i_t * BT + i_i * BC >= T: + return + + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + if HEAD_FIRST: + p_ge = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + else: + p_ge = tl.make_block_ptr(ge + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + # [BC, BK] + b_ge = tl.load(p_ge, boundary_check=(0, 1)) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + if i_i > 0: + if HEAD_FIRST: + p_gn = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_i * BC - 1) * K + o_k, BK), BK) + else: + p_gn = gi + (bos + i_t * BT + i_i * BC - 1) * H*K + i_h*K + o_k + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(0, i_i): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + else: + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(gi+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA+(bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = b_k * exp(b_gn[None, :] - b_gk) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dA, b_kg) + b_dq *= exp(b_ge - b_gn[None, :]) + + o_i = tl.arange(0, BC) + m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + if HEAD_FIRST: + o_dA = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + p_kj = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_gkj = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + else: + o_dA = bos*H*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_i * BC + p_kj = k + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_gkj = gi + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_dq = tl.make_block_ptr(dq + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK,] + b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) + b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] > j + # [BC, BK] + # (SY 09/17) important to not use bf16 here to have a good precision. + b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * exp(b_ge - b_gkj[None, :]), 0.) + p_kj += K if HEAD_FIRST else H*K + p_gkj += K if HEAD_FIRST else H*K + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(gi + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + + NC = min(NC, tl.cdiv(T - i_t * BT, BC)) + if i_i < NC - 1: + if HEAD_FIRST: + p_gn = gi + i_bh * T*K + (min(i_t * BT + i_i * BC + BC, T) - 1)*K + o_k + p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK) + else: + p_gn = gi + (bos + min(i_t * BT + i_i * BC + BC, T) - 1) * H*K + i_h*K + o_k + + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(i_i + 1, NC): + m_j = (i_t * BT + i_j * BC + tl.arange(0, BC)) < T + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gq = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k*BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (BT, T), (1, BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + else: + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k*BK), (BC, BK), (1, 0)) + p_gq = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k*BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + (bos*H+i_h)*BT, (BT, T), (1, H*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_gq = tl.where(m_j[:, None] & m_k, tl.load(p_gq, boundary_check=(0, 1)), float('-inf')) + b_qg = b_q * exp(b_gq - b_gn[None, :]) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + # (SY 09/17) important to not use bf16 here to have a good precision. + b_dk += tl.dot(b_dA, b_qg) + b_dk *= exp(b_gn[None, :] - b_gk) + if HEAD_FIRST: + o_dA = i_bh * T*BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC) + p_qj = tl.max_contiguous(tl.multiple_of(q + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_gqj = tl.max_contiguous(tl.multiple_of(ge + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_dk = tl.make_block_ptr(dk + i_bh*T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + else: + o_dA = bos*H*BT + (i_t * BT + i_i * BC) * H*BT + i_h * BT + i_i * BC + tl.arange(0, BC) + p_qj = q + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_gqj = ge + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_dA = tl.load(dA + o_dA + j * (1 if HEAD_FIRST else H) * BT) + # [BK,] + b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) + b_gqj = tl.load(p_gqj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] < j + b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp(b_gqj[None, :] - b_gk), 0.) + p_qj += K if HEAD_FIRST else H*K + p_gqj += K if HEAD_FIRST else H*K + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in [2, 4, 8] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_bwd_kernel_inter( + q, + k, + v, + h, + gi, + ge, + u, + do, + dh, + dA, + dq, + dk, + dq2, + dk2, + dg, + du, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + if HEAD_FIRST: + p_gk = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gi = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = tl.max_contiguous(tl.multiple_of(gi + i_bh * T*K + (min(T, i_t * BT + BT)-1) * K + o_k, BK), BK) + else: + p_gk = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gi = tl.make_block_ptr(gi + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = gi + (bos + min(T, i_t * BT + BT)-1) * H*K + i_h * K + o_k + b_gn = tl.load(p_gn, mask=m_k, other=0) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dgk = tl.zeros([BK,], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BK] + b_dgk += tl.sum(b_h * b_dh, axis=0) + # [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + b_dgk *= exp(b_gn) + b_dq *= scale + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_gi = tl.load(p_gi, boundary_check=(0, 1)) + b_dq = b_dq * exp(b_gk) + b_dk = b_dk * exp(b_gn[None, :] - b_gi) + + o_i = tl.arange(0, BT) + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA_dig = dA + (i_bh * T + i_t * BT + o_i) * BT + o_i + else: + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA_dig = dA + ((bos + i_t * BT + o_i) * H + i_h) * BT + o_i + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dgk += tl.sum(b_dk * b_k, axis=0) + + b_dq += tl.load(p_dq, boundary_check=(0, 1)) + b_dk += tl.load(p_dk, boundary_check=(0, 1)) + b_dg = b_q * b_dq - b_k * b_dk + b_dg = b_dg - tl.cumsum(b_dg, axis=0) + tl.sum(b_dg, axis=0)[None, :] + b_dgk[None, :] - b_q * b_dq + # [BT,] + b_dA_dig = tl.load(p_dA_dig, mask=(i_t * BT + o_i) < T, other=0) + + p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (i_k * BK,), (BK,), (0,)) + b_u = tl.load(p_u, boundary_check=(0,)) + # scale is already applied to b_dA_diag + b_dq += (b_dA_dig[:, None] * b_u[None, :] * b_k) + b_dk += (b_dA_dig[:, None] * b_u[None, :] * b_q) + b_du = tl.sum(b_dA_dig[:, None] * b_q * b_k, axis=0) + p_du = tl.make_block_ptr(du + (i_tg * H + i_h) * K, (K,), (1,), (i_k * BK,), (BK,), (0,)) + tl.store(p_du, b_du, boundary_check=(0,)) + + if HEAD_FIRST: + p_dq = tl.make_block_ptr(dq2 + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk2 + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_dq = tl.make_block_ptr(dq2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_rwkv6_fwd_intra( + q: torch.Tensor, + k: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + u: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + if head_first: + B, H, T, K = k.shape + else: + B, T, H, K = k.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BC = min(16, BT) + NC = triton.cdiv(BT, BC) + + A = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float) + grid = (NT, NC * NC, B * H) + chunk_rwkv6_fwd_A_kernel_intra_sub_inter[grid]( + q, + k, + gi, + ge, + A, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + NC=NC, + HEAD_FIRST=head_first + ) + + grid = (NT, NC, B * H) + # load the entire [BC, K] blocks into SRAM at once + if K <= 256: + BK = triton.next_power_of_2(K) + chunk_rwkv6_fwd_A_kernel_intra_sub_intra[grid]( + q, + k, + gi, + ge, + u, + A, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + HEAD_FIRST=head_first + ) + # split then merge + else: + BK = min(128, triton.next_power_of_2(K)) + NK = triton.cdiv(K, BK) + A_intra = q.new_empty(NK, B, *((H, T) if head_first else (T, H)), BC, dtype=torch.float) + + grid = (NK, NT * NC, B * H) + chunk_rwkv6_fwd_A_kernel_intra_sub_intra_split[grid]( + q, + k, + gi, + ge, + u, + A_intra, + offsets, + indices, + scale, + B=B, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + NC=NC, + HEAD_FIRST=head_first + ) + + grid = (NT, NC, B * H) + chunk_rwkv6_fwd_A_kernel_intra_sub_intra_merge[grid]( + A_intra, + A, + offsets, + indices, + B=B, + T=T, + H=H, + BT=BT, + BC=BC, + NK=NK, + HEAD_FIRST=head_first + ) + return A + + +def chunk_rwkv6_bwd_dh( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + do: torch.Tensor, + h0: torch.Tensor, + dht: torch.Tensor, + scale: float, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + head_first: bool = True, + chunk_size: int = 64, + states_in_fp32: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + HQ = q.shape[1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + # N: the actual number of sequences in the batch with either equal or variable lengths + # NG: number of groups in GQA + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT = len(offsets) - 1, len(indices) + chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1) + NG = HQ // H + + if head_first: + dh = k.new_empty(B, HQ, NT, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + else: + dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None + + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) + chunk_rwkv6_bwd_kernel_dh[grid]( + q=q, + gi=gi, + ge=ge, + do=do, + dh=dh, + dht=dht, + dh0=dh0, + offsets=offsets, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + NG=NG, + HEAD_FIRST=head_first + ) + return dh, dh0 + + +def chunk_rwkv6_bwd_dqk_intra( + q: torch.Tensor, + k: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + dA: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + if head_first: + B, H, T, K = q.shape + else: + B, T, H, K = q.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BC = min(16, BT) + BK = min(64, triton.next_power_of_2(K)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + NC = triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + + dq = torch.empty_like(q, dtype=torch.float) + dk = torch.empty_like(k, dtype=torch.float) + grid = (NK, NT * NC, B * H) + chunk_rwkv6_bwd_kernel_intra[grid]( + q, + k, + gi, + ge, + dA, + dq, + dk, + offsets, + indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + NC=NC, + HEAD_FIRST=head_first + ) + return dq, dk + + +def chunk_rwkv6_bwd_dqkgu( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + u: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + dA: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + dq2 = torch.empty_like(dq) + dk2 = torch.empty_like(dk) + dg = torch.empty_like(g) + du = u.new_empty(B * NT, H, K, dtype=torch.float) + def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H) + chunk_rwkv6_bwd_kernel_inter[grid]( + q, + k, + v, + h, + gi, + ge, + u, + do, + dh, + dA, + dq, + dk, + dq2, + dk2, + dg, + du, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + HEAD_FIRST=head_first + ) + du = du.sum(0) + return dq2, dk2, dg, du + + +def chunk_rwkv6_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + u: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + gi, ge = chunk_rwkv6_fwd_cumsum(g, chunk_size=chunk_size, offsets=offsets, indices=indices, head_first=head_first) + h, ht = chunk_fwd_h( + k=k, + v=v, + g=None, + gk=gi, + gv=None, + h0=initial_state, + output_final_state=output_final_state, + offsets=offsets, + head_first=head_first, + chunk_size=chunk_size, + states_in_fp32=True + ) + # the intra A is kept in fp32 + # the computation has very marginal effect on the entire throughput + A = chunk_rwkv6_fwd_intra( + q=q, + k=k, + gi=gi, + ge=ge, + u=u, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + + o = chunk_gla_fwd_o_gk( + q=q, + v=v, + g=ge, + A=A, + h=h, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return A, h, ht, o + + +def chunk_rwkv6_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + u: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + A: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + gi, ge = chunk_rwkv6_fwd_cumsum(g, chunk_size=chunk_size, offsets=offsets, indices=indices, head_first=head_first) + h, _ = chunk_fwd_h( + k=k, + v=v, + g=None, + gk=gi, + gv=None, + h0=initial_state, + output_final_state=False, + offsets=offsets, + head_first=head_first, + chunk_size=chunk_size, + states_in_fp32=True + ) + dh, dh0 = chunk_rwkv6_bwd_dh( + q=q, + k=k, + v=v, + gi=gi, + ge=ge, + do=do, + h0=initial_state, + dht=dht, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size, + states_in_fp32=True + ) + + # dq dk in fp32 + dA = chunk_gla_bwd_dA( + v=v, + do=do, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + dv = chunk_gla_bwd_dv( + k=k, + g=gi, + A=A, + do=do, + dh=dh, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + dq, dk = chunk_rwkv6_bwd_dqk_intra( + q=q, + k=k, + gi=gi, + ge=ge, + dA=dA, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + dq, dk, dg, du = chunk_rwkv6_bwd_dqkgu( + q=q, + k=k, + v=v, + h=h, + g=g, + gi=gi, + ge=ge, + u=u, + do=do, + dh=dh, + dA=dA, + dq=dq, + dk=dk, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return dq, dk, dv, dg, du, dh0 + + +class ChunkRWKV6Function(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q, + k, + v, + g, + u, + scale, + initial_state, + output_final_state, + offsets, + head_first + ): + T = q.shape[2] if head_first else q.shape[1] + chunk_size = min(32, max(32, triton.next_power_of_2(T))) if check_shared_mem() \ + else min(64, max(32, triton.next_power_of_2(T))) + + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = None + if offsets is not None: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()]) + indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) + + A, h, ht, o = chunk_rwkv6_fwd( + q=q, + k=k, + v=v, + g=g, + u=u, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + + ctx.save_for_backward(q, k, v, g, initial_state, A, u) + + ctx.chunk_size = chunk_size + ctx.scale = scale + ctx.offsets = offsets + ctx.indices = indices + ctx.head_first = head_first + return o, ht + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht): + q, k, v, g, initial_state, A, u = ctx.saved_tensors + chunk_size, scale, offsets, indices, head_first = ctx.chunk_size, ctx.scale, ctx.offsets, ctx.indices, ctx.head_first + dq, dk, dv, dg, du, dh0 = chunk_rwkv6_bwd( + q=q, + k=k, + v=v, + g=g, + u=u, + scale=scale, + initial_state=initial_state, + A=A, + do=do, + dht=dht, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return dq.to(q), dk.to(k), dv.to(v), dg.to(g), du.to(u), None, dh0, None, None, None + + +@torch.compiler.disable +def chunk_rwkv6( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + u: torch.Tensor, + scale: Optional[int] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + g (torch.Tensor): + Forget gates of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` applied to keys. + u (torch.Tensor): + bonus representations of shape `[H]`. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + final_state (Optional[torch.Tensor]): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.rwkv6 import chunk_rwkv6 + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) + >>> u = torch.randn(H, K, device='cuda') + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = chunk_rwkv6(q, k, v, g, u, + initial_state=h0, + output_final_state=True, + head_first=False) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_rwkv6(q, k, v, g, u, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens, + head_first=False) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + if scale is None: + scale = q.shape[-1] ** -0.5 + o, final_state = ChunkRWKV6Function.apply( + q, + k, + v, + g, + u, + scale, + initial_state, + output_final_state, + cu_seqlens, + head_first + ) + return o, final_state diff --git a/fla/ops/rwkv6/chunk_naive.py b/fla/ops/rwkv6/chunk_naive.py new file mode 100644 index 0000000000000000000000000000000000000000..4a2ac664f5079a20eabe9b11c19c1cff6755c658 --- /dev/null +++ b/fla/ops/rwkv6/chunk_naive.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +def naive_chunk_rwkv6( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + chunk_size: int = 32 +): + assert q.shape[-2] % chunk_size == 0 + orig_dtype = q.dtype + num_chunk = q.shape[-2] // chunk_size + u = u.unsqueeze(0) + + q, k, v, w = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size).float(), (q, k, v, w)) + + w_cumsum = w.cumsum(-2) + + kw = k * (w_cumsum[..., -1, None, :] - w_cumsum).exp() + wkv = kw.transpose(-1, -2) @ v + + wkv_new = torch.zeros_like(wkv) + + for i in range(num_chunk - 1): + wkv_new[:, :, i+1] = (wkv_new[:, :, i] * w_cumsum[:, :, i, -1, :, None].exp()) + wkv[:, :, i] + + o_inter = torch.einsum('b h n d p, b h n c d -> b h n c p', wkv_new, (q * (w_cumsum - w).exp())) + + o_intra = torch.zeros_like(o_inter) + for i in range(chunk_size): + attn = (q[:, :, :, i, None] * k * (w_cumsum[:, :, :, i, None] - w[:, :, :, i, None] - w_cumsum).exp()).sum(-1) + mask = (torch.arange(0, chunk_size) < i).to(attn.device) + attn.masked_fill_(~mask, 0) + intra_inter_o = (attn.unsqueeze(-1) * v).sum(-2) + intra_intra_o = (q[:, :, :, i] * u.unsqueeze(2) * k[:, :, :, i]).sum(-1).unsqueeze(-1) * v[:, :, :, i] + o_intra[:, :, :, i] = intra_inter_o + intra_intra_o + o = o_inter + o_intra + return rearrange(o, 'b h n c d -> b h (n c) d').to(orig_dtype) diff --git a/fla/ops/rwkv6/recurrent_naive.py b/fla/ops/rwkv6/recurrent_naive.py new file mode 100644 index 0000000000000000000000000000000000000000..ba2268759b5d4ce7f9be1be1f9c2e1a2f2a8e6c3 --- /dev/null +++ b/fla/ops/rwkv6/recurrent_naive.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch + + +def naive_recurrent_rwkv6( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False +): + orig_dtype = q.dtype + B, H, T, K, V = *q.shape, v.shape[-1] + q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u)) + h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) + o = torch.zeros_like(v) + + if scale is None: + scale = K ** -0.5 + + if initial_state is not None: + h += initial_state + + for i in range(T): + q_i = q[:, :, i, :] * scale + k_i = k[:, :, i] + v_i = v[:, :, i, :] + w_i = w[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None] + o[:, :, i] = o_i.sum(-2) + h = h * w_i[..., None] + kv_i + ht = h if output_final_state else None + return o.to(orig_dtype), ht + + +@torch.no_grad +@torch.jit.script +def naive_recurrent_rwkv6_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + o: torch.Tensor, + do: torch.Tensor, + initial_state: Optional[torch.Tensor] = None +): + q, k, v, w, u, o, do = (x.to(dtype=torch.float32) for x in (q, k, v, w, u, o, do)) + B, H, T, K, V = q.shape[0], q.shape[1], q.shape[2], q.shape[3], v.shape[-1] + h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) + dq = torch.zeros_like(q) + dq_aux = torch.zeros_like(q) + + if initial_state is not None: + h += initial_state + + for i in range(T): + k_i = k[:, :, i] + v_i = v[:, :, i] + w_i = w[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + h_i = (h + u[None, ..., None] * kv_i) + dq_i = (do[:, :, i, None, :] * h_i).sum(-1) + dq_aux_i = (do[:, :, i, None, :] * h).sum(-1) + dq[:, :, i] = dq_i + dq_aux[:, :, i] = dq_aux_i + h = h * w_i[..., None] + kv_i + + du = torch.zeros_like(u) + dh = torch.zeros_like(h) + dk = torch.zeros_like(k) + dk_aux = torch.zeros_like(k) + dv = torch.zeros_like(v) + + for i in range(T - 1, -1, -1): + d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None] + k_i = k[:, :, i] + v_i = v[:, :, i] + du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1) + du += du_i.sum(0) + dk_i = (dh * v_i[..., None, :]).sum(-1) + dk_aux[:, :, i] = dk_i + dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1) + dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2) + dv_i += (dh * k_i[..., None]).sum(-2) + + dk[:, :, i] = dk_i + dv[:, :, i] = dv_i + dh = dh * w[:, :, i, :, None].exp() + d_kv_i + + # dw = q * dq_aux - k * dk_aux + dw = torch.zeros_like(w) + for i in range(T - 2, -1, -1): + dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i] + + return dq, dk, dv, dw, du, dh diff --git a/fla/ops/rwkv7/fused_recurrent.py b/fla/ops/rwkv7/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce2d15aec9c6995f2df26992c89a29182f0169d --- /dev/null +++ b/fla/ops/rwkv7/fused_recurrent.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch + +from fla.ops.generalized_delta_rule import fused_recurrent_dplr_delta_rule + + +def fused_recurrent_rwkv7( + r: torch.Tensor, + w: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float = 1.0, + initial_state: torch.Tensor = None, + output_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +): + """ + Args: + r (torch.Tensor): + r of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + w (torch.Tensor): + log decay of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + k of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + v of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + a (torch.Tensor): + a of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + b (torch.Tensor): + b of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + scale (float): + scale of the attention. + initial_state (torch.Tensor): + initial state of shape `[B, H, K, V]` if cu_seqlens is None else `[N, H, K, V]` where N = len(cu_seqlens) - 1. + output_final_state (bool): + whether to output the final state. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (bool): + whether to use head first. Recommended to be False to avoid extra transposes. + """ + return fused_recurrent_dplr_delta_rule( + q=r, + k=k, + v=v, + a=a, + b=b, + gk=w, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + head_first=head_first + ) diff --git a/fla/ops/simple_gla/README.md b/fla/ops/simple_gla/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2a64f3dcdee7ff9863089a6b47ef694f6234ab8f --- /dev/null +++ b/fla/ops/simple_gla/README.md @@ -0,0 +1,10 @@ +# Simple GLA + +Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet). + +Compared to GLA, the gating is head-wise instead of elementwise. +As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. +It is faster than GLA but has less expressive power. +I will use it as a baseline for the GLA. + +$S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar. diff --git a/fla/ops/simple_gla/naive.py b/fla/ops/simple_gla/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..5fcc96ebeb720cc8b9699793ee6bdf8d3d39fdaa --- /dev/null +++ b/fla/ops/simple_gla/naive.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +def torch_simple_gla(q, k, v, g, chunk_size=64, scale=None): + if scale is None: + scale = (q.shape[-1] ** -0.5) + q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) * scale + k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) + v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) + g = rearrange(g, 'b h (n c) -> b h n c', c=chunk_size) + g = g.cumsum(-1) + kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None]) + S = torch.zeros_like(kv) + + for i in range(1, g.shape[-2]): + S[:, :, i] = S[:, :, i-1].clone() * g[:, :, i-1, -1, None, None].exp() + kv[:, :, i-1] + + inter = (q * g[..., None].exp()) @ S + attn = q @ k.transpose(-1, -2) + attn = attn * (g[..., None] - g[..., None, :]).exp() + attn = attn.masked_fill(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) + intra = attn @ v + o = inter + intra + return rearrange(o, 'b h n c d -> b h (n c) d') + + +def torch_simple_gla_recurrent(q, k, v, g, scale=None, initial_state=None, output_final_state=True): + B, H, T, DK = q.shape + original_dtype = q.dtype + q, k, v, g = q.float(), k.float(), v.float(), g.float() + if scale is None: + scale = DK ** -0.5 + q = q * scale + _, _, _, DV = v.shape + if initial_state is None: + S = torch.zeros(B, H, DK, DV) + else: + S = initial_state + o = torch.zeros(B, H, T, DV).to(q) + for i in range(T): + gate = g[:, :, i].exp() + key = k[:, :, i] + value = v[:, :, i] + kv = key.unsqueeze(-1) * value.unsqueeze(-2) + S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv + q_i = q[:, :, i, :] + o_i = (q_i.unsqueeze(-1) * S).sum(-2) + o[:, :, i] = o_i + if not output_final_state: + S = None + return o.to(original_dtype), S diff --git a/fla/ops/simple_gla/parallel.py b/fla/ops/simple_gla/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..d0ad1c8c33dd846eb1d1cf3c582836b6110017d7 --- /dev/null +++ b/fla/ops/simple_gla/parallel.py @@ -0,0 +1,722 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import chunk_global_cumsum, chunk_local_cumsum +from fla.ops.utils.op import safe_exp +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, is_intel_alchemist + +# https://github.com/intel/intel-xpu-backend-for-triton/issues/3449 +triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {} + + +@triton.heuristics({ + 'NV': lambda args: triton.cdiv(args['V'], args['BV']), + 'OUTPUT_ATTENTIONS': lambda args: args['attn'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'USE_G': lambda args: args['g'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=["BT", "BS", "BK", "BV", "USE_G"], +) +@triton.jit +def parallel_simple_gla_fwd_kernel( + q, + k, + v, + g, + o, + attn, + scale, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NV: tl.constexpr, + OUTPUT_ATTENTIONS: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_G: tl.constexpr +): + tl.static_assert(not (USE_OFFSETS and HEAD_FIRST), "USE_OFFSETS and HEAD_FIRST cannot be True at the same time") + i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_k, i_v = i_kv // NV, i_kv % NV + i_b, i_h = i_bh // H, i_bh % H + o += i_k * B * T * H * V + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + q += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + v += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + o += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + if USE_G: + g += i_bh * T if HEAD_FIRST else bos * H + i_h + if OUTPUT_ATTENTIONS: + attn += (bos * H + i_h * T) * T + i_k * B * H * T * T + stride_qk = K if HEAD_FIRST else H * K + stride_vo = V if HEAD_FIRST else H * V + stride_g = 1 if HEAD_FIRST else H + + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_o = tl.zeros([BT, BV], dtype=tl.float32) + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + # [BS] + o_k = i_t * BT + tl.arange(0, BS) + # Q block and K block have overlap. + # masks required + if USE_G: + p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,)) + # [BT,] + b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32) + # rescale interchunk output + else: + b_gq = None + + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BS] + m_s = o_q[:, None] >= o_k[None, :] + b_s = tl.dot(b_q, b_k) + if USE_G: + p_gk = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,)) + b_gk = tl.load(p_gk, boundary_check=(0,)) + b_s *= safe_exp(b_gq[:, None] - b_gk[None, :]) + b_s = tl.where(m_s, b_s, 0) + else: + b_s = tl.where(m_s, b_s, 0) + # [BT, BV] + if i_s >= 0: + b_o += tl.dot(b_s.to(b_q.dtype), b_v) + if OUTPUT_ATTENTIONS: + p_a = tl.make_block_ptr(attn, (T, T), (T, 1), (i_t * BT, i_s), (BT, BS), (1, 0)) + tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1)) + o_k += BS + + for i_s in range(i_t * BT - BS, -BS, -BS): + p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_s = tl.dot(b_q, b_k) + if USE_G: + p_g = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_gn = tl.load(g + (min(i_s + BS, T) - 1) * stride_g) + b_gp = tl.load(g + (i_s-1) * stride_g) if i_s % BT > 0 else 0. + # No concrete meaning. Just to avoid some layout bugs. + b_s *= safe_exp(b_gq[:, None] + (b_gn - b_g)[None, :]) + b_gq += (b_gn - b_gp) + if OUTPUT_ATTENTIONS: + p_a = tl.make_block_ptr(attn, (T, T), (T, 1), (i_t * BT, i_s), (BT, BS), (1, 0)) + tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1)) + if i_s >= 0: + b_o += tl.dot(b_s.to(b_v.dtype), b_v) + p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def parallel_simple_gla_bwd_kernel_dq( + i_t, + i_k, + i_v, + q, + k, + v, + g, + do, + dq, + dg, + stride_qk, + stride_vo, + stride_g, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr +): + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BT, BK] + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + + for i_s in range(0, i_t * BT, BS): + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v, (V, T), (1, stride_vo), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BV] @ [BV, BS] = [BT, BS] + b_ds = tl.dot(b_do, b_v) + if USE_G: + p_g = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_gn = tl.load(g + (min(i_s + BS, T) - 1) * stride_g) + b_gp = tl.load(g + (i_s - 1) * stride_g) if i_s % BT > 0 else 0. + b_ds *= safe_exp(b_gn - b_g)[None, :] + if i_s > 0: + b_dq *= safe_exp(b_gn - b_gp) + # [BT, BS] @ [BS, BK] = [BT, BK] + b_dq += tl.dot(b_ds.to(b_v.dtype), b_k) + + if USE_G: + p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,)) + # [BT,] + b_gq = tl.load(p_gq, boundary_check=(0,)) + # [BT, BK] + b_dq *= safe_exp(b_gq)[:, None] + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + # [BS] + o_k = i_t * BT + tl.arange(0, BS) + # Q block and K block have overlap. masks required + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v, (V, T), (1, stride_vo), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BV] @ [BV, BS] = [BT, BS] + b_ds = tl.dot(b_do, b_v) + if USE_G: + p_gk = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,)) + b_gk = tl.load(p_gk, boundary_check=(0,)) + b_ds *= safe_exp(b_gq[:, None] - b_gk[None, :]) + b_ds = tl.where(o_q[:, None] >= o_k[None, :], b_ds, 0) + # [BT, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), b_k) + o_k += BS + + b_dq *= scale + p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + if USE_G: + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_dg = tl.sum(b_dq * b_q, 1) + p_dg = tl.make_block_ptr(dg, (T,), (stride_g,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + +@triton.jit(do_not_specialize=['T']) +def parallel_simple_gla_bwd_kernel_dkv( + i_t, + i_k, + i_v, + q, + k, + v, + g, + do, + dk, + dv, + dg, + scale, + stride_qk, + stride_vo, + stride_g, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr +): + # [BT, BK] + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + # [BT, BV] + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + if USE_G: + p_gk = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,)) + b_gk = tl.load(p_gk, boundary_check=(0,)) + NTS = tl.cdiv(T, BS) + # [BT, BK] + for i_s in range(NTS * BS - BS, (i_t + 1) * BT - BS, -BS): + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_ds = tl.dot(b_v, tl.trans(b_do)) + b_s = tl.dot(b_k, tl.trans(b_q)) + if USE_G: + p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,)) + b_gq = tl.load(p_gq, boundary_check=(0,)) + b_gp = tl.load(g + (min(i_s + BS, T) - 1) * stride_g) + b_gn = tl.load(g + (i_s - 1) * stride_g) if i_s % BT > 0 else 0. + if i_s >= 0: + tmp = safe_exp(b_gp - b_gn) + b_dk *= tmp + b_dv *= tmp + tmp2 = safe_exp(b_gq - b_gn) + b_ds *= tmp2[None, :] + b_s *= tmp2[None, :] + # [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + # [BT, BV] + b_dv += tl.dot(b_s.to(b_do.dtype), b_do) + + if USE_G: + b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * stride_g) + if i_t >= 0: + tmp2 = safe_exp(b_g_last - b_gk)[:, None] + b_dk *= tmp2 + b_dv *= tmp2 + + o_q = i_t * BT + tl.arange(0, BS) + o_k = i_t * BT + tl.arange(0, BT) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BS, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BS, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BS] + b_ds = tl.dot(b_v, tl.trans(b_do)) + b_s = tl.dot(b_k, tl.trans(b_q)) + if USE_G: + p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,)) + b_gq = tl.load(p_gq, boundary_check=(0,)) + if i_s >= 0: + tmp = safe_exp(-b_gk[:, None] + b_gq[None, :]) + b_ds *= tmp + b_s *= tmp + m_s = o_k[:, None] <= o_q[None, :] + b_s = tl.where(m_s, b_s, 0) + b_ds = tl.where(m_s, b_ds, 0) + # [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + b_dv += tl.dot(b_s.to(b_do.dtype), b_do) + o_q += BS + b_dk *= scale + b_dv *= scale + p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + if USE_G: + p_dg = tl.make_block_ptr(dg, (T,), (stride_g,), (i_t * BT,), (BT,), (0,)) + b_dg = tl.load(p_dg, boundary_check=(0,)) + b_dg -= tl.sum(b_dk * b_k, 1) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'NV': lambda args: triton.cdiv(args['V'], args['BV']), + 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'USE_G': lambda args: args['g'] is not None +}) +@triton.autotune( + configs=[ + triton.Config(triton_config, num_warps=num_warps) + for num_warps in [2, 4, 8, 16] + ], + key=['BT', 'BS', 'BK', 'BV', 'USE_G'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_simple_gla_bwd_kernel( + q, + k, + v, + g, + do, + dq, + dk, + dv, + dg, + scale, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_G: tl.constexpr +): + tl.static_assert(not (USE_OFFSETS and HEAD_FIRST), "USE_OFFSETS and HEAD_FIRST cannot be True at the same time") + i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_k, i_v = i_kv // NV, i_kv % NV + i_b, i_h = i_bh // H, i_bh % H + dq += i_v * B * H * T * K + dk += i_v * B * H * T * K + dv += i_k * B * H * T * V + if USE_G: + dg += i_kv * B * H * T + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + q += (i_bh * T * K) if HEAD_FIRST else (bos * H + i_h) * K + k += (i_bh * T * K) if HEAD_FIRST else (bos * H + i_h) * K + v += (i_bh * T * V) if HEAD_FIRST else (bos * H + i_h) * V + do += (i_bh * T * V) if HEAD_FIRST else (bos * H + i_h) * V + dq += (i_bh * T * K) if HEAD_FIRST else (bos * H + i_h) * K + dk += (i_bh * T * K) if HEAD_FIRST else (bos * H + i_h) * K + dv += (i_bh * T * V) if HEAD_FIRST else (bos * H + i_h) * V + if USE_G: + g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) + dg += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) + stride_qk = K if HEAD_FIRST else H * K + stride_vo = V if HEAD_FIRST else H * V + stride_g = 1 if HEAD_FIRST else H + + parallel_simple_gla_bwd_kernel_dq( + i_t=i_t, + i_k=i_k, + i_v=i_v, + q=q, + k=k, + v=v, + g=g, + do=do, + dq=dq, + dg=dg, + scale=scale, + stride_qk=stride_qk, + stride_vo=stride_vo, + stride_g=stride_g, + T=T, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + USE_G=USE_G + ) + tl.debug_barrier() + parallel_simple_gla_bwd_kernel_dkv( + i_t=i_t, + i_k=i_k, + i_v=i_v, + q=q, + k=k, + v=v, + g=g, + do=do, + dk=dk, + dv=dv, + dg=dg, + scale=scale, + stride_qk=stride_qk, + stride_vo=stride_vo, + stride_g=stride_g, + T=T, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + USE_G=USE_G + ) + + +def parallel_simple_gla_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: float, + output_attentions: bool = False, + chunk_size: int = 128, + head_first: bool = True, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT, BS = chunk_size, 32 + if check_shared_mem('hopper', k.device.index): + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + elif check_shared_mem('ampere', k.device.index): + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + else: + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert BT % BS == 0 + + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + # local cumulative decay in log space + if g is not None: + g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first) + grid = (NK * NV, NT, B * H) + o = torch.empty(NK, *v.shape, dtype=v.dtype if NK == 1 else torch.float, device=q.device) + attn = q.new_zeros(NK, B, H, T, T) if output_attentions else None + + parallel_simple_gla_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + o=o, + attn=attn, + scale=scale, + offsets=offsets, + indices=indices, + B=B, + H=H, + T=T, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + HEAD_FIRST=head_first, + ) + o = o.sum(0) + + if output_attentions: + attn = attn.sum(0) + return o, g, attn + + +def parallel_simple_gla_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + scale: float, + chunk_size: int = 128, + head_first: bool = True, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT, BS = chunk_size, 32 + if check_shared_mem('hopper', k.device.index): + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + elif check_shared_mem('ampere', k.device.index): + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + elif check_shared_mem('ada', k.device.index): + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + else: + BK = min(32, triton.next_power_of_2(K)) + BV = min(32, triton.next_power_of_2(V)) + + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert BT % BS == 0 + + dq = torch.empty(NV, * q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device) + dk = torch.empty(NV, * k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device) + dv = torch.empty(NK, * v.shape, dtype=v.dtype if NK == 1 else torch.float, device=q.device) + dg = torch.empty(NK*NV, *g.shape, dtype=torch.float, device=q.device) if g is not None else None + + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + grid = (NK * NV, NT, B * H) + parallel_simple_gla_bwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + do=do, + dq=dq, + dk=dk, + dv=dv, + dg=dg, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dg = chunk_global_cumsum(dg.sum(0), reverse=True, head_first=head_first, offsets=offsets) if g is not None else None + return dq, dk, dv, dg + + +class ParallelSimpleGLAFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, g, scale, output_attentions, head_first, offsets): + chunk_size = 128 + ctx.dtype = q.dtype + + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = None + if offsets is not None: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()]) + indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) + + o, g, attn = parallel_simple_gla_fwd( + q=q, + k=k, + v=v, + g=g, + scale=scale, + output_attentions=output_attentions, + head_first=head_first, + offsets=offsets, + indices=indices, + chunk_size=chunk_size) + ctx.save_for_backward(q, k, v, g, offsets, indices) + ctx.scale = scale + ctx.chunk_size = chunk_size + ctx.head_first = head_first + return o.to(q.dtype), attn + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, da=None): + q, k, v, g, offsets, indices = ctx.saved_tensors + dq, dk, dv, dg = parallel_simple_gla_bwd( + q=q, + k=k, + v=v, + g=g, + do=do, + scale=ctx.scale, + chunk_size=ctx.chunk_size, + offsets=offsets, + indices=indices, + head_first=ctx.head_first) + return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.dtype) if dg is not None else None, None, None, None, None + + +def parallel_simple_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + output_attentions: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + g (torch.Tensor): + Forget gates of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`. + Compared to GLA, the gating is head-wise instead of elementwise. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + output_attentions (bool): + Whether to output the materialized attention scores of shape [B, H, T, T]. Default: `False`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + attn (torch.Tensor): + Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None` + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + assert not head_first, "head_first must be False when cu_seqlens are provided" + if g is not None: + g = g.float() + if output_attentions: + assert cu_seqlens is None, "output_attentions=True is not supported with variable-length sequences" + o, attn = ParallelSimpleGLAFunction.apply(q, k, v, g, scale, output_attentions, head_first, cu_seqlens) + return o, attn diff --git a/fla/ops/titans/__init__.py b/fla/ops/titans/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..55aa4e0588f0b27d61ba190c5987ae7d637ca3d8 --- /dev/null +++ b/fla/ops/titans/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .naive import chunk_titans_linear + +__all__ = [ + 'chunk_titans_linear' +] diff --git a/fla/ops/titans/log_impl.py b/fla/ops/titans/log_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..c05a6436ca662749f34cc51d24436e992ab6ede8 --- /dev/null +++ b/fla/ops/titans/log_impl.py @@ -0,0 +1,153 @@ +import torch + + +def cal_n_log(log_theta, log_eta, seq_len): + """ + calculate n_{i,j} in log space + log(n_{i,j}) = log(θ_j) + sum_{k=j+1}^i log(η_k) + """ + # create log(n) + log_n = torch.zeros(*log_theta.shape, seq_len, dtype=log_eta.dtype).to( + log_eta.device + ) # [batch_size, num_heads, seq_len, seq_len] + for i in range(seq_len): + for j in range(i + 1): + if i == j: + log_n[..., j, i] = log_theta[..., j] + else: + log_n[..., j, i] = log_theta[..., j] + torch.sum( + log_eta[..., j + 1: i + 1], dim=-1 + ) + + return log_n + + +def cal_f_log(log_beta, seq_len, log_m): + """ + cal_f_log(log_beta, seq_len, log_m) -> f + log(f_t) = log(sum_{i=1}^t exp(sum_{k=i+1}^t log(1-α_k) + sum_{k=1}^i log(η_k))) + """ + # create f + # f = torch.zeros_like(log_beta) + # for t in range(seq_len): + # for i in range(t + 1): + # f[..., t] += torch.exp(log_beta[..., t] - log_beta[..., i] + log_m[..., i]) + log_f = torch.zeros_like(log_beta) + for t in range(seq_len): + a_i = log_beta[..., t: t + 1] - log_beta[..., : t + 1] + log_m[..., : t + 1] + log_f[..., t] = torch.logsumexp(a_i, dim=-1) + f = torch.exp(log_f) + + # this version overflow and even slower + # t_indices = torch.arange(seq_len, device=log_beta.device) + # i_indices = torch.arange(seq_len, device=log_beta.device) + # + # mask = i_indices.unsqueeze(0) <= t_indices.unsqueeze(1) + # log_beta_t = log_beta.unsqueeze(-1) # [..., seq_len, 1] + # log_beta_i = log_beta.unsqueeze(-2) # [..., 1, seq_len] + # log_m_i = log_m.unsqueeze(-2) + # a_i = log_beta_t - log_beta_i + log_m_i + # masked_a_i = torch.where(mask, a_i, torch.tensor(-float('inf'), device=a_i.device, dtype=a_i.dtype)) + # log_f = torch.logsumexp(masked_a_i, dim=-1) # [..., seq_len] + # + # f = torch.exp(log_f) + return f + + +def cal_G_log(log_beta, log_n, seq_len): + """ + calculate G_{i,j} + log(G_{i,j}) = log(sum_{k=j}^i exp(log(β_i/β_k) + log(n_{k,j}))) + """ + # G = torch.zeros(*log_beta.shape[:-1], seq_len, seq_len, device = log_beta.device) + # # Fill in the lower triangular part + # for i in range(seq_len): # row + # for j in range(i + 1): # column + # # Sum from k=j to i + # for k in range(j, i + 1): + # G[..., i, j] += torch.exp(log_beta[..., i] - log_beta[..., k] + log_n[..., j, k]) + + log_G = torch.full( + (*log_beta.shape[:-1], seq_len, seq_len), float("-inf"), device=log_beta.device + ) + # fill in the lower triangular part + for i in range(seq_len): # row + for j in range(i + 1): # column + terms = ( + log_beta[..., i: i + 1] + - log_beta[..., j: i + 1] + + log_n[..., j: j + 1, j: i + 1].squeeze(-2) + ) + # use logsumexp to avoid overflow + log_G[..., i, j] = torch.logsumexp(terms, dim=-1) + + G = torch.exp(log_G) + return G + + +def _combine_params_log(log_theta, log_alpha_complement, log_eta, seq_len): + """ + Update rule for Titans in log space + + Parameters: + - log_theta: log(θ) + - log_alpha_complement: log(1-α) + - log_eta: log(η) + - seq_len: sequence length + + Returns: + - log_beta, beta_T, log_f, f_T, log_g, log_G, m_T, n_T + """ + # calculate log(β_t) = sum_{k=1}^t log(1-α_k) + log_beta = torch.cumsum(log_alpha_complement, dim=-1) + + # get β_T + beta_T = torch.exp(log_beta[..., -1]) + + # calculate log(m_i) = sum_{k=1}^i log(η_k) + log_m = torch.cumsum(log_eta, dim=-1) + m_T = torch.exp(log_m[..., -1]) + + # cal log(n_{i,j}) + log_n = cal_n_log(log_theta, log_eta, seq_len) + n_T = torch.exp(log_n[..., -1]) + + # cal log(f_t) + f = cal_f_log(log_beta, seq_len, log_m) + f_T = f[..., -1] + + # cal log(G_{i,j}) + G = cal_G_log(log_beta, log_n, seq_len) + # get log(g_j) = log(G_{T,j}) + g = G[..., -1, :] + + return log_beta, beta_T, f, f_T, g, G, m_T, n_T + + +def combine_params_log(theta, alpha, eta, seq_len): + """ + log space Titians + + Parameters: + - theta: θ + - alpha: α + - eta: η + - seq_len: sequence length + + Returns: + - beta, beta_T, f, f_T, g, G, m_T, n_T + """ + # convert to log space + log_theta = torch.log(theta.squeeze(-1)) + log_alpha_complement = torch.log(1 - alpha.squeeze(-1)) + log_eta = torch.log(eta.squeeze(-1)) + + # combine params in log space + log_beta, beta_T, f, f_T, g, G, m_T, n_T = _combine_params_log( + log_theta, log_alpha_complement, log_eta, seq_len + ) + + # convert back to normal space + beta = torch.exp(log_beta) + + return beta, beta_T, f, f_T, g, G, m_T, n_T diff --git a/fla/ops/titans/naive.py b/fla/ops/titans/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..2a1bd4b0a8fe1d6f8efe03dc4be3a3e4854cde98 --- /dev/null +++ b/fla/ops/titans/naive.py @@ -0,0 +1,375 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn.functional as F + +from fla.ops.titans.log_impl import combine_params_log + + +def cal_n(theta, eta, seq_len): + n = torch.zeros(*theta.shape, seq_len, dtype=theta.dtype).to( + theta.device + ) # [batch_size, num_heads, seq_len, seq_len] + + # 1. deal with diagonal elements + indices = torch.arange(seq_len, device=theta.device) + n[..., indices, indices] = theta[..., indices] + + # 2. Create a cumulative product matrix + # First create a mask to mark the positions where eta needs to be multiplied + mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(theta.device) + # Convert mask to boolean type + mask = mask.bool() + # Expand eta to match the target shape + eta_expanded = eta.unsqueeze(-2).expand(*theta.shape[:-1], seq_len, seq_len) + # Create a matrix filled with 1s for cumulative product + cumulative = torch.ones_like(eta_expanded) + cumulative = torch.where(mask, eta_expanded, cumulative) + # Calculate the cumulative product + cumulative_prod = torch.cumprod(cumulative, dim=-1) + + # 3. Calculate non-diagonal elements + # Create an expanded version of theta + theta_expanded = theta.unsqueeze(-1).expand(*theta.shape[:-1], seq_len, seq_len) + # Create a mask to keep only the upper triangular part (excluding the diagonal) + upper_triangular = torch.triu(torch.ones_like(n), diagonal=1).bool() + # Combine theta and cumulative product + n = torch.where(upper_triangular, theta_expanded * cumulative_prod, n) + return n + + +def cal_f(beta, seq_len, m): + a = torch.tril(beta.to(torch.float32).unsqueeze(-1).expand(*beta.shape, seq_len), 0) + ratio = (m.to(torch.float32) / beta.to(torch.float32)).unsqueeze(-1) + f = torch.matmul(a, ratio).squeeze(-1) + return f.to(beta.dtype) + + +def cal_G(beta, n, seq_len): + i_indices = torch.arange(seq_len, device=beta.device) + j_indices = torch.arange(seq_len, device=beta.device) + k_indices = torch.arange(seq_len, device=beta.device) + beta_ratio = beta[..., :, None] / beta[..., None, :] # [..., i, k] + + # create mask + k_mask = (k_indices[None, None, :] >= j_indices[None, :, None]) & ( + k_indices[None, None, :] <= i_indices[:, None, None] + ) + + # use mask to filter out invalid values + masked_beta_ratio = beta_ratio[..., :, None, :] * k_mask # [..., i, j, k] + masked_n = n[..., None, :, :] * k_mask # [..., i, j, k] + # calculate G + G = torch.sum(masked_beta_ratio * masked_n, dim=-1) # [..., i, j] + return G + + +def combine_params(theta, alpha, eta, seq_len): + theta = theta.squeeze(-1) + eta = eta.squeeze(-1) + alpha = alpha.squeeze(-1) + beta = torch.cumprod(1 - alpha, dim=-1) # β_t = ∏(1 - α_t) in titans paper + beta_T = beta[..., -1] # β_T + # Calculate m_i = ∏(k=1 to i) η_k + m = torch.cumprod(eta, dim=-1) # [batch_size, num_heads, seq_len] + m_T = m[..., -1] # m_T + # Calculate n_{i,j} + # We need to calculate ∏(k=j+1 to i) η_k for each i,j pair + # # this may be optimized + # n = torch.zeros(*theta.shape, seq_len, dtype = theta.dtype).to( + # theta.device) # [batch_size, num_heads, seq_len, seq_len] + # for i in range(seq_len): + # for j in range(i + 1): + # if i == j: + # n[..., j, i] = theta[..., j] + # else: + # # Calculate product of eta from j+1 to i + # eta_product = torch.prod(eta[..., j + 1:i + 1], dim = -1) + # n[..., j, i] = theta[..., j] * eta_product + + n = cal_n(theta, eta, seq_len) + n_T = n[..., -1] # [batch_size, num_heads, seq_len] + # Calculate f_t = ∑(i=1 to t) (β_t/β_i) m_i + # f = torch.zeros_like(theta) + # for t in range(seq_len): + # for i in range(t + 1): + # f[..., t] += (beta[..., t] / beta[..., i]) * m[..., i] + f = cal_f(beta, seq_len, m) + f_T = f[..., -1] # [batch_size, num_heads, seq_len] + # Calculate g_j = ∑(i=j to t) (β_t/β_i) n_{i,j} + # g = torch.zeros_like(theta) # [batch_size, num_heads, seq_len] + # for j in range(seq_len): + # for i in range(j, seq_len): + # g[..., j] += (beta[..., -1] / beta[..., i]) * n[..., j, i] + # G = torch.zeros(*beta.shape[:-1], seq_len, seq_len, device = beta.device) + # # Fill in the lower triangular part + # for i in range(seq_len): # row + # for j in range(i + 1): # column + # # Sum from k=j to i + # for k in range(j, i + 1): + # G[..., i, j] += (beta[..., i] / beta[..., k]) * n[..., j, k] + G = cal_G(beta, n, seq_len) + g = G[:, :, -1, :] # [batch_size, num_heads, seq_len] + # g2, G2 = compute_g_and_G(beta, n, seq_len) + return beta, beta_T, f, f_T, g, G, m_T, n_T + + +def titans_linear( + q, k, v, w, b, theta, alpha, eta, eps, chunk_size, initial_state, output_final_state +): + """ + Implementation of Titans Linear function based on the update rules: + M_t = (1 - alpha_t) * M_{t-1} + S_t + S_t = eta_t * S_{t-1} - theta_t * nabla_l(M_{t-1}; x_t) + + Args: + q: Query tensor + k: Key tensor + v: Value tensor + w: Weight tensor + b: Bias tensor + theta: Learning rate tensor + alpha: Momentum decay tensor + eta: Step size tensor + eps: Epsilon for numerical stability + initial_state: Initial state M_0 + output_final_state: Whether to output the final state + + Returns: + Tuple of (output tensor, final state) + """ + B, H, T, D = q.shape + device = q.device + w = w.reshape(H, 1, D).to(torch.float32) + b = b.reshape(H, 1, D).to(torch.float32) + # Initialize states + if initial_state is None: + M_prev = torch.zeros(B, H, D, D, device=device) + else: + M_prev = initial_state + M_prev_nabla = M_prev.clone() + S_prev = torch.zeros_like(M_prev) + outputs = [] + + # Process sequence step by step + for t in range(T): + # Get current step inputs + q_t = q[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim) + k_t = k[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim) + v_t = v[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim) + theta_t = theta[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim) + alpha_t = alpha[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim) + eta_t = eta[:, :, t: t + 1, :] # (batch_size, num_heads, 1, dim) + + # Compute gradient + km = k_t @ M_prev_nabla # (batch_size, num_heads, 1, dim) + reconstruction_target = v_t - k_t + mean = km.mean(-1, keepdim=True) + var = km.var(-1, unbiased=False, keepdim=True).to(torch.float32) + rstd = torch.sqrt(var + eps).to(torch.float32) + km_hat = (km - mean) / rstd + + grad = w * km_hat + b - reconstruction_target + grad = grad * w + # v_new = (D * grad - grad.sum(-1, keepdim = True) - km_hat * (grad * km_hat).sum(-1, keepdim = True)) / ( + # rstd * D) + v_new = D * grad - grad.sum(-1, keepdim=True) / (rstd * D) + proj_term = km_hat * (grad * km_hat).sum(-1, keepdim=True) / (rstd * D) + v_new = v_new - proj_term + # v_new = grad + + # Update S_t + S_t = eta_t * S_prev - 2 * theta_t * k_t.transpose(-2, -1) @ v_new + + # Update M_t + M_t = (1 - alpha_t) * M_prev + S_t + + # Store output + output_t = q_t @ M_t # (batch_size, num_heads, seq_len, dim) + mean = output_t.mean(dim=-1, keepdim=True) + var = output_t.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32) + rstd = torch.sqrt(var + eps).to(torch.float32) + output_t = output_t + (output_t - mean) / rstd * w + b + outputs.append(output_t) + + # Update states for next step + if (t + 1) % chunk_size == 0: + M_prev_nabla = M_t.clone() + M_prev = M_t + S_prev = S_t + + # Stack outputs along sequence dimension + output = torch.stack(outputs, dim=-2).squeeze( + -3 + ) # (batch_size, num_heads, seq_len, dim) + + if output_final_state: + return output, M_prev + return output, None + + +def chunk_titans_linear( + q, k, v, w, b, theta, alpha, eta, eps, chunk_size, initial_state, output_final_state +): + B, H, T, D = q.shape + num_batch = T // chunk_size + # [num_batch, B, num_heads, mini_batch_size, head_dim] + _q = q.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4) + _k = k.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4) + _v = v.reshape(B, H, num_batch, chunk_size, D).permute(2, 0, 1, 3, 4) + # [num_batch, B, num_heads, mini_batch_size, 1] + _eta = eta.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4) + _theta = theta.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4) + _alpha = alpha.reshape(B, H, num_batch, chunk_size, 1).permute(2, 0, 1, 3, 4) + # [H, 1, D] + w = w.reshape(H, 1, D).to(torch.float32) + b = b.reshape(H, 1, D).to(torch.float32) + # [num_heads, 1, head_dim] + if initial_state is None: + M_prev = torch.zeros((B, H, D, D), device=v.device, dtype=v.dtype).to( + torch.float32 + ) + else: + M_prev = initial_state + + S_prev = torch.zeros_like(M_prev) + + # [num_batch, B, num_heads, mini_batch_size, head_dim] + o = torch.empty_like(_v) + + for i in range(num_batch): + q_i, k_i, v_i, eta_i, theta_i, alpha_i = [ + x[i] for x in [_q, _k, _v, _eta, _theta, _alpha] + ] + + # beta, beta_T, f, f_T, g, G, m_T, n = combine_params(theta_i, alpha_i, eta_i, chunk_size) + beta, beta_T, f, f_T, g, G, m_T, n = combine_params_log( + theta_i, alpha_i, eta_i, chunk_size + ) + + m_T = m_T.unsqueeze(-1).unsqueeze(-1) + beta_T = beta_T.unsqueeze(-1).unsqueeze(-1) + f_T = f_T.unsqueeze(-1).unsqueeze(-1) + g_diag = torch.diag_embed(g).to(q_i.dtype) + n = torch.diag_embed(n).to(q_i.dtype) + beta = torch.diag_embed(beta).to(q_i.dtype) + f = torch.diag_embed(f).to(q_i.dtype) + km = k_i @ M_prev + reconstruction_target = v_i - k_i + + mean = km.mean(-1, True) + var = km.var(-1, unbiased=False, keepdim=True).to(torch.float32) + rstd = torch.sqrt(var + eps).to(torch.float32) + km_hat = (km - mean) / rstd + + grad = w * km_hat + b - reconstruction_target + grad *= w + v_new = D * grad - grad.sum(-1, keepdim=True) / (rstd * D) + proj_term = km_hat * (grad * km_hat).sum(-1, keepdim=True) / (rstd * D) + v_new = v_new - proj_term + # v_new = (D * grad - grad.sum(-1, True)) + # print(f"Projection term stats: min={torch.abs(beta_T).min()}") + + # v_new = grad + + Attn = torch.tril(q_i @ k_i.transpose(-2, -1)) * G + + # o_i + output_t = beta @ q_i @ M_prev + f @ q_i @ S_prev - 2 * Attn @ v_new + + M_t = ( + beta_T * M_prev + + f_T * S_prev + - 2 * (g_diag @ k_i).transpose(-1, -2) @ v_new + ) + # cal S_T from S_0 + S_t = m_T * S_prev - 2 * (n @ k_i).transpose(-1, -2) @ v_new + # layer norm with residuals + mean = output_t.mean(dim=-1, keepdim=True) + var = output_t.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32) + rstd = torch.sqrt(var + eps).to(torch.float32) + output_t = output_t + (output_t - mean) / rstd * w + b + o[i] = output_t + S_prev = S_t + M_prev = M_t + + # [B, num_mini_batch, mini_batch_size, num_heads, head_dim] + o = o.permute(1, 2, 0, 3, 4).reshape(B, H, T, D) + M_prev = M_prev if output_final_state else None + return o, M_prev + + +# most of the code is copied from ttt +def chunk_titans_linear_ref( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + theta: torch.Tensor, + alpha: torch.Tensor, + eta: torch.Tensor, + eps: float = 1e-6, + chunk_size: int = 16, # chunk size + initial_state: torch.Tensor = None, + output_final_state: bool = False, + head_first: bool = True, + use_chunk: bool = True, +): + assert q.dtype == k.dtype == v.dtype + assert k.shape[-1] == v.shape[-1], "DK must equal to DV." + if not head_first: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + eta = eta.transpose(1, 2) + alpha = alpha.transpose(1, 2) + theta = theta.transpose(1, 2) + seq_len = q.shape[-2] + pad_len = (chunk_size - (seq_len % chunk_size)) % chunk_size + if pad_len > 0: + q = F.pad(q, (0, 0, 0, pad_len)) + k = F.pad(k, (0, 0, 0, pad_len)) + v = F.pad(v, (0, 0, 0, pad_len)) + theta = F.pad(theta, (0, 0, 0, pad_len)) + alpha = F.pad(alpha, (0, 0, 0, pad_len)) + eta = F.pad(eta, (0, 0, 0, pad_len)) + theta[:, :, -1, :] = theta[:, :, -(pad_len + 1), :] + alpha[:, :, -1, :] = alpha[:, :, -(pad_len + 1), :] + eta[:, :, -1, :] = eta[:, :, -(pad_len + 1), :] + assert q.shape[-2] % chunk_size == 0, "Sequence length should be a multiple of BT." + q, k, v, w, b = map(lambda x: x.to(torch.float32), [q, k, v, w, b]) + if use_chunk: + o, final_state = chunk_titans_linear( + q, + k, + v, + w, + b, + theta, + alpha, + eta, + eps, + chunk_size, + initial_state, + output_final_state, + ) + else: + o, final_state = titans_linear( + q, + k, + v, + w, + b, + theta, + alpha, + eta, + eps, + chunk_size, + initial_state, + output_final_state, + ) + o = o[:, :, :seq_len, :] + if not head_first: + o = o.transpose(1, 2) + return o, final_state diff --git a/fla/ops/ttt/__init__.py b/fla/ops/ttt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3d66e42f18785546b4e3be77abc2c91519e3bb9 --- /dev/null +++ b/fla/ops/ttt/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_ttt_linear +from .fused_chunk import fused_chunk_ttt_linear + +__all__ = [ + 'fused_chunk_ttt_linear', + 'chunk_ttt_linear' +] diff --git a/fla/ops/ttt/chunk.py b/fla/ops/ttt/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..6342364268cfc56b3b87601e902f0ab40d6a1b7f --- /dev/null +++ b/fla/ops/ttt/chunk.py @@ -0,0 +1,1539 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +from fla.modules.layernorm import group_norm +from fla.ops.common.utils import prepare_chunk_indices, prepare_chunk_offsets +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8] + ], + key=['BT', 'BK', 'BV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_ttt_linear_fwd_kernel_h( + k, + v, + v_new, + eta, + w, + b, + eps, + h, + hb, + h0, + hb0, + ht, + hbt, + offsets, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_INITIAL_STATE_B: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + # [BV] + b_hb = tl.zeros([BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + if USE_INITIAL_STATE_B: + p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32) + + offs = tl.arange(0, BV) + b_w = tl.load(w + i_h * V + offs, mask=offs < V, other=0.) + b_b = tl.load(b + i_h * V + offs, mask=offs < V, other=0.) + + for i_t in range(NT): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_hb = tl.make_block_ptr(hb + (i_nh * NT + i_t) * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + else: + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_hb = tl.make_block_ptr(hb + ((boh + i_t) * H + i_h) * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_hb, b_hb.to(p_hb.dtype.element_ty), boundary_check=(0,)) + if HEAD_FIRST: + p_k = tl.make_block_ptr(k+i_nh*T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_eta_last = eta+i_nh*T+T-1 if i_t == NT-1 else eta+i_nh*T+i_t*BT+BT-1 + else: + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0)) + p_eta_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H + b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero") + b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero") + + b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :] + b_kh = tl.where((offs < V)[None, :], b_kh, 0.) + mean = tl.sum(b_kh, axis=1, keep_dims=True) / V + xbar = tl.where((offs < V)[None, :], b_kh - mean, 0.) + var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V + rstd = 1 / tl.sqrt(var.to(tl.float32) + eps) + b_kh_hat = (b_kh - mean) * rstd + + b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \ + b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k) + b_v = tl.where((offs < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.) + b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype) + * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V + tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + b_eta_last = tl.load(p_eta_last) + b_h = b_h - tl.dot(b_eta_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False) + b_hb = b_hb - tl.sum(b_eta_last * b_v2.to(b_k.dtype), axis=0) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_hbt = tl.make_block_ptr(hbt + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_hbt, b_hb.to(p_hbt.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3] + ], + key=['BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_ttt_linear_fwd_kernel_o( + q, + k, + v, + eta, + h, + hb, + o, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K) + k += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K) + v += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V) + eta += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) + o += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V) + h += ((i_bh * NT + i_t) * K * V) if HEAD_FIRST else ((i_tg * H + i_h) * K * V) + hb += ((i_bh * NT + i_t) * V) if HEAD_FIRST else ((i_tg * H + i_h) * V) + stride_qk = K if HEAD_FIRST else H*K + stride_vo = V if HEAD_FIRST else H*V + stride_eta = 1 if HEAD_FIRST else H + + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (0, i_t * BT), (BK, BT), (0, 1)) + p_eta = tl.make_block_ptr(eta, (T,), (stride_eta,), (i_t * BT,), (BT,), (0,)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (0, i_v * BV), (BK, BV), (1, 0)) + p_hb = tl.make_block_ptr(hb, (V,), (1,), (i_v * BV,), (BV,), (0,)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero") + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero") + # [BT, 1] + b_eta = tl.load(p_eta, boundary_check=(0,), padding_option="zero") + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1), padding_option="zero") + # [BV] + b_hb = tl.load(p_hb, boundary_check=(0,), padding_option="zero") + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o = tl.dot(b_q, b_h, allow_tf32=False) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A = tl.dot(b_q, b_k, allow_tf32=False) + + o_i = tl.arange(0, BT) + m_A = o_i[:, None] >= o_i[None, :] + b_A = tl.where(m_A, b_A, 0) + b_Ae = tl.where(m_A, b_eta[:, None], 0.0) + + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero") + b_o = (b_o - tl.dot(b_eta[:, None] * b_A.to(b_v.dtype), b_v, allow_tf32=False)) * scale + b_o += b_hb[None, :] - tl.dot(b_Ae.to(b_v.dtype), b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8] + ], + key=['BT', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_ttt_linear_bwd_kernel_h( + k, + v, + v_new, + eta, + w, + b, + eps, + h, + h0, + hb0, + x, + y, + r, + offsets, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_INITIAL_STATE_B: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + # [BV] + b_hb = tl.zeros([BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + if USE_INITIAL_STATE_B: + p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32) + + offs = tl.arange(0, BV) + b_w = tl.load(w + i_h * V + offs, mask=offs < V, other=0.) + b_b = tl.load(b + i_h * V + offs, mask=offs < V, other=0.) + + for i_t in range(NT): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + if HEAD_FIRST: + p_k = tl.make_block_ptr(k+i_nh*T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_x = tl.make_block_ptr(x+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_y = tl.make_block_ptr(y+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_r = tl.make_block_ptr(r+i_nh*T, (T, 1), (1, 1), (i_t * BT, 0), (BT, 1), (1, 0)) + p_eta_last = eta+i_nh*T+T-1 if i_t == NT-1 else eta+i_nh*T+i_t*BT+BT-1 + else: + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0)) + p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0)) + p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0)) + p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0)) + p_eta_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H + b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero") + b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero") + + b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :] + b_kh = tl.where((offs < V)[None, :], b_kh, 0.) + mean = tl.sum(b_kh, axis=1, keep_dims=True) / V + xbar = tl.where((offs < V)[None, :], b_kh - mean, 0.) + var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V + rstd = 1 / tl.sqrt(var.to(tl.float32) + eps) + b_kh_hat = (b_kh - mean) * rstd + + b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \ + b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k) + b_v = tl.where((offs < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.) + b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype) + * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V + tl.store(p_x, b_kh_hat.to(p_x.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_y, b_v.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_r, rstd.to(p_r.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + b_eta_last = tl.load(p_eta_last) + b_h = b_h - tl.dot(b_eta_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False) + b_hb = b_hb - tl.sum(b_eta_last * b_v2.to(b_k.dtype), axis=0) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [4] + ], + key=['BT', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_ttt_linear_bwd_kernel_dv_local( + q, + k, + eta, + do, + dv, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + eta += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) + do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + dv += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + stride_qk = K if HEAD_FIRST else H*K + stride_vo = V if HEAD_FIRST else H*V + stride_eta = 1 if HEAD_FIRST else H + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, b_q) + + p_eta = tl.make_block_ptr(eta, (T,), (stride_eta,), (i_t * BT,), (BT,), (0,)) + b_eta = tl.load(p_eta, boundary_check=(0,)) + mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]) + b_A = - tl.where(mask, b_A * scale * b_eta[None, :], 0).to(do.dtype.element_ty) + b_Ae = - tl.where(mask, b_eta[None, :], 0).to(do.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.dot(b_A.to(b_do.dtype), b_do) + tl.dot(b_Ae.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_FINAL_STATE_GRADIENT_B': lambda args: args['dhbt'] is not None, + 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, + 'USE_INITIAL_STATE_B': lambda args: args['dhb0'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [2, 4, 8, 16] + ], + key=['BT', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_ttt_linear_bwd_kernel_norm( + q, + k, + v, + v_new, + x, + y, + r, + w, + b, + eta, + h, + dht, + dhbt, + dh0, + dhb0, + do, + dh, + dhb, + dv, + dv_new, + dk, + dw, + db, + offsets, + chunk_offsets, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT_B: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_INITIAL_STATE_B: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + # [BV] + b_dhb = tl.zeros([BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1), padding_option="zero") + if USE_FINAL_STATE_GRADIENT_B: + p_dhbt = tl.make_block_ptr(dhbt + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + b_dhb += tl.load(p_dhbt, boundary_check=(0,), padding_option="zero") + + # [BV] + offs_v = tl.arange(0, BV) + offs_t = tl.arange(0, BT) + b_w = tl.load(w + i_h * V + offs_v, mask=offs_v < V, other=0.) + b_b = tl.load(b + i_h * V + offs_v, mask=offs_v < V, other=0.) + b_dw = tl.zeros([BV,], dtype=b_w.dtype) + b_db = tl.zeros([BV,], dtype=b_b.dtype) + p_dw = tl.make_block_ptr(dw + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + p_db = tl.make_block_ptr(db + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + + for i_t in range(NT - 1, -1, -1): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dhb = tl.make_block_ptr(dhb + (i_nh * NT + i_t) * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + else: + p_h = tl.make_block_ptr(h + ((boh+i_t) * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dhb = tl.make_block_ptr(dhb + ((boh+i_t) * H + i_h) * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dhb, b_dhb.to(p_dhb.dtype.element_ty), boundary_check=(0,)) + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_nh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_x = tl.make_block_ptr(x + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_y = tl.make_block_ptr(y + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv_new = tl.make_block_ptr(dv_new + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_nh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_r = tl.make_block_ptr(r + i_nh * T, (T, 1), (1, 1), (i_t * BT, 0), (BT, 1), (1, 0)) + p_eta_last = eta + i_nh*T + T - 1 if i_t == NT-1 else eta + i_nh*T + i_t*BT + BT - 1 + else: + p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv_new = tl.make_block_ptr(dv_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, i_k * BK), (BT, BK), (1, 0)) + p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0)) + p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0)) + p_eta_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H + b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero") + b_dv_new = tl.load(p_dv_new, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype) + b_eta_last = tl.load(p_eta_last) + b_dv_new -= tl.dot(b_eta_last * b_k, b_dh.to(b_k.dtype)) + b_dv_new -= b_eta_last * b_dhb.to(b_k.dtype)[None, :] + + b_v_new = tl.load(p_v_new, boundary_check=(0, 1), padding_option="zero") + b_x = tl.load(p_x, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype) + b_y = tl.load(p_y, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype) + b_rstd = tl.load(p_r, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + b_dy = b_rstd * (b_dv_new * V - tl.sum(b_dv_new, axis=1, keep_dims=True) - + b_x * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V + b_dx = -b_rstd * (b_dv_new * tl.sum(b_x * b_y, axis=1, keep_dims=True) + + b_y * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V + b_drstd = tl.sum(b_dv_new.to(b_rstd.dtype) * b_v_new.to(b_rstd.dtype) / b_rstd, axis=1, keep_dims=True) + + b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero") + b_w = b_w.to(b_k.dtype) + b_b = b_b.to(b_k.dtype) + b_dv = -b_w * b_dy.to(b_k.dtype) + b_dk = b_w * b_dy.to(b_k.dtype) + b_dw += tl.sum(2 * b_w * b_x * b_dy.to(b_k.dtype) + + (b_b - b_v.to(b_k.dtype) + b_k) * b_dy.to(b_k.dtype), axis=0).to(b_dw.dtype) + b_db += tl.sum(b_w * b_dy.to(b_k.dtype), axis=0).to(b_db.dtype) + b_dx = b_dx.to(b_k.dtype) + b_w * b_w * b_dy.to(b_k.dtype) + + # d_rstd, dx --> dkh --> dk, dh + b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero") + b_h = tl.load(p_h, boundary_check=(0, 1), padding_option="zero") + b_do = tl.load(p_do, boundary_check=(0, 1), padding_option="zero") + b_q = (b_q * scale).to(b_q.dtype) + b_dkh = b_rstd * (V * b_dx - tl.sum(b_dx, axis=1, keep_dims=True) - + b_x * tl.sum(b_x * b_dx, axis=1, keep_dims=True)) / V + b_dkh -= b_rstd * b_rstd * b_drstd * b_x / V + b_dkh = tl.where((offs_v < V)[None, :] * (offs_t < T-i_t*BT)[:, None], b_dkh, 0.) + b_dk += tl.dot(b_dkh, b_h.to(b_dkh.dtype)).to(b_k.dtype) + b_dh += tl.dot(b_q, b_do.to(b_q.dtype)) + tl.dot(tl.trans(b_k).to(b_dkh.dtype), b_dkh) + b_dhb += tl.sum(b_do + b_dkh, axis=0) + b_dh = tl.where((offs_v < V)[None, :], b_dh, 0.) + b_dhb = tl.where((offs_v < V), b_dhb, 0.) + + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0,)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + if USE_INITIAL_STATE_B: + p_dhb0 = tl.make_block_ptr(dhb0+i_nh*V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + tl.store(p_dhb0, b_dhb.to(p_dhb0.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3] + ], + key=['BT', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dqke( + q, + k, + v, + e, + h, + do, + dh, + dhb, + dq, + dk, + de, + offsets, + indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + v += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + h += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K * V + dh += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K * V + dhb += (i_bh * NT + i_t) * V if HEAD_FIRST else (i_tg * H + i_h) * V + q += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + dq += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + dk += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + e += i_bh * T if HEAD_FIRST else (bos * H + i_h) + de += i_bh * T if HEAD_FIRST else (bos * H + i_h) + stride_qk = K if HEAD_FIRST else H*K + stride_vo = V if HEAD_FIRST else H*V + stride_e = 1 if HEAD_FIRST else H + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + b_de = tl.zeros([BT,], dtype=tl.float32) + + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_e_last = (e + (i_t*BT+BT-1)*stride_e) if (i_t*BT+BT) <= T else (e + (T-1)*stride_e) + i_last = (BT-1) if (i_t*BT+BT) <= T else (T % BT-1) + mask = (tl.arange(0, BT) == i_last) + b_e_last = tl.load(p_e_last) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dhb = tl.make_block_ptr(dhb, (V,), (1,), (i_v * BV,), (BV,), (0,)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BV] + b_dhb = tl.load(p_dhb, boundary_check=(0,)) + # [BT, BV] @ [BV, BT] -> [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v)) + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dk -= b_e_last * tl.dot(b_v, b_dh.to(b_v.dtype)) + b_de -= mask * tl.sum(tl.trans(b_dh) * tl.dot(tl.trans(b_k), b_v.to(b_k.dtype))) + b_de -= mask * tl.sum(b_dhb * tl.sum(b_v, axis=0).to(b_k.dtype)) + + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_e = tl.make_block_ptr(e, (T,), (stride_e,), (i_t * BT,), (BT,), (0,)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_e = tl.load(p_e, boundary_check=(0,)) + + p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_de = tl.make_block_ptr(de, (T,), (stride_e,), (i_t * BT,), (BT,), (0,)) + + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0) + b_ds = b_ds.to(b_k.dtype) + b_dq -= tl.dot(b_ds, b_k) * b_e[:, None] + b_dk -= tl.dot(tl.trans(b_ds), b_q * b_e[:, None]) * scale + b_de -= tl.sum(scale * tl.dot(b_ds, b_k) * b_q, axis=1) + b_de -= tl.sum(b_ds, axis=1) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_de, b_de.to(p_de.dtype.element_ty), boundary_check=(0,)) + + +def chunk_ttt_linear_fwd_h( + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + eps: float, + initial_state: Optional[torch.Tensor] = None, + initial_state_bias: Optional[torch.Tensor] = None, + output_final_state: bool = False, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 16, +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = chunk_size + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT) + BK = triton.next_power_of_2(K) + BV = triton.next_power_of_2(V) + assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128." + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + assert NV == 1, 'NV > 1 is not supported by TTT update rule.' + + if head_first: + h = k.new_empty(B, H, NT, K, V) + hb = k.new_empty(B, H, NT, 1, V) + else: + h = k.new_empty(B, NT, H, K, V) + hb = k.new_empty(B, NT, H, 1, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + final_state_bias = k.new_empty(N, H, 1, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(v) + grid = (NK, NV, N * H) + + chunk_ttt_linear_fwd_kernel_h[grid]( + k=k, + v=v, + v_new=v_new, + eta=eta, + w=w, + b=b, + eps=eps, + h=h, + hb=hb, + h0=initial_state, + hb0=initial_state_bias, + ht=final_state, + hbt=final_state_bias, + offsets=offsets, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + NT=NT, + HEAD_FIRST=head_first + ) + return h, hb, v_new, final_state, final_state_bias + + +def chunk_ttt_linear_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + eta: torch.Tensor, + h: torch.Tensor, + hb: torch.Tensor, + scale: Optional[float] = None, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> torch.Tensor: + if head_first: + B, H, T, K, V = *q.shape, v.shape[-1] + else: + B, T, H, K, V = *q.shape, v.shape[-1] + if scale is None: + scale = k.shape[-1] ** -0.5 + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BK = triton.next_power_of_2(K) + BV = triton.next_power_of_2(V) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + assert NV == 1, 'NV > 1 is not supported by TTT update rule.' + + o = torch.empty_like(v) + + grid = (NV, NT, B * H) + chunk_ttt_linear_fwd_kernel_o[grid]( + q, + k, + v, + eta, + h, + hb, + o, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return o + + +def chunk_ttt_linear_bwd_h( + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + eps: float, + initial_state: Optional[torch.Tensor] = None, + initial_state_bias: Optional[torch.Tensor] = None, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 16, +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = chunk_size + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT) + BK = triton.next_power_of_2(K) + BV = triton.next_power_of_2(V) + assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128." + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + assert NV == 1, 'NV > 1 is not supported by TTT update rule.' + + if head_first: + h = k.new_empty(B, H, NT, K, V) + rstd = v.new_empty(B, H, T, 1, dtype=torch.float32) + else: + h = k.new_empty(B, NT, H, K, V) + rstd = v.new_empty(B, T, H, 1, dtype=torch.float32) + x = torch.empty_like(v) + y = torch.empty_like(v) + + v_new = torch.empty_like(v) + grid = (NK, NV, N * H) + + chunk_ttt_linear_bwd_kernel_h[grid]( + k=k, + v=v, + v_new=v_new, + eta=eta, + w=w, + b=b, + eps=eps, + h=h, + h0=initial_state, + hb0=initial_state_bias, + x=x, + y=y, + r=rstd, + offsets=offsets, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + NT=NT, + HEAD_FIRST=head_first + ) + return h, v_new, x, y, rstd + + +def chunk_ttt_linear_bwd_dv_local( + q: torch.Tensor, + k: torch.Tensor, + eta: torch.Tensor, + do: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 16 +) -> torch.Tensor: + if head_first: + B, H, T, K, V = *k.shape, do.shape[-1] + else: + B, T, H, K, V = *k.shape, do.shape[-1] + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BK = min(triton.next_power_of_2(K), 128) + BV = min(triton.next_power_of_2(V), 128) + + dv = torch.empty_like(do) + grid = (NT, B * H) + chunk_ttt_linear_bwd_kernel_dv_local[grid]( + q, + k, + eta, + do, + dv, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dv + + +def chunk_ttt_linear_bwd_norm( + q: torch.Tensor, # [B, H, L, D] + k: torch.Tensor, # [B, H, L, D] + v: torch.Tensor, # [B, H, L, D] + v_new: torch.Tensor, # [B, H, L, D] + x: torch.Tensor, # [B, H, L, D] + y: torch.Tensor, # [B, H, L, D] + rstd: torch.Tensor, # [B, H, L, 1] + w: torch.Tensor, # [H, D] + b: torch.Tensor, # [H, D] + eta: torch.Tensor, # [B, H, L, 1] + h0: torch.Tensor, # [B, H, D, D] + hb0: torch.Tensor, # [B, H, 1, D] + h: torch.Tensor, # [B, H, NT, D, D] + dht: Optional[torch.Tensor], # [B, H, D, D] + dhbt: Optional[torch.Tensor], # [B, H, 1, D] + dv_new: Optional[torch.Tensor], # [B, H, L, D] + do: torch.Tensor, # [B, H, L, D] + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 16 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # torch implementation of `dkh, dw, db, dk, dv` for LN^2 + assert offsets is None, "bwd of varlen is not implemented yet." + if head_first: + B, H, T, K, V = *q.shape, do.shape[-1] + else: + B, T, H, K, V = *q.shape, do.shape[-1] + BT = chunk_size + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT) + + BK = triton.next_power_of_2(K) + BV = triton.next_power_of_2(V) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported by TTT.' + assert NV == 1, 'NV > 1 is not supported by TTT.' + + if head_first: + dh = q.new_empty(B, H, NT, K, V) + dhb = q.new_empty(B, H, NT, 1, V) + else: + dh = q.new_empty(B, NT, H, K, V) + dhb = q.new_empty(B, NT, H, 1, V) + dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None + dhb0 = torch.empty_like(hb0, dtype=torch.float32) if hb0 is not None else None + dv = torch.empty_like(v) + dk = torch.empty_like(k) + dw = w.new_empty(B, H, V) + db = b.new_empty(B, H, V) + + grid = (NK, NV, N * H) + chunk_ttt_linear_bwd_kernel_norm[grid]( + q=q, + k=k, + v=v, + v_new=v_new, + x=x, + y=y, + r=rstd, + w=w, + b=b, + eta=eta, + h=h, + dht=dht, + dhbt=dhbt, + dh0=dh0, + dhb0=dhb0, + do=do, + dh=dh, + dhb=dhb, + dv=dv, + dv_new=dv_new, + dk=dk, + dw=dw, + db=db, + offsets=offsets, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + dw = dw.sum(dim=0) + db = db.sum(dim=0) + return dh, dhb, dh0, dhb0, dv, dk, dw, db + + +def chunk_ttt_linear_bwd_norm_ref( + q: torch.Tensor, # [B, H, L, D] + k: torch.Tensor, # [B, H, L, D] + v: torch.Tensor, # [B, H, L, D] + v_new: torch.Tensor, # [B, H, L, D] + kh: torch.Tensor, # [B, H, L, D] + y: torch.Tensor, # [B, H, L, D] + w: torch.Tensor, # [H, D] + b: torch.Tensor, # [H, D] + eta: torch.Tensor, # [B, H, L, 1] + h0: torch.Tensor, # [B, H, D, D] + h: torch.Tensor, # [B, H, NT, D, D] + dht: Optional[torch.Tensor], # [B, H, D, D] + dv_new: Optional[torch.Tensor], # [B, H, L, D] + do: torch.Tensor, # [B, H, L, D] + scale: float, + eps: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 16 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # torch implementation of `dkh, dw, db, dk, dv` for LN^2 + assert offsets is None, "bwd of varlen is not implemented yet." + if head_first: + B, H, T, K, V = *q.shape, do.shape[-1] + else: + B, T, H, K, V = *q.shape, do.shape[-1] + # [B, L, H, D] -> [B, H, L, D] + q, k, v, v_new, kh, y, h, eta, dv_new, do = [ + x.transpose(1, 2) for x in + [q, k, v, v_new, kh, y, h, eta, dv_new, do] + ] + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + pad_len = (BT - (T % BT)) % BT + if pad_len > 0: + q, k, v, v_new, kh, y, eta, dv_new, do = [ + F.pad(x, (0, 0, 0, pad_len)) for x in + [q, k, v, v_new, kh, y, eta, dv_new, do] + ] + eta[:, :, -1, :] = eta[:, :, -(pad_len+1), :] + # [NT, B, H, BT, D] + q, k, v, v_new, kh, y, eta, dv_new, do = [ + x.reshape(B, H, NT, BT, -1).permute(2, 0, 1, 3, 4) for x in + [q, k, v, v_new, kh, y, eta, dv_new, do] + ] + h = h.permute(2, 0, 1, 3, 4) + + # allocate + dh = q.new_zeros(NT, B, H, K, V) + dv = torch.zeros_like(v) + dk = torch.zeros_like(k) + dw = torch.zeros_like(w) + db = torch.zeros_like(b) + # recurrent state + b_dh = dht if dht is not None else torch.zeros_like(dh[0]) + b_dh = b_dh.to(torch.float32) + + # [H, 1, D] + _w = w.reshape(H, 1, V).to(torch.float32) + _b = b.reshape(H, 1, V).to(torch.float32) + + # d_state passing + for i_t in range(NT - 1, -1, -1): + dh[i_t] = b_dh.to(dh.dtype) + # [B, H, BT, D] + _q, _k, _v, _v_new, _kh, _y, _h, _eta, _dv_new, _do = [ + x[i_t].to(torch.float32) for x in + (q, k, v, v_new, kh, y, h, eta, dv_new, do) + ] + _dv_new -= (_eta[:, :, -1, :, None] * _k) @ b_dh + + mean = _kh.mean(dim=-1, keepdim=True) + var = _kh.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32) + rstd = 1 / torch.sqrt(var + eps).to(torch.float32) + x = (_kh - mean) * rstd + # [B, H, BT, D] + dy = rstd * (_dv_new*V - _dv_new.sum(dim=-1, keepdim=True) - x*(x*_dv_new).sum(dim=-1, keepdim=True)) / V + dx = -rstd * (_dv_new*(x*_y).sum(dim=-1, keepdim=True) + _y*(x*_dv_new).sum(dim=-1, keepdim=True)) / V + d_rstd = (_dv_new * _v_new / rstd).sum(dim=-1, keepdim=True) + + dv[i_t] = (-_w*dy).to(dv.dtype) + dk[i_t] += (_w*dy).to(dk.dtype) + dw += (2*_w*x*dy+(_b-_v+_k)*dy).sum(dim=(0, 2)).to(dw.dtype) + db += (_w*dy).sum(dim=(0, 2)).to(db.dtype) + dx += _w*_w*dy + + # d_rstd, dx --> dkh --> dk, dh + dkh = rstd * (V * dx - dx.sum(dim=-1, keepdim=True) - x * (x * dx).sum(dim=-1, keepdim=True)) / V + dkh -= rstd**2 * d_rstd * x / V + dk[i_t] += (dkh @ _h.transpose(-2, -1)).to(dk.dtype) + b_dh += (_q.transpose(-2, -1) * scale) @ _do + _k.transpose(-2, -1) @ dkh + dh0 = b_dh.to(torch.float32) if h0 is not None else None + + # [NT, B, H, BT, D] -> [B, H, T, D] + dv = dv.permute(1, 2, 0, 3, 4).reshape(B, H, -1, V)[:, :, :T, :] + dk = dk.permute(1, 2, 0, 3, 4).reshape(B, H, -1, K)[:, :, :T, :] + # [B, H, NT, D, D] + dh = dh.permute(1, 2, 0, 3, 4) + if not head_first: + dv, dk, dh = [x.transpose(1, 2) for x in (dv, dk, dh)] + dh, dv, dk, dw, db = [x.contiguous() for x in (dh, dv, dk, dw, db)] + dh0 = dh0.contiguous() if h0 is not None else None + return dh, dh0, dv, dk, dw, db + + +def chunk_ttt_linear_bwd_dqke( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + eta: torch.Tensor, + h: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + dhb: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 16, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + BK = triton.next_power_of_2(K) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K, BK) + assert NK == 1, "NK > 1 is not supported." + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + de = torch.empty_like(eta) + grid = (NK, NT, B * H) + + chunk_bwd_kernel_dqke[grid]( + q=q, + k=k, + v=v, + e=eta, + h=h, + do=do, + dh=dh, + dhb=dhb, + dq=dq, + dk=dk, + de=de, + offsets=offsets, + indices=indices, + scale=scale, + B=B, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dq, dk, de + + +def chunk_ttt_linear_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float, + eps: float, + initial_state: torch.Tensor, + initial_state_bias: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + BT: int = 16 +): + h, hb, v_new, final_state, final_state_bias = chunk_ttt_linear_fwd_h( + k=k, + v=v, + w=w, + b=b, + eta=eta, + eps=eps, + initial_state=initial_state, + initial_state_bias=initial_state_bias, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + o = chunk_ttt_linear_fwd_o( + q=q, + k=k, + v=v_new, + eta=eta, + h=h, + hb=hb, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + return o, final_state, final_state_bias + + +def chunk_ttt_linear_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float, + eps: float, + do: torch.Tensor, + dht: torch.Tensor, + dhbt: torch.Tensor, + BT: int = 16, + initial_state: torch.Tensor = None, + initial_state_bias: torch.Tensor = None, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True +): + h, v_new, x, y, rstd = chunk_ttt_linear_bwd_h( + k=k, + v=v, + w=w, + b=b, + eta=eta, + eps=eps, + initial_state=initial_state, + initial_state_bias=initial_state_bias, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dv_new = chunk_ttt_linear_bwd_dv_local( + q=q, + k=k, + eta=eta, + do=do, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dh, dhb, dh0, dhb0, dv, dk, dw, db = chunk_ttt_linear_bwd_norm( + q=q, + k=k, + v=v, + v_new=v_new, + x=x, + y=y, + rstd=rstd, + w=w, + b=b, + eta=eta, + h0=initial_state, + hb0=initial_state_bias, + h=h, + dht=dht, + dhbt=dhbt, + dv_new=dv_new, + do=do, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dq, dk2, de = chunk_ttt_linear_bwd_dqke( + q=q, + k=k, + v=v_new, + eta=eta, + h=h, + do=do, + dh=dh, + dhb=dhb, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dk.add_(dk2) + return dq, dk, dv, de, dw, db, dh0, dhb0 + + +class ChunkTTTLinearFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, w, b, BT, eta, scale, eps, initial_state, + initial_state_bias, output_final_state, offsets, head_first): + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None + o, final_state, final_state_bias = chunk_ttt_linear_fwd( + q=q, + k=k, + v=v, + w=w, + b=b, + eta=eta, + scale=scale, + eps=eps, + BT=BT, + initial_state=initial_state, + initial_state_bias=initial_state_bias, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + ) + ctx.save_for_backward(q, k, v, eta, w, b, initial_state, initial_state_bias) + ctx.BT = BT + ctx.scale = scale + ctx.eps = eps + ctx.offsets = offsets + ctx.indices = indices + ctx.head_first = head_first + return o.to(q.dtype), final_state, final_state_bias + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht, dhbt): + q, k, v, eta, w, b, initial_state, initial_state_bias = ctx.saved_tensors + dq, dk, dv, de, dw, db, dh0, dhb0 = chunk_ttt_linear_bwd( + q=q, + k=k, + v=v, + w=w, + b=b, + eta=eta, + scale=ctx.scale, + eps=ctx.eps, + do=do, + dht=dht, + dhbt=dhbt, + BT=ctx.BT, + initial_state=initial_state, + initial_state_bias=initial_state_bias, + offsets=ctx.offsets, + indices=ctx.indices, + head_first=ctx.head_first + ) + return dq.to(q), dk.to(k), dv.to(v), dw.to(w), db.to(b), None, de.to(eta), None, None, dh0, dhb0, None, None, None + + +def norm_residual(x, weight, bias, eps, head_first): + # GroupNorm and Residual + if head_first: + B, H, T, D = x.shape + x = x.transpose(1, 2) + x += group_norm( + x.reshape(B, T, -1).clone(), + weight=weight.reshape(-1).clone(), + bias=bias.reshape(-1).clone(), + eps=eps, + num_groups=H, + ).reshape(x.shape) + x = x.transpose(1, 2) + else: + B, T, H, D = x.shape + x += group_norm( + x.reshape(B, T, -1).clone(), + weight=weight.reshape(-1).clone(), + bias=bias.reshape(-1).clone(), + eps=eps, + num_groups=H, + ).reshape(x.shape) + return x + + +def chunk_ttt_linear( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float = None, + eps: float = 1e-6, + chunk_size: int = 16, + initial_state: torch.Tensor = None, + initial_state_bias: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True, +): + r""" + Args: + q (torch.Tensor): + queries of shape `(B, H, T, K)` + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + w (torch.Tensor): + layer norm weight of shape `(H, V)` + b (torch.Tensor): + layer norm bias of shape `(H, V)` + eta (torch.Tensor): + Learning rate for hidden state, of shape `(B, H, T, 1)`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + chunk_size (int): + chunk size. Default: `16`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `(B, H, K, V)`. Default: `None`. + initial_state_bias (Optional[torch.Tensor]): + Initial state bias of shape `(B, H, 1, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `(B, H, K, V)`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` + final_state (torch.Tensor): + Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None` + """ + assert q.dtype == k.dtype == v.dtype + assert k.shape[-1] == v.shape[-1], "DK must equal to DV." + if isinstance(eta, float): + eta = torch.full_like(q[:, :, :, :1], eta) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "Scale must be positive." + o, final_state, final_state_bias = ChunkTTTLinearFunction.apply( + q, + k, + v, + w, + b, + chunk_size, + eta, + scale, + eps, + initial_state, + initial_state_bias, + output_final_state, + cu_seqlens, + head_first, + ) + o = norm_residual(o, w, b, eps, head_first) + return o, final_state, final_state_bias diff --git a/fla/ops/ttt/fused_chunk.py b/fla/ops/ttt/fused_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..08850c170b3e88fa98f1818434baab7d73e93c7a --- /dev/null +++ b/fla/ops/ttt/fused_chunk.py @@ -0,0 +1,896 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.modules.layernorm import group_norm +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4) + ], + key=['BT', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def fused_chunk_ttt_linear_fwd_kernel( + q, + k, + v, + eta, + w, + b, + o, + scale, + eps, + h0, + hb0, + ht, + hbt, + offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_INITIAL_STATE_B: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + # indices + i_nh = tl.program_id(0) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + + o_i = tl.arange(0, BT) + v_i = tl.arange(0, BV) + m_A = o_i[:, None] >= o_i[None, :] + b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.) + b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.) + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + # [BV] + b_hb = tl.zeros([BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + if USE_INITIAL_STATE_B: + p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (0,), (BV,), (0,)) + b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32) + + for i_t in range(NT): + if HEAD_FIRST: + p_q = tl.make_block_ptr(q+i_nh*T*K, (T, K), (K, 1), (i_t*BT, 0), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k+i_nh*T*K, (K, T), (1, K), (0, i_t*BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_e = tl.make_block_ptr(eta+i_nh*T, (T,), (1,), (i_t*BT,), (BT,), (0,)) + p_e_last = eta+i_nh*T+T-1 if i_t == NT-1 else eta+i_nh*T+i_t*BT+BT-1 + else: + p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,)) + p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero") + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero") + + # [BT, BV] + b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :] + b_kh = tl.where((v_i < V)[None, :], b_kh, 0.) + mean = tl.sum(b_kh, axis=1, keep_dims=True) / V + xbar = tl.where((v_i < V)[None, :], b_kh - mean, 0.) + var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V + rstd = 1 / tl.sqrt(var.to(tl.float32) + eps) + b_kh_hat = (b_kh - mean) * rstd + + b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \ + b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k) + b_v = tl.where((v_i < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.) + b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype) + * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero") + # [BT] + b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero") + b_q = (b_q * scale).to(b_k.dtype) + + # [BT, BT] + b_A = tl.dot(b_q, b_k, allow_tf32=False) + b_A = tl.where(m_A, b_A, 0) + b_Ae = tl.where(m_A, b_e[:, None], 0.0) + + b_o = - tl.dot(b_e[:, None] * b_A.to(b_v2.dtype), b_v2, allow_tf32=False) + b_o += b_hb[None, :] - tl.dot(b_Ae.to(b_v2.dtype), b_v2, allow_tf32=False) + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_e_last = tl.load(p_e_last) + b_h = b_h - tl.dot(b_e_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False) + b_hb = b_hb - tl.sum(b_e_last * b_v2.to(b_k.dtype), axis=0) + b_h = tl.where((v_i < V)[None, :], b_h, 0.) + b_hb = tl.where((v_i < V), b_hb, 0.) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0)) + p_hbt = tl.make_block_ptr(hbt + i_nh * V, (V,), (1,), (0,), (BV,), (0,)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_hbt, b_hb.to(p_hbt.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4) + ], + key=['BT', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def fused_chunk_ttt_linear_bwd_kernel_h( + k, + v, + v2, + x, + y, + r, + w, + b, + eta, + h0, + hb0, + h, + do, + dq, + scale, + eps, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_INITIAL_STATE_B: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + # indices + i_nh = tl.program_id(0) + i_n, i_h = i_nh // H, i_nh % H + bos, _ = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + o_i = tl.arange(0, BT) + v_i = tl.arange(0, BV) + m_A = o_i[:, None] >= o_i[None, :] + b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.) + b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.) + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + # [BV] + b_hb = tl.zeros([BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + if USE_INITIAL_STATE_B: + p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (0,), (BV,), (0,)) + b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32) + + for i_t in range(NT): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h+(i_nh*NT+i_t)*K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0)) + p_k = tl.make_block_ptr(k+i_nh*T*K, (K, T), (1, K), (0, i_t*BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_v2 = tl.make_block_ptr(v2+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_x = tl.make_block_ptr(x+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_y = tl.make_block_ptr(y+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_r = tl.make_block_ptr(r+i_nh*T, (T, 1), (1, 1), (i_t*BT, 0), (BT, 1), (1, 0)) + p_e = tl.make_block_ptr(eta+i_nh*T, (T,), (1,), (i_t*BT,), (BT,), (0,)) + p_dq = tl.make_block_ptr(dq+i_nh*T*K, (T, K), (K, 1), (i_t*BT, 0), (BT, BK), (1, 0)) + p_do = tl.make_block_ptr(do+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_e_last = eta+i_nh*T+T-1 if i_t == NT-1 else eta+i_nh*T+i_t*BT+BT-1 + else: + p_h = tl.make_block_ptr(h+((boh+i_t)*H+i_h)*K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0)) + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_v2 = tl.make_block_ptr(v2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0)) + p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,)) + p_dq = tl.make_block_ptr(dq+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0)) + p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero") + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero") + + b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :] + b_kh = tl.where((v_i < V)[None, :], b_kh, 0.) + mean = tl.sum(b_kh, axis=1, keep_dims=True) / V + xbar = tl.where((v_i < V)[None, :], b_kh - mean, 0.) + var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V + rstd = 1 / tl.sqrt(var.to(tl.float32) + eps) + b_kh_hat = (b_kh - mean) * rstd + + b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \ + b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k) + b_v = tl.where((v_i < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.) + b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype) + * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V + tl.store(p_x, b_kh_hat.to(p_x.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_y, b_v.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_r, rstd.to(p_r.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_v2, b_v2.to(p_v2.dtype.element_ty), boundary_check=(0, 1)) + + b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero") + b_do = tl.load(p_do, boundary_check=(0, 1), padding_option="zero") + + b_v2 = tl.where((v_i < V)[None, :], b_v2, 0.) + b_ds = tl.dot(b_do, tl.trans(b_v2).to(b_do.dtype)) + b_ds = tl.where(m_A, b_ds, 0) + b_ds = b_ds.to(b_k.dtype) + b_dq = tl.dot(b_do, tl.trans(b_h).to(b_do.dtype)) + b_dq -= tl.dot(b_ds, tl.trans(b_k)) * b_e[:, None] + b_dq *= scale + + b_e_last = tl.load(p_e_last) + b_h = b_h - tl.dot(b_e_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False) + b_hb = b_hb - tl.sum(b_e_last * b_v2.to(b_k.dtype), axis=0) + b_h = tl.where((v_i < V)[None, :], b_h, 0.) + b_hb = tl.where((v_i < V), b_hb, 0.) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, + 'USE_INITIAL_STATE_B': lambda args: args['dhb0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_FINAL_STATE_GRADIENT_B': lambda args: args['dhbt'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4) + ], + key=['BT', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def fused_chunk_ttt_linear_bwd_kernel_dh( + q, + k, + v, + v2, + x, + y, + r, + w, + b, + eta, + h, + dht, + dhbt, + dh0, + dhb0, + do, + dk, + dv, + de, + dw, + db, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_INITIAL_STATE_B: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT_B: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + # indices + i_nh = tl.program_id(0) + i_n, i_h = i_nh // H, i_nh % H + bos, _ = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + # [BV] + b_dhb = tl.zeros([BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1), padding_option="zero") + if USE_FINAL_STATE_GRADIENT_B: + p_dhbt = tl.make_block_ptr(dhbt + i_nh * V, (V,), (1,), (0,), (BV,), (0,)) + b_dhb += tl.load(p_dhbt, boundary_check=(0,), padding_option="zero") + + # [BV] + o_i = tl.arange(0, BT) + v_i = tl.arange(0, BV) + m_A = o_i[:, None] >= o_i[None, :] + m_A_t = o_i[:, None] <= o_i[None, :] + b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.) + b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.) + b_dw = tl.zeros([BV,], dtype=b_w.dtype) + b_db = tl.zeros([BV,], dtype=b_b.dtype) + p_dw = tl.make_block_ptr(dw + i_nh * V, (V,), (1,), (0,), (BV,), (0,)) + p_db = tl.make_block_ptr(db + i_nh * V, (V,), (1,), (0,), (BV,), (0,)) + + for i_t in range(NT - 1, -1, -1): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h+(i_nh*NT+i_t)*K*V, (V, K), (1, V), (0, 0), (BV, BK), (0, 1)) + p_q = tl.make_block_ptr(q+i_nh*T*K, (K, T), (1, K), (0, i_t*BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k+i_nh*T*K, (T, K), (K, 1), (i_t*BT, 0), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_v2 = tl.make_block_ptr(v2+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_x = tl.make_block_ptr(x+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_y = tl.make_block_ptr(y+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_r = tl.make_block_ptr(r+i_nh*T, (T, 1), (1, 1), (i_t*BT, 0), (BT, 1), (1, 0)) + p_e = tl.make_block_ptr(eta+i_nh*T, (T,), (1,), (i_t*BT,), (BT,), (0,)) + p_dv = tl.make_block_ptr(dv+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk+i_nh*T*K, (T, K), (K, 1), (i_t*BT, 0), (BT, BK), (1, 0)) + p_do = tl.make_block_ptr(do+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_de = tl.make_block_ptr(de+i_nh*T, (T,), (1,), (i_t*BT,), (BT,), (0,)) + p_e_last = eta + i_nh*T + T - 1 if i_t == NT-1 else eta + i_nh*T + i_t*BT + BT - 1 + else: + p_h = tl.make_block_ptr(h+((boh+i_t)*H+i_h)*K*V, (V, K), (1, V), (0, 0), (BV, BK), (0, 1)) + p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_v2 = tl.make_block_ptr(v2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0)) + p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,)) + p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0)) + p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_de = tl.make_block_ptr(de+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,)) + p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H + b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero") + b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero") + b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero") + b_do = tl.load(p_do, boundary_check=(0, 1), padding_option="zero") + b_e_last = tl.load(p_e_last) + b_A = tl.dot(b_k, b_q) + b_A = - tl.where(m_A_t, b_A * scale * b_e[None, :], 0).to(do.dtype.element_ty) + b_Ae = - tl.where(m_A_t, b_e[None, :], 0).to(do.dtype.element_ty) + b_dv_new = tl.dot(b_A.to(b_do.dtype), b_do) + tl.dot(b_Ae.to(b_do.dtype), b_do) + b_dv_new -= tl.dot(b_e_last * b_k, b_dh.to(b_k.dtype)) + b_dv_new -= b_e_last * b_dhb.to(b_k.dtype)[None, :] + + b_v2 = tl.load(p_v2, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype) + b_x = tl.load(p_x, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype) + b_y = tl.load(p_y, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype) + b_rstd = tl.load(p_r, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + b_dy = b_rstd * (b_dv_new * V - tl.sum(b_dv_new, axis=1, keep_dims=True) - + b_x * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V + b_dx = -b_rstd * (b_dv_new * tl.sum(b_x * b_y, axis=1, keep_dims=True) + + b_y * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V + b_drstd = tl.sum(b_dv_new.to(b_rstd.dtype) * b_v2.to(b_rstd.dtype) / b_rstd, axis=1, keep_dims=True) + + b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero") + b_w = b_w.to(b_k.dtype) + b_b = b_b.to(b_k.dtype) + b_dv = -b_w * b_dy.to(b_k.dtype) + b_dk = b_w * b_dy.to(b_k.dtype) + b_dw += tl.sum(2 * b_w * b_x * b_dy.to(b_k.dtype) + + (b_b - b_v.to(b_k.dtype) + b_k) * b_dy.to(b_k.dtype), axis=0).to(b_dw.dtype) + b_db += tl.sum(b_w * b_dy.to(b_k.dtype), axis=0).to(b_db.dtype) + b_dx = b_dx.to(b_k.dtype) + b_w * b_w * b_dy.to(b_k.dtype) + + b_h = tl.load(p_h, boundary_check=(0, 1), padding_option="zero") + b_q = (b_q * scale).to(b_q.dtype) + b_dkh = b_rstd * (V * b_dx - tl.sum(b_dx, axis=1, keep_dims=True) - + b_x * tl.sum(b_x * b_dx, axis=1, keep_dims=True)) / V + b_dkh -= b_rstd * b_rstd * b_drstd * b_x / V + b_dkh = tl.where((v_i < V)[None, :] * (o_i < T-i_t*BT)[:, None], b_dkh, 0.) + b_dk += tl.dot(b_dkh, b_h.to(b_dkh.dtype)).to(b_k.dtype) + + b_ds = tl.dot(b_do, tl.trans(b_v2)) + b_ds = tl.where(m_A, b_ds, 0) + b_ds = b_ds.to(b_k.dtype) + i_last = (BT-1) if (i_t*BT+BT) <= T else (T % BT-1) + mask = (o_i == i_last) + b_dk -= b_e_last * tl.dot(b_v2, tl.trans(b_dh).to(b_v2.dtype)) + b_dk -= tl.dot(tl.trans(b_ds), tl.trans(b_q) * b_e[:, None]) + b_de = mask * tl.sum(- b_dh * tl.trans(tl.dot(tl.trans(b_v2), b_k))).to(b_k.dtype) + b_de -= mask * tl.sum(b_dhb * tl.sum(b_v2, axis=0)).to(b_k.dtype) + b_de -= tl.sum(tl.dot(b_ds, b_k) * tl.trans(b_q).to(b_k.dtype), axis=1) + b_de -= tl.sum(b_ds, axis=1) + b_dh += tl.dot(b_q, b_do.to(b_q.dtype)) + tl.dot(tl.trans(b_k).to(b_dkh.dtype), b_dkh) + b_dhb += tl.sum(b_do + b_dkh, axis=0) + b_dh = tl.where((v_i < V)[None, :], b_dh, 0.) + b_dhb = tl.where((v_i < V), b_dhb, 0.) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_de, b_de.to(p_de.dtype.element_ty), boundary_check=(0,)) + tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0,)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0+i_nh*K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + if USE_INITIAL_STATE_B: + p_dhb0 = tl.make_block_ptr(dhb0+i_nh*V, (V,), (1,), (0,), (BV,), (0,)) + tl.store(p_dhb0, b_dhb.to(p_dhb0.dtype.element_ty), boundary_check=(0,)) + + +def fused_chunk_ttt_linear_bwd_h( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float, + eps: float, + do: torch.Tensor, + BT: int = 16, + initial_state: torch.Tensor = None, + initial_state_bias: torch.Tensor = None, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True +): + assert offsets is None, "bwd of varlen is not implemented yet." + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + # N: the actual number of sequences in the batch with either equal or variable lengths + N, NT = B, triton.cdiv(T, BT) + BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V) + assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128." + + if head_first: + h = k.new_empty(B, H, NT, K, V) + r = v.new_empty(B, H, T, 1, dtype=torch.float32) + else: + h = k.new_empty(B, NT, H, K, V) + r = v.new_empty(B, T, H, 1, dtype=torch.float32) + v2 = torch.empty_like(v) + x = torch.empty_like(v) + y = torch.empty_like(v) + dq = torch.empty_like(q) + + grid = (N * H,) + fused_chunk_ttt_linear_bwd_kernel_h[grid]( + k=k, + v=v, + v2=v2, + x=x, + y=y, + r=r, + w=w, + b=b, + eta=eta, + h0=initial_state, + hb0=initial_state_bias, + h=h, + do=do, + dq=dq, + scale=scale, + eps=eps, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dq, h, v2, x, y, r + + +def fused_chunk_ttt_linear_bwd_dh( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + v2: torch.Tensor, + x: torch.Tensor, + y: torch.Tensor, + r: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float, + h: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + dhbt: torch.Tensor, + BT: int = 16, + initial_state: torch.Tensor = None, + initial_state_bias: torch.Tensor = None, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True +): + assert offsets is None, "bwd of varlen is not implemented yet." + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + # N: the actual number of sequences in the batch with either equal or variable lengths + N = B + BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V) + assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128." + + dh0 = torch.empty_like(initial_state, dtype=torch.float32) if initial_state is not None else None + dhb0 = torch.empty_like(initial_state_bias, dtype=torch.float32) if initial_state_bias is not None else None + dk = torch.empty_like(k) + dv = torch.empty_like(v) + de = torch.empty_like(eta) + dw = w.new_empty(B, H, V) + db = b.new_empty(B, H, V) + + grid = (N * H,) + fused_chunk_ttt_linear_bwd_kernel_dh[grid]( + q=q, + k=k, + v=v, + v2=v2, + x=x, + y=y, + r=r, + w=w, + b=b, + eta=eta, + h=h, + dht=dht, + dhbt=dhbt, + dh0=dh0, + dhb0=dhb0, + do=do, + dk=dk, + dv=dv, + de=de, + dw=dw, + db=db, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + dw = dw.sum(dim=0) + db = db.sum(dim=0) + return dk, dv, de, dw, db, dh0, dhb0 + + +def fused_chunk_ttt_linear_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float, + eps: float, + initial_state: torch.Tensor, + initial_state_bias: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True, + BT: int = 16 +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + # N: the actual number of sequences in the batch with either equal or variable lengths + N = B if offsets is None else len(offsets) - 1 + BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V) + assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128." + o = torch.empty_like(v) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + final_state_bias = k.new_empty(N, H, 1, V, dtype=torch.float32) if output_final_state else None + + grid = (N * H,) + fused_chunk_ttt_linear_fwd_kernel[grid]( + q=q, + k=k, + v=v, + eta=eta, + w=w, + b=b, + o=o, + scale=scale, + eps=eps, + h0=initial_state, + hb0=initial_state_bias, + ht=final_state, + hbt=final_state_bias, + offsets=offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return o, final_state, final_state_bias + + +def fused_chunk_ttt_linear_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float, + eps: float, + do: torch.Tensor, + dht: torch.Tensor, + dhbt: torch.Tensor, + BT: int = 16, + initial_state: torch.Tensor = None, + initial_state_bias: torch.Tensor = None, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True +): + assert offsets is None, "bwd of varlen is not implemented yet." + dq, h, v2, x, y, rstd = fused_chunk_ttt_linear_bwd_h( + q=q, + k=k, + v=v, + w=w, + b=b, + eta=eta, + scale=scale, + eps=eps, + do=do, + BT=BT, + initial_state=initial_state, + initial_state_bias=initial_state_bias, + offsets=offsets, + head_first=head_first + ) + dk, dv, de, dw, db, dh0, dhb0 = fused_chunk_ttt_linear_bwd_dh( + q=q, + k=k, + v=v, + v2=v2, + x=x, + y=y, + r=rstd, + w=w, + b=b, + eta=eta, + scale=scale, + h=h, + do=do, + dht=dht, + dhbt=dhbt, + BT=BT, + initial_state=initial_state, + initial_state_bias=initial_state_bias, + offsets=offsets, + head_first=head_first + ) + return dq, dk, dv, de, dw, db, dh0, dhb0 + + +class FusedChunkTTTLinearFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, w, b, BT, eta, scale, eps, initial_state, + initial_state_bias, output_final_state, offsets, head_first): + o, final_state, final_state_bias = fused_chunk_ttt_linear_fwd( + q=q, + k=k, + v=v, + w=w, + b=b, + eta=eta, + scale=scale, + eps=eps, + BT=BT, + initial_state=initial_state, + initial_state_bias=initial_state_bias, + output_final_state=output_final_state, + offsets=offsets, + head_first=head_first + ) + ctx.save_for_backward(q, k, v, eta, w, b, initial_state, initial_state_bias) + ctx.BT = BT + ctx.scale = scale + ctx.eps = eps + ctx.offsets = offsets + ctx.head_first = head_first + return o.to(q.dtype), final_state, final_state_bias + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht, dhbt): + q, k, v, eta, w, b, initial_state, initial_state_bias = ctx.saved_tensors + dq, dk, dv, de, dw, db, dh0, dhb0 = fused_chunk_ttt_linear_bwd( + q=q, + k=k, + v=v, + w=w, + b=b, + eta=eta, + scale=ctx.scale, + eps=ctx.eps, + do=do, + dht=dht, + dhbt=dhbt, + BT=ctx.BT, + initial_state=initial_state, + initial_state_bias=initial_state_bias, + offsets=ctx.offsets, + head_first=ctx.head_first + ) + return dq.to(q), dk.to(k), dv.to(v), dw.to(w), db.to(b), None, de.to(eta), None, None, dh0, dhb0, None, None, None + + +def norm_residual(x, weight, bias, eps, head_first): + # GroupNorm and Residual + if head_first: + B, H, T, D = x.shape + x = x.transpose(1, 2) + x += group_norm( + x.reshape(B, T, -1).clone(), + weight=weight.reshape(-1).clone(), + bias=bias.reshape(-1).clone(), + eps=eps, + num_groups=H, + ).reshape(x.shape) + x = x.transpose(1, 2) + else: + B, T, H, D = x.shape + x += group_norm( + x.reshape(B, T, -1).clone(), + weight=weight.reshape(-1).clone(), + bias=bias.reshape(-1).clone(), + eps=eps, + num_groups=H, + ).reshape(x.shape) + return x + + +def fused_chunk_ttt_linear( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float = None, + eps: float = 1e-6, + chunk_size: int = 16, + initial_state: torch.Tensor = None, + initial_state_bias: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True, +): + r""" + Args: + q (torch.Tensor): + queries of shape `(B, H, T, K)` + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + w (torch.Tensor): + layer norm weight of shape `(H, V)` + b (torch.Tensor): + layer norm bias of shape `(H, V)` + eta (torch.Tensor): + Learning rate for hidden state, of shape `(B, H, T, 1)`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + chunk_size (int): + chunk size. Default: `16`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `(B, H, K, V)`. Default: `None`. + initial_state_bias (Optional[torch.Tensor]): + Initial state bias of shape `(B, H, 1, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `(B, H, K, V)`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` + final_state (torch.Tensor): + Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`. + final_state_bias (torch.Tensor): + Final state bias of shape `[B, H, 1, V]` if `output_final_state=True` else `None`. + """ + assert q.dtype == k.dtype == v.dtype + assert k.shape[-1] == v.shape[-1], "DK must equal to DV." + if isinstance(eta, float): + eta = torch.full_like(q[:, :, :, :1], eta) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "Scale must be positive." + o, final_state, final_state_bias = FusedChunkTTTLinearFunction.apply( + q, + k, + v, + w, + b, + chunk_size, + eta, + scale, + eps, + initial_state, + initial_state_bias, + output_final_state, + cu_seqlens, + head_first + ) + o = norm_residual(o, w, b, eps, head_first) + return o, final_state, final_state_bias diff --git a/fla/ops/ttt/naive.py b/fla/ops/ttt/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..0ad5dbba89989f6bfa7b13278f93e506f72a691a --- /dev/null +++ b/fla/ops/ttt/naive.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan + +import torch +import torch.nn.functional as F + + +def ttt_linear( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float, + eps: float, + mini_batch_size: int, + initial_state: torch.Tensor, + initial_state_bias: torch.Tensor, + output_final_state: bool +): + B, H, T, D = q.shape + BT = mini_batch_size + NT = T // BT + # [NT, B, H, mini_batch_size, D] + _q = q.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4) + _k = k.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4) + _v = v.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4) + # [NT, B, H, BT, 1] + _eta = eta.reshape(B, H, NT, BT, 1).permute(2, 0, 1, 3, 4) + # [H, 1, D] + w = w.reshape(H, 1, D).to(torch.float32) + b = b.reshape(H, 1, D).to(torch.float32) + + h = torch.zeros((B, H, D, D), device=v.device, dtype=torch.float32) if initial_state is None else initial_state + hb = torch.zeros((B, H, 1, D), device=v.device, dtype=torch.float32) if initial_state_bias is None else initial_state_bias + q *= scale + # [NT, B, H, BT, D] + o = torch.empty_like(_v) + + for i in range(NT): + q_i, k_i, v_i, eta_i = [x[i] for x in [_q, _k, _v, _eta]] + kh = k_i @ h + hb + reconstruction_target = v_i - k_i + + mean = kh.mean(-1, True) + var = kh.var(-1, unbiased=False, keepdim=True).to(torch.float32) + rstd = torch.sqrt(var + eps).to(torch.float32) + kh_hat = (kh - mean) / rstd + + g = w * kh_hat + b - reconstruction_target + g *= w + v_new = (D * g - g.sum(-1, True) - kh_hat * (g * kh_hat).sum(-1, True)) / (rstd * D) + + Attn = torch.tril(q_i @ k_i.transpose(-2, -1)) + o_i = q_i @ h - (eta_i * Attn) @ v_new + hb - torch.tril(eta_i.expand_as(Attn)) @ v_new + h = h - (eta_i[:, :, -1, :, None] * k_i).transpose(-1, -2) @ v_new + hb = hb - torch.sum(eta_i[:, :, -1, :, None] * v_new, dim=-2, keepdim=True) + # layer norm with residuals + + mean = o_i.mean(dim=-1, keepdim=True) + var = o_i.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32) + rstd = torch.sqrt(var + eps).to(torch.float32) + o[i] = o_i + (o_i - mean) / rstd * w + b + + # [B, H, T, D] + o = o.permute(1, 2, 0, 3, 4).reshape(B, H, T, D) + h = h if output_final_state else None + hb = hb if output_final_state else None + return o, h, hb + + +def chunk_ttt_linear_ref( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float = None, + eps: float = 1e-6, + mini_batch_size: int = 16, + initial_state: torch.Tensor = None, + initial_state_bias: torch.Tensor = None, + output_final_state: bool = False, + head_first: bool = True, +): + assert q.dtype == k.dtype == v.dtype + assert k.shape[-1] == v.shape[-1], "The key and value dimension must be the same." + if isinstance(eta, float): + eta = torch.full_like(q[:, :, :, :1], eta) + if scale is None: + scale = k.shape[-1] ** -0.5 + if not head_first: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + eta = eta.transpose(1, 2) + T = q.shape[-2] + padded = (mini_batch_size - (T % mini_batch_size)) % mini_batch_size + if padded > 0: + q = F.pad(q, (0, 0, 0, padded)) + k = F.pad(k, (0, 0, 0, padded)) + v = F.pad(v, (0, 0, 0, padded)) + eta = F.pad(eta, (0, 0, 0, padded)) + eta[:, :, -1, :] = eta[:, :, -(padded+1), :] + assert q.shape[-2] % mini_batch_size == 0, "Sequence length should be a multiple of mini_batch_size." + q, k, v, eta, w, b = map(lambda x: x.to(torch.float32), [q, k, v, eta, w, b]) + o, final_state, final_state_bias = ttt_linear( + q, + k, + v, + w, + b, + eta, + scale, + eps, + mini_batch_size, + initial_state, + initial_state_bias, + output_final_state, + ) + o = o[:, :, :T, :].contiguous() + if not head_first: + o = o.transpose(1, 2) + return o, final_state, final_state_bias diff --git a/fla/ops/utils/__init__.py b/fla/ops/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4b0ff5fcf03073efdcf657043ecdd482c8eec1 --- /dev/null +++ b/fla/ops/utils/__init__.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- + +from .asm import fp32_to_tf32_asm +from .cumsum import ( + chunk_global_cumsum, + chunk_global_cumsum_scalar, + chunk_global_cumsum_scalar_kernel, + chunk_global_cumsum_vector, + chunk_global_cumsum_vector_kernel, + chunk_local_cumsum, + chunk_local_cumsum_scalar, + chunk_local_cumsum_scalar_kernel, + chunk_local_cumsum_vector, + chunk_local_cumsum_vector_kernel +) +from .logcumsumexp import logcumsumexp_fwd_kernel +from .logsumexp import logsumexp_fwd, logsumexp_fwd_kernel +from .matmul import addmm, matmul, matmul_kernel +from .pooling import mean_pooling +from .softmax import softmax_bwd, softmax_bwd_kernel, softmax_fwd, softmax_fwd_kernel + +__all__ = [ + 'chunk_global_cumsum', + 'chunk_global_cumsum_scalar', + 'chunk_global_cumsum_scalar_kernel', + 'chunk_global_cumsum_vector', + 'chunk_global_cumsum_vector_kernel', + 'chunk_local_cumsum', + 'chunk_local_cumsum_scalar', + 'chunk_local_cumsum_scalar_kernel', + 'chunk_local_cumsum_vector', + 'chunk_local_cumsum_vector_kernel', + 'logcumsumexp_fwd_kernel', + 'logsumexp_fwd', + 'logsumexp_fwd_kernel', + 'addmm', + 'matmul', + 'matmul_kernel', + 'mean_pooling', + 'softmax_bwd', + 'softmax_bwd_kernel', + 'softmax_fwd', + 'softmax_fwd_kernel', + 'fp32_to_tf32_asm', +] diff --git a/fla/ops/utils/__pycache__/solve_tril.cpython-312.pyc b/fla/ops/utils/__pycache__/solve_tril.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b690afbbf3b423df1d95958087a4bd157506492 Binary files /dev/null and b/fla/ops/utils/__pycache__/solve_tril.cpython-312.pyc differ diff --git a/fla/ops/utils/asm.py b/fla/ops/utils/asm.py new file mode 100644 index 0000000000000000000000000000000000000000..c4a96bad2cecf24733832b6817f8d4b855685f05 --- /dev/null +++ b/fla/ops/utils/asm.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +from fla.utils import device_platform + + +def fp32_to_tf32_asm() -> str: + """ + Get the assembly code for converting FP32 to TF32. + """ + ASM_DICT = { + 'nvidia': 'cvt.rna.tf32.f32 $0, $1;' + } + if device_platform in ASM_DICT: + return ASM_DICT[device_platform] + else: + # return empty string if the device is not supported + return "" diff --git a/fla/ops/utils/cumsum.py b/fla/ops/utils/cumsum.py new file mode 100644 index 0000000000000000000000000000000000000000..5a5f3e90d39566507f01660040bfdb1986d25adb --- /dev/null +++ b/fla/ops/utils/cumsum.py @@ -0,0 +1,400 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.utils import check_shared_mem, input_guard + +BS_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8] + ], + key=['BT'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_local_cumsum_scalar_kernel( + s, + o, + offsets, + indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr, + REVERSE: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # [BT] + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BS': BS}, num_warps=num_warps) + for BS in BS_LIST + for num_warps in [2, 4, 8] + ], + key=['S', 'BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_local_cumsum_vector_kernel( + s, + o, + offsets, + indices, + T, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr, + REVERSE: tl.constexpr +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + if REVERSE: + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) + else: + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BT': 16}, num_warps=2), + triton.Config({'BT': 32}, num_warps=4), + triton.Config({'BT': 32}, num_warps=2), + triton.Config({'BT': 64}, num_warps=8), + triton.Config({'BT': 64}, num_warps=4), + ], + key=[] +) +@triton.jit(do_not_specialize=['T']) +def chunk_global_cumsum_scalar_kernel( + s, + o, + offsets, + T, + H: tl.constexpr, + BT: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr, + REVERSE: tl.constexpr +): + i_bh = tl.program_id(0) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + b_z = tl.zeros([], dtype=tl.float32) + NT = tl.cdiv(T, BT) + for i_c in range(NT): + i_t = NT-1-i_c if REVERSE else i_c + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + b_ss = tl.sum(b_s, 0) + if REVERSE: + b_o = -b_o + b_ss + b_s + b_o += b_z + if i_c >= 0: + b_z += b_ss + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BT': BT}, num_warps=num_warps) + for BT in [16, 32, 64] + for num_warps in [2, 4, 8] + ], + key=['S'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_global_cumsum_vector_kernel( + s, + z, + offsets, + T, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr, + REVERSE: tl.constexpr +): + i_s, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + o_i = tl.arange(0, BT) + if REVERSE: + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) + else: + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + + b_z = tl.zeros([BS], dtype=tl.float32) + NT = tl.cdiv(T, BT) + for i_c in range(NT): + i_t = NT-1-i_c if REVERSE else i_c + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_z = tl.make_block_ptr(z + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_z = tl.make_block_ptr(z + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1)) + if i_c >= 0: + b_z += tl.sum(b_s, 0) + + +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + head_first: bool = True, + output_dtype: Optional[torch.dtype] = torch.float +) -> torch.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + if offsets is not None: + B = len(offsets) - 1 + assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + g_org, + g, + offsets, + indices, + T=T, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse + ) + return g + + +def chunk_local_cumsum_vector( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + head_first: bool = True, + output_dtype: Optional[torch.dtype] = torch.float +) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" + + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) + # keep cummulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid]( + g_org, + g, + offsets, + indices, + T=T, + H=H, + S=S, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse + ) + return g + + +@input_guard +def chunk_global_cumsum_scalar( + s: torch.Tensor, + dtype: Optional[torch.dtype] = None, + reverse: bool = False, + offsets: Optional[torch.Tensor] = None, + head_first: bool = True, + output_dtype: Optional[torch.dtype] = torch.float +) -> torch.Tensor: + dtype = dtype or s.dtype + if head_first: + B, H, T = s.shape + else: + B, T, H = s.shape + if offsets is not None: + B = len(offsets) - 1 + grid = (B * H,) + z = torch.empty_like(s, dtype=output_dtype or dtype) + chunk_global_cumsum_scalar_kernel[grid]( + s, + z, + offsets, + T=T, + H=H, + HEAD_FIRST=head_first, + REVERSE=reverse + ) + return z + + +@input_guard +def chunk_global_cumsum_vector( + s: torch.Tensor, + dtype: Optional[torch.dtype] = None, + reverse: bool = False, + offsets: Optional[torch.Tensor] = None, + head_first: bool = True, + output_dtype: Optional[torch.dtype] = torch.float +) -> torch.Tensor: + dtype = dtype or s.dtype + if head_first: + B, H, T, S = s.shape + else: + B, T, H, S = s.shape + BS = min(32, triton.next_power_of_2(S)) + if offsets is not None: + B = len(offsets) - 1 + grid = (triton.cdiv(S, BS), B * H) + z = torch.empty_like(s, dtype=output_dtype or dtype) + chunk_global_cumsum_vector_kernel[grid]( + s, + z, + offsets, + T=T, + H=H, + S=S, + BS=BS, + HEAD_FIRST=head_first, + REVERSE=reverse + ) + return z + + +@input_guard +def chunk_global_cumsum( + s: torch.Tensor, + dtype: Optional[torch.dtype] = None, + reverse: bool = False, + offsets: Optional[torch.Tensor] = None, + head_first: bool = True, + output_dtype: Optional[torch.dtype] = torch.float +) -> torch.Tensor: + if offsets is not None: + assert s.shape[0] == 1, "Only batch size 1 is supported when offsets are provided" + if len(s.shape) == 3: + return chunk_global_cumsum_scalar(s, dtype, reverse, offsets, head_first, output_dtype) + elif len(s.shape) == 4: + return chunk_global_cumsum_vector(s, dtype, reverse, offsets, head_first, output_dtype) + else: + raise ValueError(f"Unsupported input shape {s.shape}. " + f"which should be [B, H, T]/[B, H, T, D] if `head_first=True` " + f"or [B, T, H]/[B, T, H, D] otherwise") + + +@input_guard +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + head_first: bool = True, + output_dtype: Optional[torch.dtype] = torch.float +) -> torch.Tensor: + if offsets is not None: + assert g.shape[0] == 1, "Only batch size 1 is supported when offsets are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar(g, chunk_size, reverse, offsets, indices, head_first, output_dtype) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector(g, chunk_size, reverse, offsets, indices, head_first, output_dtype) + else: + raise ValueError(f"Unsupported input shape {g.shape}. " + f"which should be (B, H, T, dim) if `head_first=True` " + f"or (batch_size, num_heads, seq_len) otherwise") diff --git a/fla/ops/utils/logsumexp.py b/fla/ops/utils/logsumexp.py new file mode 100644 index 0000000000000000000000000000000000000000..b647012b68c05ee59783d3d3615961962895a185 --- /dev/null +++ b/fla/ops/utils/logsumexp.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2024, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp, log + + +@triton.heuristics({ + 'HAS_SCALE': lambda args: args['scale'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16, 32] + ], + key=['D'] +) +@triton.jit +def logsumexp_fwd_kernel( + x, + z, + scale, + D: tl.constexpr, + B: tl.constexpr, + HAS_SCALE: tl.constexpr +): + i_n, i_d = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64) + o_d = i_d * B + tl.arange(0, B) + m_d = o_d < D + + b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf')) + if HAS_SCALE: + b_x = b_x * scale + b_m = tl.max(b_x, 0) + b_z = log(tl.sum(exp(b_x - b_m), 0)) + b_m + tl.store(z + i_n * tl.cdiv(D, B) + i_d, b_z) + + +def logsumexp_fwd( + x, + scale: Optional[float] = None, + dtype: Optional[torch.dtype] = None +): + r""" + Compute the logsumexp of the input tensor over the last dimension. + + Args: + x (Tensor): + The input tensor of any shape. + scale (Optional[float]): + The scale applied to the input tensor. Default: `None`. + dtype (Optional[torch.dtype]): + The data type of the output tensor. Default: `None`. + Returns: + Tensor: The logsumexp of the input tensor. + """ + + shape = x.shape + x = x.view(-1, shape[-1]) + N, D = x.shape + B = min(triton.next_power_of_2(D), 64 * 1024) + ND = triton.cdiv(D, B) + + z = x.new_empty(N, ND, dtype=torch.float) + logsumexp_fwd_kernel[(N, ND)]( + x=x, + z=z, + scale=scale, + D=D, + B=B + ) + z = z.logsumexp(-1).view(*shape[:-1]) + if dtype is not None and dtype != torch.float: + z = z.to(dtype) + return z diff --git a/fla/ops/utils/matmul.py b/fla/ops/utils/matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..994bcecd237c721eb1c2d8511b0f15d5d0aa804d --- /dev/null +++ b/fla/ops/utils/matmul.py @@ -0,0 +1,245 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# code adapted from +# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp +from fla.utils import input_guard + + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BM`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.heuristics({ + 'HAS_ALPHA': lambda args: args['alpha'] is not None, + 'HAS_BETA': lambda args: args['beta'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BM': 128, 'BK': 64, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8), + triton.Config({'BM': 64, 'BK': 32, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4), + triton.Config({'BM': 128, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4), + triton.Config({'BM': 128, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4), + triton.Config({'BM': 64, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4), + triton.Config({'BM': 128, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4), + triton.Config({'BM': 64, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=5, num_warps=2), + triton.Config({'BM': 32, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=5, num_warps=2), + # Good config for fp8 inputs. + # triton.Config({'BM': 128, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8), + # triton.Config({'BM': 256, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=3, num_warps=8), + # triton.Config({'BM': 256, 'BK': 128, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4), + # triton.Config({'BM': 64, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4), + # triton.Config({'BM': 128, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4), + # triton.Config({'BM': 128, 'BK': 64, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4), + # triton.Config({'BM': 64, 'BK': 64, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4), + # triton.Config({'BM': 128, 'BK': 64, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4) + ], + key=['M', 'N', 'K'] +) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a, + b, + c, + input, + alpha, + beta, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `s_am` is how much to increase `a` + # by to get the element one row down (A has M rows). + stride_ab, stride_am, stride_ak, # a: batch, M, K + stride_bk, stride_bn, # b: K, N + stride_cb, stride_cm, stride_cn, # c: batch, M, N + # Meta-parameters + BM: tl.constexpr, + BK: tl.constexpr, + BN: tl.constexpr, + G: tl.constexpr, + ACTIVATION: tl.constexpr, + HAS_INPUT: tl.constexpr, + HAS_ALPHA: tl.constexpr, + HAS_BETA: tl.constexpr, + ALLOW_TF32: tl.constexpr, + X_DIM: tl.constexpr = 1, +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + i_b, i_m, i_n = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + NM, NN = tl.num_programs(1), tl.num_programs(2) + i_m, i_n = tl.swizzle2d(i_m, i_n, NM, NN, G) + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `p_a` is a block of [BM, BK] pointers + # `p_b` is a block of [BK, BN] pointers + # See above `Pointer Arithmetic` section for details + a_batch_ptr = a + i_b * stride_ab + o_am = (i_m * BM + tl.arange(0, BM)) % M + o_bn = (i_n * BN + tl.arange(0, BN)) % N + o_k = tl.arange(0, BK) + + p_a = a_batch_ptr + (o_am[:, None] * stride_am + o_k[None, :] * stride_ak) + p_b = b + (o_k[:, None] * stride_bk + o_bn[None, :] * stride_bn) + + b_acc = tl.zeros((BM, BN), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BK)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0) + b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0) + # We accumulate along the K dimension. + b_acc = tl.dot(b_a, b_b, acc=b_acc, allow_tf32=ALLOW_TF32) + # Advance the ptrs to the next K block. + p_a += BK * stride_ak + p_b += BK * stride_bk + + o_cm = i_m * BM + tl.arange(0, BM) + o_cn = i_n * BN + tl.arange(0, BN) + mask = (o_cm[:, None] < M) & (o_cn[None, :] < N) + + b_c = b_acc + # You can fuse arbitrary activation functions here + # while the b_acc is still in FP32! + if ACTIVATION == "leaky_relu": + b_c = leaky_relu(b_c) + elif ACTIVATION == "relu": + b_c = relu(b_c) + elif ACTIVATION == "sigmoid": + b_c = sigmoid(b_c) + elif ACTIVATION == "tanh": + b_c = tanh(b_c) + + if HAS_ALPHA: + b_c *= tl.load(alpha) + + if HAS_INPUT: + p_i = input + (stride_cm * o_cm[:, None] if X_DIM == 2 else 0) + stride_cn * o_cn[None, :] + mask_p = (o_cn[None, :] < N) if X_DIM == 1 else mask + b_i = tl.load(p_i, mask=mask_p, other=0.0).to(tl.float32) + if HAS_BETA: + b_i *= tl.load(beta) + b_c += b_i + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + c_batch_ptr = c + i_b * stride_cb + p_c = c_batch_ptr + stride_cm * o_cm[:, None] + stride_cn * o_cn[None, :] + tl.store(p_c, b_c.to(c.dtype.element_ty), mask=mask) + + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. +@triton.jit +def leaky_relu(x): + return tl.where(x >= 0, x, 0.01 * x) + + +@triton.jit +def sigmoid(x): + # σ(x) = 1 / (1 + exp(-x)) + return 1.0 / (1.0 + exp(-x)) + + +@triton.jit +def tanh(x): + # tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) + # 2 * sigmoid(2x) - 1 + return (exp(x) - exp(-x)) / (exp(x) + exp(-x)) + + +@triton.jit +def relu(x): + # ReLU(x) = max(0, x) + return tl.maximum(x, 0.0) + + +@input_guard +def matmul(a, b, activation=''): + assert a.dim() in [2, 3], "a must be 2D or 3D" + assert b.dim() == 2, "b must be 2D" + assert a.shape[-1] == b.shape[0], f"Incompatible dimensions: A {a.shape}, B {b.shape}" + + if a.dim() == 2: + a_dim = 2 + a = a.unsqueeze(0).contiguous() # (1, M, K) + else: + a_dim = 3 + allow_tf32 = False if a.dtype == torch.float32 else True + + B, M, K = a.shape[0], a.shape[1], a.shape[2] + K_b, N = b.shape + assert K == K_b, f"Incompatible K dimension: A {K} vs B {K_b}" + c = a.new_empty(B, M, N) + + def grid(meta): return (B, triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN'])) + matmul_kernel[grid]( + a, b, c, None, None, None, + M, N, K, + a.stride(0), a.stride(1), a.stride(2), # stride_ab, stride_am, stride_ak + b.stride(0), b.stride(1), # stride_bk, stride_bn (b.dim() == 2) + c.stride(0), c.stride(1), c.stride(2), # stride_cb, stride_cm, stride_cn + ACTIVATION=activation, + ALLOW_TF32=allow_tf32, + HAS_INPUT=False, + ) + return c.squeeze(0) if a_dim == 2 else c + + +@input_guard +def addmm( + x: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + alpha: Optional[float] = None, + beta: Optional[float] = None, +) -> torch.Tensor: + assert a.dim() in [2, 3], "a must be 2D or 3D" + assert b.dim() == 2, "b must be 2D" + assert a.shape[-1] == b.shape[0], f"Incompatible dimensions: A {a.shape}, B {b.shape}" + + if a.dim() == 2: + a_dim = 2 + a = a.unsqueeze(0).contiguous() # (1, M, K) + else: + a_dim = 3 + allow_tf32 = False if a.dtype == torch.float32 else True + + B, M, K = a.shape[0], a.shape[1], a.shape[2] + K_b, N = b.shape + assert K == K_b, f"Incompatible K dimension: A {K} vs B {K_b}" + c = a.new_empty(B, M, N) + + def grid(meta): return (B, triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN'])) + matmul_kernel[grid]( + a, b, c, x, alpha, beta, + M, N, K, + a.stride(0), a.stride(1), a.stride(2), # stride_ab, stride_am, stride_ak + b.stride(0), b.stride(1), # stride_bk, stride_bn (b.dim() == 2) + c.stride(0), c.stride(1), c.stride(2), # stride_cb, stride_cm, stride_cn + ACTIVATION=None, + ALLOW_TF32=allow_tf32, + HAS_INPUT=True, + X_DIM=x.dim(), + ) + return c.squeeze(0) if a_dim == 2 else c diff --git a/fla/ops/utils/op.py b/fla/ops/utils/op.py new file mode 100644 index 0000000000000000000000000000000000000000..f0fe269ed8756b6a7b3ea396dffdfdd56b924ea9 --- /dev/null +++ b/fla/ops/utils/op.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +import os + +import triton +import triton.language as tl +import triton.language.extra.libdevice as tldevice + +from fla.utils import is_gather_supported + +if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': + div = tldevice.fast_dividef + exp = tldevice.fast_expf + log = tldevice.fast_logf + log2 = tldevice.fast_log2f +else: + @triton.jit + def div_normal(x, y): + return x / y + div = div_normal + exp = tl.exp + log = tl.log + log2 = tl.log2 + + +@triton.jit +def safe_exp(x): + return exp(tl.where(x <= 0, x, float('-inf'))) + + +if not is_gather_supported: + def gather(*args, **kwargs): + pass +else: + gather = tl.gather diff --git a/fla/ops/utils/pooling.py b/fla/ops/utils/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..0dd9059b4abd0a87fb65e25c01fd5897452f77e0 --- /dev/null +++ b/fla/ops/utils/pooling.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.utils import prepare_chunk_indices +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=num_warps) + for BD in [16, 32, 64, 128] + for num_warps in [1, 2, 4, 8] + ], + key=['BT'] +) +@triton.jit(do_not_specialize=['T']) +def mean_pooling_fwd_kernel( + x, + o, + offsets, + indices, + T: tl.constexpr, + H: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + NT: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + p_x = tl.make_block_ptr(x + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_o = tl.make_block_ptr(o + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,)) + # [BT, BD] + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + # [BD] + b_o = tl.sum(b_x, axis=0) / min(BT, T - i_t * BT) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=num_warps) + for BD in [16, 32, 64, 128] + for num_warps in [1, 2, 4, 8] + ], + key=['BT'] +) +@triton.jit(do_not_specialize=['T']) +def mean_pooling_bwd_kernel( + do, + dx, + offsets, + indices, + T: tl.constexpr, + H: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + NT: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + p_dx = tl.make_block_ptr(dx + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_do = tl.make_block_ptr(do + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,)) + # [BD] + b_do = tl.load(p_do, boundary_check=(0,)).to(tl.float32) + # [BT, BD] + b_dx = b_do / tl.full((BT,), min(BT, T - i_t * BT), dtype=tl.float32)[:, None] + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) + + +def mean_pooling_fwd( + x: torch.Tensor, + chunk_size: int, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None +) -> torch.Tensor: + B, T, H, D = x.shape + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + o = x.new_empty(B, NT, H, D) + def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H) + mean_pooling_fwd_kernel[grid]( + x, + o, + offsets, + indices, + T=T, + H=H, + D=D, + BT=BT, + NT=NT, + ) + return o + + +def mean_pooling_bwd( + do: torch.Tensor, + batch_size: int, + seq_len: int, + chunk_size: int, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None +) -> torch.Tensor: + B, T, H, D = batch_size, seq_len, *do.shape[-2:] + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + dx = do.new_empty(B, T, H, D) + def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H) + mean_pooling_bwd_kernel[grid]( + do, + dx, + offsets, + indices, + T=T, + H=H, + D=D, + BT=BT, + NT=NT + ) + return dx + + +class MeanPoolingFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + x: torch.Tensor, + chunk_size: int, + offsets: Optional[torch.LongTensor] = None + ) -> torch.Tensor: + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None + o = mean_pooling_fwd(x, chunk_size, offsets, indices) + ctx.batch_size = x.shape[0] + ctx.seq_len = x.shape[1] + ctx.chunk_size = chunk_size + ctx.offsets = offsets + ctx.indices = indices + return o + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward( + ctx, do + ) -> Tuple[torch.Tensor, None, None]: + batch_size = ctx.batch_size + seq_len = ctx.seq_len + chunk_size = ctx.chunk_size + offsets = ctx.offsets + indices = ctx.indices + dx = mean_pooling_bwd(do, batch_size, seq_len, chunk_size, offsets, indices) + return dx, None, None + + +def mean_pooling( + x: torch.Tensor, + chunk_size: int, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + if head_first: + x = x.transpose(1, 2) + if cu_seqlens is not None: + if x.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {x.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + o = MeanPoolingFunction.apply(x, chunk_size, cu_seqlens) + if head_first: + o = o.transpose(1, 2) + return o diff --git a/fla/ops/utils/softmax.py b/fla/ops/utils/softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..12c37c7a57061c8d8dfd2ab6a31b2dc33547607f --- /dev/null +++ b/fla/ops/utils/softmax.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2024, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32) + ], + key=['D'] +) +@triton.jit +def softmax_fwd_kernel( + x, + p, + D: tl.constexpr, + B: tl.constexpr +): + i_n = tl.program_id(0) + o_d = tl.arange(0, B) + m_d = o_d < D + + b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf')) + b_m = tl.max(b_x, 0) + b_x = exp(b_x - b_m) + b_p = b_x / tl.sum(b_x, 0) + + tl.store(p + i_n * D + o_d, b_p.to(p.dtype.element_ty), mask=m_d) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32) + ], + key=['D'] +) +@triton.jit +def softmax_bwd_kernel( + p, + dp, + ds, + D: tl.constexpr, + B: tl.constexpr +): + i_n = tl.program_id(0) + o_d = tl.arange(0, B) + m_d = o_d < D + + b_p = tl.load(p + i_n * D + o_d, mask=m_d, other=0.) + b_dp = tl.load(dp + i_n * D + o_d, mask=m_d, other=0.) + b_pp = tl.sum(b_p * b_dp, 0) + b_ds = b_p * b_dp - b_p * b_pp + tl.store(ds + i_n * D + o_d, b_ds.to(ds.dtype.element_ty), mask=m_d) + + +def softmax_fwd( + x: torch.Tensor, + dtype: Optional[torch.dtype] = torch.float +) -> torch.Tensor: + shape = x.shape + x = x.view(-1, x.shape[-1]) + + N, D = x.shape + B = triton.next_power_of_2(D) + + p = torch.empty_like(x, dtype=dtype) + softmax_fwd_kernel[(N,)]( + x=x, + p=p, + D=D, + B=B + ) + return p.view(*shape) + + +def softmax_bwd( + p: torch.Tensor, + dp: torch.Tensor, + dtype: Optional[torch.dtype] = torch.float +) -> torch.Tensor: + shape = p.shape + p = p.view(-1, p.shape[-1]) + ds = torch.empty_like(p, dtype=dtype) + + N, D = p.shape + B = triton.next_power_of_2(D) + softmax_bwd_kernel[(N,)]( + p=p, + dp=dp, + ds=ds, + D=D, + B=B + ) + return ds.view(*shape) diff --git a/fla/ops/utils/solve_tril.py b/fla/ops/utils/solve_tril.py new file mode 100644 index 0000000000000000000000000000000000000000..d0c2b66833c4d1479ec8c25ae82a85e69e96650a --- /dev/null +++ b/fla/ops/utils/solve_tril.py @@ -0,0 +1,321 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.common.utils import prepare_chunk_indices +from fla.utils import input_guard + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=['BT'], +) +@triton.jit(do_not_specialize=['T']) +def solve_tril_16x16_kernel( + A, + Ad, + offsets, + indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + A = A + i_bh * T * BT + Ad = Ad + i_bh * T * 16 + stride_16 = 16 + stride_BT = BT + else: + A = A + (bos*H + i_h) * BT + Ad = Ad + (bos*H + i_h) * 16 + stride_16 = H*16 + stride_BT = H*BT + + offset = (i_t * 16) % BT + p_A = tl.make_block_ptr(A, (T, BT), (stride_BT, 1), (i_t * 16, offset), (16, 16), (1, 0)) + p_Ai = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 16, 0), (16, 16), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0) + + o_i = tl.arange(0, 16) + for i in range(1, min(16, T-i_t*16)): + b_a = -tl.load(A + (i_t * 16 + i) * stride_BT + o_i + offset) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) + mask = o_i == i + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += o_i[:, None] == o_i[None, :] + tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=['H', 'BT', 'HEAD_FIRST', 'USE_OFFSETS'], +) +@triton.jit(do_not_specialize=['T']) +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + offsets, + indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + A += (i_bh * T * 32) + Ad += (i_bh * T * 16) + Ai += (i_bh * T * 32) + stride_16 = 16 + stride_32 = 32 + else: + A += (bos*H + i_h) * 32 + Ad += (bos*H + i_h) * 16 + Ai += (bos*H + i_h) * 32 + stride_16 = 16 * H + stride_32 = 32 * H + + p_A_21 = tl.make_block_ptr(A, (T, 32), (stride_32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 32, 0), (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (stride_32, 1), (i_t * 32, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (stride_32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (stride_32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)) + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)) + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), Ai_11, input_precision='ieee') + tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=['H', 'BT', 'HEAD_FIRST', 'USE_OFFSETS'], +) +@triton.jit(do_not_specialize=['T']) +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + offsets, + indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + A += i_bh * T * 64 + Ad += i_bh * T * 16 + Ai += i_bh * T * 64 + stride_16 = 16 + stride_64 = 64 + else: + A += (bos*H + i_h) * 64 + Ad += (bos*H + i_h) * 16 + Ai += (bos*H + i_h) * 64 + stride_16 = 16 * H + stride_64 = 64 * H + + p_A_21 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_A_32 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)) + p_A_31 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_A_43 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)) + p_A_42 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)) + p_A_41 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 64, 0), (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)) + A_32 = tl.load(p_A_32, boundary_check=(0, 1)) + A_31 = tl.load(p_A_31, boundary_check=(0, 1)) + A_43 = tl.load(p_A_43, boundary_check=(0, 1)) + A_42 = tl.load(p_A_42, boundary_check=(0, 1)) + A_41 = tl.load(p_A_41, boundary_check=(0, 1)) + + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)) + Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)) + Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)) + + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), Ai_11, input_precision='ieee') + Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision='ieee'), Ai_22, input_precision='ieee') + Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision='ieee'), Ai_33, input_precision='ieee') + + Ai_31 = -tl.dot( + Ai_33, + tl.dot(A_31, Ai_11, input_precision='ieee') + + tl.dot(A_32, Ai_21, input_precision='ieee'), + input_precision='ieee' + ) + Ai_42 = -tl.dot( + Ai_44, + tl.dot(A_42, Ai_22, input_precision='ieee') + + tl.dot(A_43, Ai_32, input_precision='ieee'), + input_precision='ieee' + ) + Ai_41 = -tl.dot( + Ai_44, + tl.dot(A_41, Ai_11, input_precision='ieee') + + tl.dot(A_42, Ai_21, input_precision='ieee') + + tl.dot(A_43, Ai_31, input_precision='ieee'), + input_precision='ieee' + ) + + p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0)) + p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0)) + p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)) + p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)) + p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)) + tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_33, Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_44, Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_31, Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_32, Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_41, Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_42, Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_43, Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@input_guard +def solve_tril( + A: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float +) -> torch.Tensor: + """ + Compute the inverse of the lower triangular matrix + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (torch.Tensor): + [B, T, H, K] if head_first else [B, H, T, K] + cu_seqlens (torch.Tensor): + The cumulative sequence lengths of the input tensor. + Default: None. + head_first (bool): + If False, the input/output tensor is in the shape of [B, T, H, K]. + If True, the input/output tensor is in the shape of [B, H, T, K]. + Default: False + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float` + + Returns: + (I + A)^-1 with the same shape as A + """ + assert A.shape[-1] in [16, 32, 64] + assert A.dtype == torch.float, "A should be float32." + + if head_first: + B, H, T, BT = A.shape + Ad = torch.empty(B, H, T, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype) + else: + B, T, H, BT = A.shape + Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype) + + indices = prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None + NT = len(indices) if cu_seqlens is not None else triton.cdiv(T, 16) + solve_tril_16x16_kernel[NT, B * H]( + A=A, + Ad=Ad, + offsets=cu_seqlens, + indices=indices, + T=T, + H=H, + BT=BT, + HEAD_FIRST=head_first, + ) + if BT == 16: + return Ad + + if head_first: + Ai = torch.zeros(B, H, T, BT, device=A.device, dtype=output_dtype) + else: + Ai = torch.zeros(B, T, H, BT, device=A.device, dtype=output_dtype) + merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel + indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = len(indices) if cu_seqlens is not None else triton.cdiv(T, BT) + merge_fn[NT, B * H]( + A=A, + Ad=Ad, + Ai=Ai, + offsets=cu_seqlens, + indices=indices, + T=T, + H=H, + BT=BT, + HEAD_FIRST=head_first, + USE_OFFSETS=cu_seqlens is not None + ) + return Ai diff --git a/fla/ops/utils/testing.py b/fla/ops/utils/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..6f4fb01202e6bfbea3351defe8a424ca648ce7d1 --- /dev/null +++ b/fla/ops/utils/testing.py @@ -0,0 +1,26 @@ +import os + +compiled_mode = os.getenv("COMPILER_MODE") == "1" +ci_env = os.getenv("CI_ENV") == "1" + + +def get_abs_err(x, y): + return (x.detach()-y.detach()).flatten().abs().max().item() + + +def get_err_ratio(x, y): + err = (x-y).flatten().square().mean().sqrt().item() + base = (x).flatten().square().mean().sqrt().item() + return err / (base + 1e-15) + + +def assert_close(prefix, ref, tri, ratio, warning=False): + msg = f"{prefix} diff: {get_abs_err(ref, tri):.6f} ratio: {get_err_ratio(ref, tri):.6f}" + print(msg) + error_rate = get_err_ratio(ref, tri) + if warning or str(prefix).strip().lower() == "dh0" or (ci_env and error_rate < 0.01): + if error_rate > ratio: + import warnings + warnings.warn(msg) + else: + assert error_rate < ratio, msg diff --git a/flame/__pycache__/__init__.cpython-312.pyc b/flame/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3bf7f9da46a16691681791592e3a168c4051fc8 Binary files /dev/null and b/flame/__pycache__/__init__.cpython-312.pyc differ diff --git a/flame/__pycache__/config_manager.cpython-312.pyc b/flame/__pycache__/config_manager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fec51e6008613bc3774cae23d41e703734d9c4b9 Binary files /dev/null and b/flame/__pycache__/config_manager.cpython-312.pyc differ diff --git a/flame/__pycache__/data.cpython-312.pyc b/flame/__pycache__/data.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76410c440aa57fd98222014d83ec35fcca68443b Binary files /dev/null and b/flame/__pycache__/data.cpython-312.pyc differ diff --git a/flame/__pycache__/train.cpython-312.pyc b/flame/__pycache__/train.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82e58bffa81e8f6ed6565a68c2378d8271104d94 Binary files /dev/null and b/flame/__pycache__/train.cpython-312.pyc differ diff --git a/flame/components/__init__.py b/flame/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/flame/models/__pycache__/__init__.cpython-312.pyc b/flame/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe8c39cbf9f93b5f16d7ae83dee7b555e3ed9c5c Binary files /dev/null and b/flame/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/flame/models/fla.toml b/flame/models/fla.toml new file mode 100644 index 0000000000000000000000000000000000000000..bbc11346144d0bfa3d8c6cf5b17a03208028ea46 --- /dev/null +++ b/flame/models/fla.toml @@ -0,0 +1,68 @@ +[model] +config = "fla-hub/transformer-1.3B-100B" +tokenizer_path = "fla-hub/transformer-1.3B-100B" + +[job] +dump_folder = "exp" +print_args = true + +[training] +batch_size = 32 +seq_len = 2048 +context_len = 2048 +gradient_accumulation_steps = 1 +steps = 20480 +max_norm = 1.0 +skip_nan_inf = true +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 1 +compile = false +dataset = "HuggingFaceFW/fineweb-edu" +dataset_name = "default" +num_workers = 32 +pin_memory = false +persistent_workers = false +prefetch_factor = 2 +seed = 42 +varlen = false +dataset_mode = "pretrain" + +[optimizer] +name = "AdamW" +eps = 1e-15 +lr = 3e-4 + +[lr_scheduler] +warmup_steps = 1024 +decay_type = "cosine" +lr_min = 0.1 + +[checkpoint] +enable_checkpoint = true +folder = "checkpoint" +interval_type = "steps" +interval = 2048 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 512 + +[metrics] +log_freq = 32 +enable_wandb = true + +[experimental] +context_parallel_degree = 1 +pipeline_parallel_degree = 1 + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false + +[activation_checkpoint] +mode = "none" diff --git a/flame/utils/__init__.py b/flame/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/flame/utils/__pycache__/convert_hf_to_dcp.cpython-312.pyc b/flame/utils/__pycache__/convert_hf_to_dcp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e510c4220bc8dc9968199f7f84fb789a8e447fc Binary files /dev/null and b/flame/utils/__pycache__/convert_hf_to_dcp.cpython-312.pyc differ diff --git a/flame/utils/checkpoint.py b/flame/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..839ac7df075c3bfca6747855781953c8a82a4c28 --- /dev/null +++ b/flame/utils/checkpoint.py @@ -0,0 +1,50 @@ +import os +import glob +import re +import shutil +from torchtitan.tools.logging import logger + + +def cleanup_local_checkpoints(checkpoint_dir: str, keep_latest_k: int): + """Removes older checkpoint directories locally, keeping only the latest k for both DCP and HF formats.""" + if keep_latest_k <= 0: + return # Keep all checkpoints + + logger.info(f"Cleaning up local checkpoints in {checkpoint_dir}, keeping latest {keep_latest_k}") + + # Cleanup DCP checkpoints (step-*) + dcp_checkpoints = sorted( + glob.glob(os.path.join(checkpoint_dir, "step-*")), + key=lambda x: int(re.search(r"step-(\d+)", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)", os.path.basename(x)) and not x.endswith("-hf") else -1, + reverse=True + ) + # Filter out HF format directories + dcp_checkpoints = [d for d in dcp_checkpoints if not d.endswith("-hf")] + + if len(dcp_checkpoints) > keep_latest_k: + checkpoints_to_delete = dcp_checkpoints[keep_latest_k:] + logger.info(f"Deleting {len(checkpoints_to_delete)} old DCP checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}") + for ckpt_path in checkpoints_to_delete: + if os.path.isdir(ckpt_path): # Ensure it's a directory + try: + shutil.rmtree(ckpt_path) + except OSError as e: + logger.error(f"Error removing directory {ckpt_path}: {e}") + + + # Cleanup HF checkpoints (step-*-hf) + hf_checkpoints = sorted( + glob.glob(os.path.join(checkpoint_dir, "step-*-hf")), + key=lambda x: int(re.search(r"step-(\d+)-hf", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)-hf", os.path.basename(x)) else -1, + reverse=True + ) + + if len(hf_checkpoints) > keep_latest_k: + checkpoints_to_delete = hf_checkpoints[keep_latest_k:] + logger.info(f"Deleting {len(checkpoints_to_delete)} old HF checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}") + for ckpt_path in checkpoints_to_delete: + if os.path.isdir(ckpt_path): # Ensure it's a directory + try: + shutil.rmtree(ckpt_path) + except OSError as e: + logger.error(f"Error removing directory {ckpt_path}: {e}") diff --git a/torchtitan/components/__pycache__/checkpoint.cpython-312.pyc b/torchtitan/components/__pycache__/checkpoint.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34fce047adeed813712a1051d4dd6ecf499eb4e8 Binary files /dev/null and b/torchtitan/components/__pycache__/checkpoint.cpython-312.pyc differ diff --git a/torchtitan/components/__pycache__/dataloader.cpython-312.pyc b/torchtitan/components/__pycache__/dataloader.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b203cd02c45153301ad9edcec057817552ab2bd8 Binary files /dev/null and b/torchtitan/components/__pycache__/dataloader.cpython-312.pyc differ diff --git a/torchtitan/components/__pycache__/loss.cpython-312.pyc b/torchtitan/components/__pycache__/loss.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d5fb81cb7b188bf4fa2e245a3613471894376bb Binary files /dev/null and b/torchtitan/components/__pycache__/loss.cpython-312.pyc differ diff --git a/torchtitan/components/__pycache__/tokenizer.cpython-312.pyc b/torchtitan/components/__pycache__/tokenizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2161c8ad260049c0b7370b3a6605c09b51405d0 Binary files /dev/null and b/torchtitan/components/__pycache__/tokenizer.cpython-312.pyc differ diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..4c90ed54a4af3d644abb552615675a7af5f15910 --- /dev/null +++ b/torchtitan/components/metrics.py @@ -0,0 +1,435 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import time +from collections import namedtuple +from datetime import datetime +from typing import Any + +import torch +from torch.utils.tensorboard import SummaryWriter +from torchtitan.components.lr_scheduler import LRSchedulersContainer +from torchtitan.components.optimizer import OptimizersContainer +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims +from torchtitan.tools import utils +from torchtitan.tools.logging import logger +from torchtitan.tools.utils import Color, device_module, device_type + +# named tuple for passing device memory stats for logging +DeviceMemStats = namedtuple( + "DeviceMemStats", + [ + "max_active_gib", + "max_active_pct", + "max_reserved_gib", + "max_reserved_pct", + "num_alloc_retries", + "num_ooms", + ], +) + + +class DeviceMemoryMonitor: + def __init__(self, device: str = f"{device_type}:0"): + self.device = torch.device(device) # device object + self.device_name = device_module.get_device_name(self.device) + self.device_index = device_module.current_device() + self.device_capacity = device_module.get_device_properties( + self.device + ).total_memory + self.device_capacity_gib = self._to_gib(self.device_capacity) + + device_module.reset_peak_memory_stats() + device_module.empty_cache() + + def _to_gib(self, memory_in_bytes): + # NOTE: GiB (gibibyte) is 1024, vs GB is 1000 + _gib_in_bytes = 1024 * 1024 * 1024 + memory_in_gib = memory_in_bytes / _gib_in_bytes + return memory_in_gib + + def _to_pct(self, memory): + return 100 * memory / self.device_capacity + + def get_peak_stats(self): + device_info = device_module.memory_stats(self.device) + + max_active = device_info.get("active_bytes.all.peak", -1) + max_active_gib = self._to_gib(max_active) + max_active_pct = self._to_pct(max_active) + + max_reserved = device_info.get("reserved_bytes.all.peak", -1) + max_reserved_gib = self._to_gib(max_reserved) + max_reserved_pct = self._to_pct(max_reserved) + + num_retries = device_info.get("num_alloc_retries", -1) + num_ooms = device_info.get("num_ooms", -1) + + if num_retries > 0: + logger.warning( + f"{num_retries} {device_type.upper()} memory allocation retries." + ) + if num_ooms > 0: + logger.warning(f"{num_ooms} {device_type.upper()} OOM errors thrown.") + + return DeviceMemStats( + max_active_gib, + max_active_pct, + max_reserved_gib, + max_reserved_pct, + num_retries, + num_ooms, + ) + + def reset_peak_stats(self): + device_module.reset_peak_memory_stats() + + +def build_device_memory_monitor(): + device_memory_monitor = DeviceMemoryMonitor(device_type) + logger.info( + f"{device_type.upper()} capacity: {device_memory_monitor.device_name} " + f"with {device_memory_monitor.device_capacity_gib:.2f}GiB memory" + ) + return device_memory_monitor + + +class BaseLogger: + """Logger that does nothing, used when logging is disabled.""" + + def log(self, metrics: dict[str, Any], step: int) -> None: + pass + + def close(self) -> None: + pass + + +class TensorBoardLogger(BaseLogger): + """Logger implementation for TensorBoard.""" + + def __init__(self, log_dir: str, tag: str | None = None): + self.tag = tag + self.writer = SummaryWriter(log_dir, max_queue=1000) + logger.info(f"TensorBoard logging enabled. Logs will be saved at {log_dir}") + + def log(self, metrics: dict[str, Any], step: int) -> None: + for k, v in metrics.items(): + tag = k if self.tag is None else f"{self.tag}/{k}" + self.writer.add_scalar(tag, v, step) + + def close(self) -> None: + self.writer.close() + + +class WandBLogger(BaseLogger): + """Logger implementation for Weights & Biases.""" + + def __init__(self, log_dir: str, tag: str | None = None): + # Import wandb here to avoid startup import + import wandb + + self.wandb = wandb + self.tag = tag + + # Create logging directory + os.makedirs(log_dir, exist_ok=True) + + self.wandb.init( + project=os.getenv("WANDB_PROJECT", "torchtitan"), + dir=log_dir, + ) + logger.info("WandB logging enabled") + + def log(self, metrics: dict[str, Any], step: int) -> None: + wandb_metrics = { + (k if self.tag is None else f"{self.tag}/{k}"): v + for k, v in metrics.items() + } + self.wandb.log(wandb_metrics, step=step) + + def close(self) -> None: + if self.wandb.run is not None: + self.wandb.finish() + + +def ensure_pp_loss_visible( + parallel_dims: ParallelDims, job_config: JobConfig, color: Color +) -> None: + """ + Ensures that the loss is visible on the console for pipeline-parallel training. + + For pipeline-parallel training, the loss is only visible on the last pipeline stage. + This function checks if the appropriate rank is included in the LOG_RANK environment + variable and warns if it's not. + """ + + # V Block Schedules return loss on rank 0 + if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble": + return + + # Calculate the rank where loss is visible (first rank of the last pipeline stage) + world_size = parallel_dims.world_size + pp_size = parallel_dims.pp + loss_visible_rank = (world_size // pp_size) * (pp_size - 1) + + # Check if the loss-visible rank is included in LOG_RANK environment variable + env_logged_ranks = os.environ.get("LOG_RANK", "").split(",") + if env_logged_ranks == [""]: + env_logged_ranks = [] + + if str(loss_visible_rank) not in env_logged_ranks: + logger.warning( + f"{color.red}Pipeline Parallel loss is not visible. " + f"Please add {color.yellow}rank {loss_visible_rank}{color.red} " + f"to LOG_RANK environment variable in run_train.sh.{color.reset}" + ) + + +def _get_metrics_rank( + parallel_dims: ParallelDims, + job_config: JobConfig, +) -> int: + """ + Determines which rank should log metrics. + + Returns: + int: The rank responsible for logging metrics: + - Rank 0 for non-pipeline-parallel configs + - Rank 0 for pipeline-parallel 'ZBVZeroBubble' schedule + - The first rank of the last pipeline stage for other pipeline-parallel schedules + """ + # Early return for non-pipeline-parallel configurations + if not parallel_dims.pp_enabled: + return 0 + + # V Block Schedules return loss on rank 0 + if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble": + return 0 + + # Calculate first rank of the last pipeline stage + world_size = parallel_dims.world_size + pp_size = parallel_dims.pp + return (world_size // pp_size) * (pp_size - 1) + + +def _build_metric_logger( + job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None +) -> BaseLogger: + """ + Build an appropriate metric logger based on configuration. + """ + metrics_config = job_config.metrics + + # Log initial config state + logger.debug( + f"Building logger with config: wandb={metrics_config.enable_wandb}, " + f"tensorboard={metrics_config.enable_tensorboard}" + ) + + # Check if any logging backend is enabled + has_logging_enabled = ( + metrics_config.enable_tensorboard or metrics_config.enable_wandb + ) + + # Determine if this rank should log + should_log = has_logging_enabled + if (not metrics_config.save_for_all_ranks) and should_log: + metrics_rank = _get_metrics_rank(parallel_dims, job_config) + should_log = torch.distributed.get_rank() == metrics_rank + + logger.debug( + f"Logging decision: has_logging_enabled={has_logging_enabled}, should_log={should_log}" + ) + + if not should_log: + logger.debug("Returning BaseLogger due to should_log=False") + return BaseLogger() + + # Setup logging directory + dump_dir = job_config.job.dump_folder + base_log_dir = os.path.join( + dump_dir, metrics_config.save_tb_folder, datetime.now().strftime("%Y%m%d-%H%M") + ) + + if metrics_config.save_for_all_ranks: + base_log_dir = os.path.join( + base_log_dir, f"rank_{torch.distributed.get_rank()}" + ) + + # Create loggers in priority order + if metrics_config.enable_wandb: + logger.debug("Attempting to create WandB logger") + try: + return WandBLogger(base_log_dir, tag) + except Exception as e: + if "No module named 'wandb'" in str(e): + logger.error( + "Failed to create WandB logger: No module named 'wandb'. Please install it using 'pip install wandb'." + ) + else: + logger.error(f"Failed to create WandB logger: {e}") + + if metrics_config.enable_tensorboard: + logger.debug("Creating TensorBoard logger") + return TensorBoardLogger(base_log_dir, tag) + + logger.debug("No loggers enabled, returning BaseLogger") + return BaseLogger() + + +class MetricsProcessor: + """Metrics processor to processes the metrics and log metrics. + + The current MetricsProcessor log some metrics to STDOUT and some metrics to + TensorBoard or WandB. + + Args: + job_config (JobConfig): Job configuration. + parallel_dims (ParallelDims): Parallel dimensions. + tag (Optional[str]): Tag to use for TensorBoard or WandB. Defaults to None. + """ + + logger: BaseLogger + parallel_dims: ParallelDims + job_config: JobConfig + device_memory_monitor: DeviceMemoryMonitor + color: utils.NoColor | utils.Color + + gpu_peak_flops: int + ntokens_since_last_log: int + data_loading_times: list[float] + time_last_log: float + + num_flops_per_token: int + optimizers: OptimizersContainer | None + lr_schedulers: LRSchedulersContainer | None + + def __init__( + self, + job_config: JobConfig, + parallel_dims: ParallelDims, + tag: str | None = None, + ): + self.logger = _build_metric_logger(job_config, parallel_dims, tag) + self.parallel_dims = parallel_dims + self.job_config = job_config + self.device_memory_monitor = build_device_memory_monitor() + # used for colorful printing + self.color = ( + utils.NoColor() + if job_config.metrics.disable_color_printing + else utils.Color() + ) + + self.gpu_peak_flops = utils.get_peak_flops( + self.device_memory_monitor.device_name + ) + self.ntokens_since_last_log = 0 + self.data_loading_times = [] + self.time_last_log = time.perf_counter() + self.device_memory_monitor.reset_peak_stats() + + # These variables have to be set later as they depend on other components or model. + self.num_flops_per_token = -1 + self.optimizers = None + self.lr_schedulers = None + + def should_log(self, step: int) -> bool: + return step == 1 or step % self.job_config.metrics.log_freq == 0 + + def log( + self, + step: int, + global_avg_loss: float, + global_max_loss: float, + extra_metrics: dict[str, Any] | None = None, + ): + assert self.num_flops_per_token > 0, "num_flops_per_token must be set" + + time_delta = time.perf_counter() - self.time_last_log + + # tokens per second per device, abbreviated as tps + tps = self.ntokens_since_last_log / ( + time_delta * self.parallel_dims.non_data_parallel_size + ) + # model FLOPS utilization + # For its definition and calculation, please refer to the PaLM paper: + # https://arxiv.org/abs/2204.02311 + mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops + tflops = self.num_flops_per_token * tps / 1e12 + + time_end_to_end = time_delta / self.job_config.metrics.log_freq + time_data_loading = sum(self.data_loading_times) / len(self.data_loading_times) + time_data_loading_pct = 100 * sum(self.data_loading_times) / time_delta + + device_mem_stats = self.device_memory_monitor.get_peak_stats() + + metrics = { + "loss_metrics/global_avg_loss": global_avg_loss, + "loss_metrics/global_max_loss": global_max_loss, + "throughput(tps)": tps, + "tflops": tflops, + "mfu(%)": mfu, + "time_metrics/end_to_end(s)": time_end_to_end, + "time_metrics/data_loading(s)": time_data_loading, + "time_metrics/data_loading(%)": time_data_loading_pct, + "memory/max_active(GiB)": device_mem_stats.max_active_gib, + "memory/max_active(%)": device_mem_stats.max_active_pct, + "memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, + "memory/max_reserved(%)": device_mem_stats.max_reserved_pct, + "memory/num_alloc_retries": device_mem_stats.num_alloc_retries, + "memory/num_ooms": device_mem_stats.num_ooms, + } + + if extra_metrics: + metrics.update(extra_metrics) + + self.logger.log(metrics, step) + + color = self.color + construct_string = str( + f"{color.red}step: {step:2} " + f"{color.green}loss: {global_avg_loss:7.4f} " + f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" + f"({device_mem_stats.max_reserved_pct:.2f}%) " + f"{color.blue}tps: {round(tps):,} " + f"{color.cyan}tflops: {tflops:,.2f} " + f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" + ) + + if extra_metrics: + for k, v in extra_metrics.items(): + if "loss" in k: + construct_string += f" {color.white}{k.lstrip('loss_metrics/')}: {v:7.4f}" + logger.info( + construct_string + ) + + self.ntokens_since_last_log = 0 + self.data_loading_times.clear() + self.time_last_log = time.perf_counter() + self.device_memory_monitor.reset_peak_stats() + + def close(self): + self.logger.close() + + +def build_metrics_processor( + job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None +) -> MetricsProcessor: + """Create a metrics processor. + + Args: + job_config (JobConfig): Job configuration. + parallel_dims (ParallelDims): Parallel dimensions. + tag (Optional[str]): Tag to use for TensorBoard or WandB. Defaults to None. + + Returns: + MetricsProcessor: A metrics processor. + """ + return MetricsProcessor(job_config, parallel_dims, tag) diff --git a/torchtitan/datasets/__pycache__/hf_datasets.cpython-312.pyc b/torchtitan/datasets/__pycache__/hf_datasets.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2710bdae54286a78a6f6e713da8cbe74ff98f606 Binary files /dev/null and b/torchtitan/datasets/__pycache__/hf_datasets.cpython-312.pyc differ diff --git a/torchtitan/datasets/tokenizer/__pycache__/tiktoken.cpython-312.pyc b/torchtitan/datasets/tokenizer/__pycache__/tiktoken.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..047a55c7b1e71d43d3fe2f94669afb6a0caaf43b Binary files /dev/null and b/torchtitan/datasets/tokenizer/__pycache__/tiktoken.cpython-312.pyc differ diff --git a/torchtitan/datasets/tokenizer/tiktoken.py b/torchtitan/datasets/tokenizer/tiktoken.py new file mode 100644 index 0000000000000000000000000000000000000000..401757a93e6b598a6a3a60c4ca934ea0427f25a4 --- /dev/null +++ b/torchtitan/datasets/tokenizer/tiktoken.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import os +from collections.abc import Collection, Iterator, Sequence, Set as AbstractSet +from pathlib import Path +from typing import cast, Literal + +import tiktoken +from tiktoken.load import load_tiktoken_bpe + +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger + + +class TikTokenizer(Tokenizer): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + + special_tokens: dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950 + + def __init__(self, model_path: str): + super().__init__() + assert os.path.exists( + model_path + ), f"The tokenizer path does not exist: {model_path}" + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [ + f"<|reserved_special_token_{i}|>" + for i in range(5, self.num_reserved_special_tokens - 5) + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + self._n_words: int = self.model.n_vocab + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.pad_id: int = -1 + self.stop_tokens = { + self.special_tokens["<|end_of_text|>"], + self.special_tokens["<|eot_id|>"], + } + logger.info( + f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}" + ) + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Literal["all"] | AbstractSet[str] | None = None, + disallowed_special: Literal["all"] | Collection[str] | None = None, + ) -> list[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_tokens ("all"|set[str]): allowed special tokens in string + disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + assert type(s) is str + allowed_special = allowed_special or set() + disallowed_special = disallowed_special or () + + # The tiktoken tokenizer can handle <=400k chars without + # pyo3_runtime.PanicException. + TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + + # https://github.com/openai/tiktoken/issues/195 + # Here we iterate over subsequences and split if we exceed the limit + # of max consecutive non-whitespace or whitespace characters. + MAX_NO_WHITESPACES_CHARS = 25_000 + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: list[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(list[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + + +def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer: + return TikTokenizer(job_config.model.tokenizer_path) diff --git a/torchtitan/distributed/__pycache__/__init__.cpython-312.pyc b/torchtitan/distributed/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..430eb691df574ce3e077d9a14df1459f3694f4ce Binary files /dev/null and b/torchtitan/distributed/__pycache__/__init__.cpython-312.pyc differ diff --git a/torchtitan/distributed/__pycache__/pipeline.cpython-312.pyc b/torchtitan/distributed/__pycache__/pipeline.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef7589241320ec60bf61ddc88f2a699431a48ad4 Binary files /dev/null and b/torchtitan/distributed/__pycache__/pipeline.cpython-312.pyc differ diff --git a/torchtitan/distributed/__pycache__/utils.cpython-312.pyc b/torchtitan/distributed/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ed9559f6adf321600f020336b2bf86565842624 Binary files /dev/null and b/torchtitan/distributed/__pycache__/utils.cpython-312.pyc differ diff --git a/torchtitan/experiments/deepseek_v3/attn_mask_utils.py b/torchtitan/experiments/deepseek_v3/attn_mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a54899c34e021a43c8a7e090d854140afa8f9e7 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/attn_mask_utils.py @@ -0,0 +1,397 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This code is based on src/transformers/modeling_attn_mask_utils.py of +# huggingface/transformers. It has been modified from its original forms to +# contain only the necessary utilities. + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + + +@dataclass +class AttentionMaskConverter: + """ + A utility attention mask class that allows one to: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Examples: + + ```python + >>> import torch + >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + >>> converter = AttentionMaskConverter(True) + >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32) + tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]]) + ``` + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. + """ + + is_causal: bool + sliding_window: int + + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + + if self.sliding_window is not None and self.sliding_window <= 0: + raise ValueError( + "Make sure that when passing `sliding_window` that its value is a strictly positive integer, " + f"not `{self.sliding_window}`" + ) + + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype, + device: Union[torch.device, "str"] = "cpu", + ) -> Optional[torch.Tensor]: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError( + f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True." + ) + + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + dtype: torch.dtype, + key_value_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + if key_value_length is None: + raise ValueError( + "This attention mask converter is causal. Make sure to pass " + "`key_value_length` to correctly create a causal mask." + ) + + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError( + "Sliding window is currently only implemented for causal masking" + ) + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask( + attention_mask_2d, dtype, tgt_len=input_shape[-1] + ).to(attention_mask_2d.device) + + if causal_4d_mask is not None: + expanded_attn_mask = causal_4d_mask.masked_fill( + expanded_attn_mask.bool(), torch.finfo(dtype).min + ) + + # expanded_attn_mask + causal_4d_mask can cause some overflow + expanded_4d_mask = expanded_attn_mask + + return expanded_4d_mask + + @staticmethod + def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window - 1 + + context_mask = torch.tril( + torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal + ) + mask.masked_fill_(context_mask, torch.finfo(dtype).min) + + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + @staticmethod + def _expand_mask( + mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None + ): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = ( + mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + ) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + @staticmethod + def _unmask_unattended( + expanded_mask: torch.FloatTensor, + min_dtype: float, + ): + # fmt: off + """ + Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when + using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + Details: https://github.com/pytorch/pytorch/issues/110213 + + `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. + `attention_mask` is [bsz, src_seq_len]. + + The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case + of alibi attention bias. + + For example, if `expanded_mask` is (e.g. here left-padding case) + ``` + [[[[0, 0, 0], + [0, 0, 0], + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[0, 0, 0], + [0, 1, 0], + [0, 1, 1]]]] + ``` + then the modified `expanded_mask` will be + ``` + [[[[1, 1, 1], <-- modified + [1, 1, 1], <-- modified + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[1, 1, 1], <-- modified + [0, 1, 0], + [0, 1, 1]]]] + ``` + """ + # fmt: on + if expanded_mask.dtype == torch.bool: + raise ValueError( + "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor." + ) + + return expanded_mask.mul( + ~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True) + ) + + @staticmethod + def _ignore_causal_mask_sdpa( + attention_mask: Optional[torch.Tensor], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, + is_training: bool = False, + ) -> bool: + """ + Detects whether the optional user-specified attention_mask & the automatically created causal mask can be + ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. + + In case no token is masked in the `attention_mask` argument, if `query_length == 1` or + `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is + passed). + """ + + _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] + key_value_length = query_length + past_key_values_length + + is_tracing = ( + torch.jit.is_tracing() + or isinstance(inputs_embeds, torch.fx.Proxy) + or is_torchdynamo_compiling() + ) + + ignore_causal_mask = False + + if attention_mask is None: + # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input + # shape, thus SDPA's `is_causal` argument is rightfully updated + # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using + # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is + # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` + # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). + # Thus, we only set `ignore_causal_mask = True` if the model is set to training. + # + # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` + # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor"). + if ( + (is_training or not is_tracing) + and (query_length == 1 or key_value_length == query_length) + and (sliding_window is None or key_value_length < sliding_window) + ): + ignore_causal_mask = True + elif sliding_window is None or key_value_length < sliding_window: + if len(attention_mask.shape) == 4: + return False + elif not is_tracing and torch.all(attention_mask == 1): + if query_length == 1 or key_value_length == query_length: + # For query_length == 1, causal attention and bi-directional attention are the same. + ignore_causal_mask = True + + # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore + # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in + # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3. + + return ignore_causal_mask + + +def _prepare_4d_causal_attention_mask( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + attention_mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + inputs_embeds (`torch.Tensor`): + The embedded inputs as a torch Tensor. + past_key_values_length (`int`): + The length of the key value cache. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter( + is_causal=True, sliding_window=sliding_window + ) + + key_value_length = input_shape[-1] + past_key_values_length + + # 4d mask is passed through the layers + if attention_mask is not None and len(attention_mask.shape) == 2: + attention_mask = attn_mask_converter.to_4d( + attention_mask, + input_shape[-1], + key_value_length=key_value_length, + dtype=inputs_embeds.dtype, + ) + elif attention_mask is not None and len(attention_mask.shape) == 4: + expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + else: + # if the 4D mask has correct shape - invert it and fill with negative infinity + inverted_mask = 1.0 - attention_mask + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min + ) + else: + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], + input_shape[-1], + key_value_length, + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ) + + return attention_mask diff --git a/torchtitan/experiments/deepseek_v3/checkpoint.py b/torchtitan/experiments/deepseek_v3/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..535ac7fe069a88555841181dddc1e870c2d30934 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/checkpoint.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +import os +from typing import Dict, Optional, Set, Tuple + +import torch +from safetensors import safe_open + +from transformers.utils import cached_file + + +logger = logging.getLogger(__name__) + +_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json" + + +def read_weights_from_json(file_path: str) -> Optional[Dict[str, str]]: + try: + with open(file_path, "r") as file: + data = json.load(file) + + if "weight_map" in data and isinstance(data["weight_map"], dict): + return data["weight_map"] + else: + logger.info("No 'weight_map' dictionary found in the JSON file.") + return None + except (json.JSONDecodeError, Exception) as e: + logger.info(f"An error occurred while reading the JSON file: {str(e)}") + return None + + +def get_hf_weight_map_and_path( + model_id: str, +) -> Tuple[Dict[str, str], str]: + """Get the weight map for a given HF model id and also the cache path for loading the weights""" + try: + index_file = cached_file(model_id, _DEFAULT_SAFETENSOR_FILE_NAME) + except Exception as e: + logger.error( + f"Model `{model_id}` not found in HF cache. " + f"You can download the model using `python download.py {model_id}" + ) + raise e + + weight_map = read_weights_from_json(index_file) + weight_path = os.path.dirname(index_file) + logger.info(f"Loading weights from: {weight_path}") + return weight_map, weight_path + + +def get_needed_files( + state_dict: Dict[str, torch.Tensor], weight_map: Dict[str, str] +) -> Set[str]: + needed_files = set() + for param in state_dict.keys(): + file = weight_map.get(param) + if file: + needed_files.add(file) + elif param.endswith("weight"): + raise ValueError( + f"Parameter {param} not found in weight map, please check..." + ) + logger.info(f"Needed files: {needed_files}") + return needed_files + + +def load_safetensor_file( + full_path: str, device: torch.device +) -> Dict[str, torch.Tensor]: + tensors = {} + with safe_open(full_path, framework="pt", device=device) as f: + for k in f.keys(): + tensors[k] = f.get_tensor(k) + logger.info(f"Loaded {len(tensors)} tensors from {full_path}") + return tensors + + +def load_safetensor_weights( + model: torch.nn.Module, + weight_map: Dict[str, str], + file_location: str, + device: torch.device, +): + """ + Load safetensor weights into a `nn.Module`. + + Args: + model (Module): The PyTorch module to load weights into. It may be a + model chunk or a full model. + weight_map (Dict[str, str]): Mapping of model parameters to file names. + file_location (str): Directory containing the weight files. + device (torch.device): The device to load tensors onto. + """ + model_state_dict = model.state_dict() + needed_files = get_needed_files(model_state_dict, weight_map) + updated_states: Set[str] = set() + + for file in needed_files: + full_path = os.path.join(file_location, file) + try: + checkpoint = load_safetensor_file(full_path, "cpu") + except FileNotFoundError: + logger.error(f"File not found: {full_path}") + except Exception as e: + logger.error(f"Error during checkpoint processing of {full_path}: {str(e)}") + + matched_keys = set(checkpoint.keys()) & set(model_state_dict.keys()) + for key in matched_keys: + # Check shape + if model_state_dict[key].shape != checkpoint[key].shape: + raise ValueError( + f"Shape mismatch for {key}: " + f"model needs {model_state_dict[key].shape}, but " + f"checkpoint has {checkpoint[key].shape}" + ) + model_state_dict[key] = checkpoint[key].to(device) + + updated_states.update(matched_keys) + + missing_keys = set(model_state_dict.keys()) - updated_states + if missing_keys: + raise RuntimeError( + f"Partially updated state dict. Missing parameters: {missing_keys}" + ) + + model.load_state_dict(model_state_dict, strict=False, assign=True) + logger.info(f"Successfully loaded {len(updated_states)} weights into model") + + +def load_weights_from_hf( + model: torch.nn.Module, + distribution: str, + device: torch.device, +): + """ + Load the weights from Hugging Face format (index file + multiple safetensor + files), and fill into `model`. Model config is needed b/c we permute + wq and wk weights based on attn heads. + """ + + weight_map, weight_path = get_hf_weight_map_and_path(distribution) + + load_safetensor_weights( + model, + weight_map, + weight_path, + device, + ) diff --git a/torchtitan/experiments/deepseek_v3/download.py b/torchtitan/experiments/deepseek_v3/download.py new file mode 100644 index 0000000000000000000000000000000000000000..0b9ec3104d716cbd6142c6564d83f042f128770f --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/download.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Usage: +# Downloads a given model to the HF Cache. Pass in a listed option ala "v3" or your own custom model path. +# python download.py {model_id} [custom_model_path] +# Examples: +# python download.py v2 # Use predefined model: deepseek-ai/DeepSeek-V2 +# python download.py custom "deepseek-ai/new-model" # Download a custom model path + +# Available models: +# "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat", +# "v2-lite": "deepseek-ai/DeepSeek-V2-Lite", +# "v2": "deepseek-ai/DeepSeek-V2", +# "v3": "deepseek-ai/deepseek-v3", +# "v3-0324": "deepseek-ai/DeepSeek-V3-0324", +# "custom": None, # Placeholder for custom models + + +import sys + +from transformers import AutoModelForCausalLM + + +MODELS = { + "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat", + "v2-lite": "deepseek-ai/DeepSeek-V2-Lite", + "v2": "deepseek-ai/DeepSeek-V2", + "v3": "deepseek-ai/deepseek-v3", + "v3-0324": "deepseek-ai/DeepSeek-V3-0324", + "custom": None, # For custom (any) models +} + + +def print_usage(): + print("Usage:") + print(" python download.py [model_version]") + print(" python download.py custom [custom_model_path]") + print("\nAvailable predefined models:") + for key, model in MODELS.items(): + if key != "custom": # Skip the custom placeholder + print(f" {key}: {model}") + print("\nFor custom models:") + print(" custom: Specify your own model path") + print(' Example: python download.py custom "organization/model-name"') + sys.exit(1) + + +# Process command line arguments +if len(sys.argv) < 2 or sys.argv[1] not in MODELS: + print_usage() + +if sys.argv[1] == "custom": + if len(sys.argv) != 3: + print("Error: Custom model requires a model path") + print_usage() + model_id = sys.argv[2] + print(f"Using custom model: {model_id}") +else: + model_id = MODELS[sys.argv[1]] +print(f"Downloading model: {model_id}") + +model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map="auto", + trust_remote_code=True, +) diff --git a/torchtitan/experiments/deepseek_v3/generate.py b/torchtitan/experiments/deepseek_v3/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7302f1078ec42d7a30dd7d92e113c9affdd650 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/generate.py @@ -0,0 +1,308 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# torchrun --standalone --nproc-per-node 4 generate.py + +# use inference.sh "Your Question Here?" to run inference with a single prompt. + +import sys +from dataclasses import dataclass + +import torch +import torch.distributed as dist + +from checkpoint import load_weights_from_hf +from model import DeepseekForCausalLM +from model_config import deepseek_config_registry +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.pipelining import PipelineStage, ScheduleGPipe +from torchtitan.tools.utils import Color +from transformers import AutoTokenizer + +# Uncomment the model you want to run. +model_id, mesh_shape = "deepseek-ai/DeepSeek-V2-Lite-Chat", (1, 4) +# model_id, mesh_shape = "deepseek-ai/deepseek-v3", (8, 4) + + +def colorize_chat(text, user_color=None, assistant_color=None, output_color=None): + """Parse and colorize chat output with optional colors for each role.""" + lines = text.split("\n") + result = [] + + current_role = None + current_content = [] + + def _process_current_content(): + if not current_role or not current_content: + return None + + content = "\n".join(current_content) + if current_role == "output": + return ( + f"Output: {output_color}{content}{color.reset}" + if output_color + else f"Output: {content}" + ) + else: + try: + prefix, rest = current_content[0].split(":", 1) + role_color = user_color if current_role == "user" else assistant_color + if role_color: + formatted = f"{prefix}:{role_color}{rest}{color.reset}" + if len(current_content) > 1: + formatted += ( + f"{role_color}\n" + + "\n".join(current_content[1:]) + + f"{color.reset}" + ) + return formatted + except ValueError: + pass + return content + + for line in lines: + if line.startswith("Output:"): + if processed := _process_current_content(): + result.append(processed) + current_role = "output" + content = line[len("Output:") :].strip() + if output_color: + content = f"Output: {output_color}{content}{color.reset}" + else: + content = f"Output: {content}" + result.append(content) + current_content = [] + + elif line.startswith("User:"): + if processed := _process_current_content(): + result.append(processed) + current_role = "user" + current_content = [line] + + elif line.startswith("Assistant:"): + if processed := _process_current_content(): + result.append(processed) + current_role = "assistant" + current_content = [line] + + else: + if current_content: + current_content.append(line) + elif line.strip() and current_role is None: + # Handle system message at the beginning + current_role = "output" + if output_color: + result.append(f"Output: {output_color}{line.strip()}{color.reset}") + else: + result.append(f"Output: {line.strip()}") + + # Process the last segment + if processed := _process_current_content(): + result.append(processed) + + return "\n".join(result) + + +color = Color() + + +@dataclass +class DistConfig: + mesh: DeviceMesh + pp_mesh: DeviceMesh + ep_mesh: DeviceMesh + pp_size: int + ep_size: int + ep_rank: int + pp_rank: int + device: torch.device + + +def create_model(dist_config: DistConfig): + model_args = deepseek_config_registry[model_id] + model_args.ep_size = dist_config.ep_size + model_args.num_stages = dist_config.pp_size + model_args.stage_idx = dist_config.pp_rank + model_args.max_seq_len = 16384 + + with dist_config.device, dist_config.mesh: + model = DeepseekForCausalLM(model_args) + load_weights_from_hf(model, model_id, dist_config.device) + model.eval() + model.setup_symm_mem(torch.bfloat16, dist_config.device) + + stage = PipelineStage( + model, + dist_config.pp_rank, + dist_config.pp_size, + dist_config.device, + group=dist_config.pp_mesh.get_group(), + ) + pp_schedule = ScheduleGPipe(stage, dist_config.pp_size) + return model, pp_schedule + + +def create_dist_config(mesh: DeviceMesh): + rank = dist.get_rank() + device_count = torch.cuda.device_count() + device = torch.device("cuda", rank % device_count) + + dist_config = DistConfig( + mesh=mesh, + pp_mesh=mesh["pp"], + ep_mesh=mesh["ep"], + pp_rank=mesh["pp"].get_local_rank(), + pp_size=mesh["pp"].size(), + ep_size=mesh["ep"].size(), + ep_rank=mesh["ep"].get_local_rank(), + device=device, + ) + return dist_config + + +def decode(tokenizer, x): + output = tokenizer.decode(x[0]) + # Clean up the output by removing special tokens + bos = tokenizer.bos_token + output = output.replace(bos, "") + # Truncate at end of sentence token + eos_token = tokenizer.eos_token + if eos_token and eos_token in output: + output = output.split(eos_token)[0] + colored_output = colorize_chat( + output, + user_color=color.green, + assistant_color=color.cyan, + output_color=color.blue, + ) + return colored_output + + +@torch.inference_mode() +def generate( + model, + pp_schedule, + tokenizer, + dist_config, + messages: list[dict], + n_tokens: int = 50, +): + rank = dist.get_rank() + device = dist_config.device + x = tokenizer.apply_chat_template( + [messages] * dist_config.pp_size, + add_generation_prompt=True, + return_tensors="pt", + ) + next_idx = x.shape[-1] + x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1) + x = x.to(device) + + for _ in range(n_tokens): + if dist_config.pp_size > 1: + if dist_config.pp_rank == 0: + pp_schedule.step(x) + torch.distributed.broadcast( + x, + group=dist_config.pp_mesh.get_group(), + group_src=dist_config.pp_size - 1, + ) + elif dist_config.pp_rank == dist_config.pp_size - 1: + preds = pp_schedule.step() + next_token = torch.argmax(preds[:, next_idx - 1], dim=-1) + x[:, next_idx] = next_token + torch.distributed.broadcast( + x, + group=dist_config.pp_mesh.get_group(), + group_src=dist_config.pp_size - 1, + ) + else: + pp_schedule.step() + torch.distributed.broadcast( + x, + group=dist_config.pp_mesh.get_group(), + group_src=dist_config.pp_size - 1, + ) + + next_idx += 1 + else: + preds = model(x) + next_token = torch.argmax(preds[:, next_idx - 1], dim=-1) + x[:, next_idx] = next_token + next_idx += 1 + + if rank == 0: + colored_output = decode(tokenizer, x) + print(f"Without CUDA Graph:\n{colored_output}") + + +@torch.inference_mode() +def generate_with_cuda_graph( + model, + tokenizer, + dist_config, + messages: list[dict], + n_tokens: int = 10, +): + rank = dist.get_rank() + device = dist_config.device + x = tokenizer.apply_chat_template( + [messages] * dist_config.pp_size, + add_generation_prompt=True, + return_tensors="pt", + ) + next_idx = x.shape[-1] + x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1) + x = x.to(device) + + torch.cuda.synchronize() + + # Create CUDA graph + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + preds = model(x) + + # Run CUDA graph + for _ in range(n_tokens): + g.replay() + next_token = torch.argmax(preds[:, next_idx - 1], dim=-1) + x[:, next_idx] = next_token + next_idx += 1 + + if rank == 0: + colored_output = decode(tokenizer, x) + print(f"With CUDA Graph:\n{colored_output}") + + +if __name__ == "__main__": + # Get user prompt from command line arguments + user_prompt = "What is 2+2?" # Default prompt + if len(sys.argv) > 1: + user_prompt = sys.argv[1] + + mesh = dist.init_device_mesh("cuda", mesh_shape, mesh_dim_names=("pp", "ep")) + rank = dist.get_rank() + if rank == 0: + print( + f"{color.yellow}Running inference with {model_id} on {mesh_shape} mesh{color.reset}" + ) + + dist_config = create_dist_config(mesh) + model, pp_schedule = create_model(dist_config) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": user_prompt}, + ] + + generate(model, pp_schedule, tokenizer, dist_config, messages) + generate_with_cuda_graph(model, tokenizer, dist_config, messages) + + if rank == 0: + print(f"\n{color.yellow}Closing inference mesh...{color.reset}") + + dist.destroy_process_group() diff --git a/torchtitan/experiments/deepseek_v3/indices.py b/torchtitan/experiments/deepseek_v3/indices.py new file mode 100644 index 0000000000000000000000000000000000000000..39d5946ecec40cd0e50e08beb4e2192df036c7ca --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/indices.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import triton +import triton.language as tl + + +__all__ = ["generate_permute_indices"] + + +@triton.jit +def fill_indices_kernel( + tokens_per_expert_group_ptr, # *Pointer* to first input vector. + start_index_values_ptr, # *Pointer* to second input vector. + write_offsets_ptr, # *Pointer* to third input vector. + output_ptr, # *Pointer* to output vector. + experts_per_rank, # Number of experts per rank. + num_ranks, # Number of expert ranks. +): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # The total number of programs in the launch grid. + num_programs = tl.num_programs(axis=0) + # We map the programs (blocks) to the experts. + for expert_id in tl.range(pid, experts_per_rank, step=num_programs): + # Read this expert's write offset. + write_offset = tl.load(write_offsets_ptr + expert_id) + # Loop over the ranks. + for r in tl.range(num_ranks): + # Slot in the tokens_per_expert_group array. + i = r * experts_per_rank + expert_id + start_index = tl.load(start_index_values_ptr + i) + length = tl.load(tokens_per_expert_group_ptr + i) + # Write the indices. + for l in tl.range(length): + val = start_index + l + tl.store(output_ptr + write_offset + l, val) + write_offset += length + + +def fill_indices( + tokens_per_expert_group: torch.Tensor, + start_index_values: torch.Tensor, + write_offsets: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, +): + # We need to preallocate the output. + permuted_indices = torch.full( + (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device + ) + # Analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. + # In this case, we use a 1D grid where the size is the number of blocks (TODO: bump this value). + grid = lambda meta: (1,) + # Each torch.tensor object is implicitly converted into a pointer to its first element. + fill_indices_kernel[grid]( + tokens_per_expert_group, + start_index_values, + write_offsets, + permuted_indices, + experts_per_rank, + num_ranks, + ) + return permuted_indices + + +def fill_indices_cpu( + tokens_per_expert_group: torch.Tensor, + start_index_values: torch.Tensor, + write_offsets: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, +): + # We need to preallocate the output. + permuted_indices = torch.full((max_len,), -1, dtype=torch.int32) + # Fill the permuted indices + # For each local expert + for e in range(experts_per_rank): + write_start = write_offsets[e] + # For each remote rank + for r in range(num_ranks): + i = r * experts_per_rank + e + start_index = start_index_values[i] + length = tokens_per_expert_group[i] + # Fill in the indices + permuted_indices[write_start : write_start + length] = torch.arange( + start_index, start_index + length + ) + write_start += length + return permuted_indices + + +def generate_permute_indices( + tokens_per_expert_group: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, + alignment: int, + use_cpu: bool = False, +): + # Prepare permutation indices and the number of tokens for each expert. The + # permutation indices are the indices of the tokens for each expert. The + # number of tokens for each expert is the sum of the number of tokens for + # such experts from all ranks. This number is aligned to the provided + # alignment requirement (usually comes from group gemm). + + # Args: + # tokens_per_expert_group: number of tokens for each expert from all ranks. + # experts_per_rank: number of experts per rank. + # num_ranks: number of ranks. + # max_len: maximum length of the output index vector. If greater than + # total number of tokens, the remaining indices are set to -1. + # alignment: alignment for each returned element in `m_sizes`. + # use_cpu: whether to use cpu or gpu. + # Returns: + # permuted_indices: permutation indices. + # m_sizes: number of tokens for each expert. + + # `tokens_per_expert_group` is of shape (num_ranks * experts_per_rank,), for example: + # From: | rank 0 | rank 1 | + # To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 | + # | 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 | + + # Prefix sum to get the start index value of each expert + start_index_values = ( + torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group + ) + # Chunk sizes for each expert + chunk_size_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0) + # Align the chunk sizes to the given alignment + m_sizes = ((chunk_size_per_expert + alignment - 1) // alignment * alignment).to( + torch.int32 + ) + # Perform another prefix sum to get the write offset of each expert in `permuted_indices` + write_offsets = torch.cumsum(m_sizes, 0) - m_sizes + # Select the method to fill the permuted indices + fill_fn = fill_indices_cpu if use_cpu else fill_indices + # Fill the permuted indices + permuted_indices = fill_fn( + tokens_per_expert_group, + start_index_values, + write_offsets, + experts_per_rank, + num_ranks, + max_len, + ) + return permuted_indices, m_sizes + + +# Below is for testing only + + +def test(): + device = torch.device("cuda", 0) + experts_per_rank = 4 + num_ranks = 4 + tokens_per_expert_group = torch.full( + (num_ranks * experts_per_rank,), 4, dtype=torch.int32, device=device + ) + max_len = 128 + alignment = 32 + # Use the GPU kernel + permuted_indices_gpu, m_sizes = generate_permute_indices( + tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment + ) + # Use the CPU method + permuted_indices_cpu, _ = generate_permute_indices( + tokens_per_expert_group, + experts_per_rank, + num_ranks, + max_len, + alignment, + use_cpu=True, + ) + # Check that the results are the same + assert torch.equal(permuted_indices_gpu.cpu(), permuted_indices_cpu) + assert torch.equal( + torch.remainder(m_sizes, alignment), + torch.zeros(experts_per_rank, device=device), + ) + # Print the results + print(permuted_indices_gpu) + print(m_sizes) + print("Success") + + +if __name__ == "__main__": + test() diff --git a/torchtitan/experiments/deepseek_v3/model.py b/torchtitan/experiments/deepseek_v3/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0669df9528b3db0de3325db36f010312b5b3eac7 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/model.py @@ -0,0 +1,1325 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This code is based on model definition of `deepseek-ai/DeepSeek-V3-Base` on +# Hugging Face Model Hub. Url: +# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py +# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/resolve/main/configuration_deepseek.py +# +# It has been modified from its original forms to accommodate naming convention +# and usage patterns of the TorchTitan project. + +# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeepSeek model.""" +import math +from typing import Optional, Tuple + +import torch +import torch.distributed as dist + +import torch.distributed._symmetric_memory as symm_mem +import torch.nn.functional as F +import torch.utils.checkpoint + +from attn_mask_utils import _prepare_4d_causal_attention_mask +from indices import generate_permute_indices +from model_config import ModelArgs +from symm_mem_recipes import OnDeviceAllToAllV +from torch import nn +from torch.distributed._functional_collectives import all_to_all_single_autograd + +from torchtitan.experiments.kernels.triton_mg_group_gemm.torchao_pr import ( + ALIGN_SIZE_M, + grouped_gemm_forward, +) + +# Get model parallel subgroup by name: +# e.g. "pp", "ep", None +def get_group(dim_name: Optional[str] = None) -> dist.ProcessGroup: + glob = torch.distributed.device_mesh._mesh_resources.get_current_mesh() + return glob.get_group(dim_name) + + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Deepseek +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class YarnRotaryEmbedding(RotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MLP(nn.Module): + act_fn = nn.SiLU() + + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim)) + ) + if self.topk_method == "noaux_tc": + self.e_score_correction_bias = nn.Parameter( + # Changed from torch.empty to torch.rand to avoid non-even + # distribution for runs without actual weigths + torch.rand((self.n_routed_experts)) + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + # compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear( + hidden_states.type(torch.float32), self.weight.type(torch.float32), None + ) + if self.scoring_func == "sigmoid": + scores = logits.sigmoid() + elif self.scoring_func == "softmax": + scores = logits.softmax(dim=-1, dtype=torch.float32) + else: + raise NotImplementedError( + f"insupportable scoring function for MoE gating: {self.scoring_func}" + ) + + # select top-k experts + if self.topk_method == "noaux_tc": + scores_for_choice = scores.view( + bsz * seq_len, -1 + ) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(bsz * seq_len, self.n_group, -1) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group + ) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores_for_choice.masked_fill( + ~score_mask.bool(), 0.0 + ) # [n, e] + _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) + topk_weight = scores.gather(1, topk_idx) + elif self.topk_method == "greedy": + topk_weight, topk_idx = torch.topk( + scores, k=self.top_k, dim=-1, sorted=False + ) + else: + raise NotImplementedError( + f"insupportable TopK function for MoE gating: {self.topk_method}" + ) + + # norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = ( + topk_weight * self.routed_scaling_factor + ) # must multiply the scaling factor + + return topk_idx, topk_weight + + +class MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + # Class attributes: + # Two shuffle method supported: + # 1. "torch_all_to_all" + # 2. "symm_mem" (see `setup_symm_mem` below) + shuffle_method = "torch_all_to_all" + + # Symmetric memory buffers shared by all MoE instances across layers + token_send_buf: Optional[torch.Tensor] = None + token_gather_buf: Optional[torch.Tensor] = None + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + # ep_size is the number of ranks in expert dimension + if config.ep_size <= 1: + raise ValueError( + "For code simplicity, this model only supports distributed experts, " + "thus EP size must be > 1, please modify your model config" + ) + self.ep_group = get_group("ep") + assert config.ep_size == self.ep_group.size() + self.ep_size = config.ep_size + self.ep_rank = self.ep_group.rank() + self.experts_per_rank = config.n_routed_experts // config.ep_size + # Use ModuleDict instead of ModuleList to preserve absoulte expert + # IDs while avoiding `None` experts. The absolute expert IDs match + # with checkpoint FQNs. + self.experts = nn.ModuleDict() + for i in range(self.experts_per_rank): + abs_expert_id = self.ep_rank * self.experts_per_rank + i + self.experts[str(abs_expert_id)] = MLP( + config, intermediate_size=config.moe_intermediate_size + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = MLP( + config=config, intermediate_size=intermediate_size + ) + + def combine_experts(self, submod_name): + all_weights = [] + for expert in self.experts.values(): + lin = expert.get_submodule(submod_name) + all_weights.append(lin.weight) + lin.weight = None + + concat_weight = torch.cat(all_weights) + self.register_parameter(f"{submod_name}_weight", nn.Parameter(concat_weight)) + + # This function is used to create a symm mem buffer for MoE's. It is for + # shuffling tokens fully "on-device", as compared to traditional torch + # all_to_all APIs which requrie a GPU-to-CPU sync of the splits. If a user + # calls this function, the `shuffle_method` would switch from + # `torch_all_to_all` to `symm_mem`. + def setup_symm_mem(self, dtype: torch.dtype, device: torch.device): + # Switch shuffle method + self.shuffle_method = "symm_mem" + + # Combine expert weights + print("Combining expert weights for Group GEMM") + self.combine_experts("gate_proj") + self.combine_experts("up_proj") + self.combine_experts("down_proj") + + # Assuming worst case, 2x tokens are routed to one EP rank + overflow = 2 + OnDeviceAllToAllV.max_output_len = ( + self.config.max_seq_len * self.num_experts_per_tok * overflow + ) + + # Symmetric memory buffers are shared by all MoE instances across + # layers, we only need to initialize them once + if MoE.token_send_buf is not None: + return + + # Input buffer for DP-to-EP shuffle + MoE.token_send_buf = symm_mem.empty( + self.config.max_seq_len + * self.num_experts_per_tok, # seq len * top k (flattened) + self.config.hidden_size, # hidden dim + dtype=dtype, + device=device, + ) + # Input buffer for EP-to-DP shuffle + MoE.token_gather_buf = symm_mem.empty( + self.config.max_seq_len + * self.num_experts_per_tok # seq len * top k (flattened) + * overflow, + self.config.hidden_size, # hidden dim + dtype=dtype, + device=device, + ) + print(f"EP rank [{self.ep_rank}]: Created Symmetric Memory for MoE") + + def get_send_buf(self): + # [Why detach?] During a first forward-backward step, the buffer would + # be included in a computational graph. In a second step, autograd will + # return an error saying "Trying to backward through the graph a second + # time (or directly access saved tensors more than once)". This is + # because the buffer is still in the graph, and autograd is trying to + # backward through the graph a second time. To avoid this, we detach the + # buffer from the graph. `detach()` returns a new tensor, which shares + # the same storage with the original one. + self.token_send_buf.grad = None + return self.token_send_buf.detach() + + def get_gather_buf(self): + # See [Why detach?] in `get_send_buf` + self.token_gather_buf.grad = None + return self.token_gather_buf.detach() + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + # for each token, select top-k experts, and compute the weight for each expert + topk_idx, topk_weight = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + if self.shuffle_method == "symm_mem": + y = self.moe_on_device(hidden_states, topk_idx, topk_weight) + else: # "torch_all_to_all" + y = self.moe_forward(hidden_states, topk_idx, topk_weight) + + y = y.view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y + + def moe_forward(self, x, topk_ids, topk_weight): + # This part sorts the token indices so that tokens routed to the same expert reside consecutively. + # An implication is that tokens to the same "expert group" (i.e., device) are also consecutive. + # Since this is an "aritificial" index creation (final outcome being + # `idxs`), we don't need gradients here. + with torch.no_grad(): + # [seq_len, n_routed_experts] + cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts)) + # Fill 1 to the selected experts + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + # Token indices for each expert + idxs = topk_ids.view(-1).argsort() + sorted_tokens_shape = idxs.shape + x.shape[1:] + + sorted_tokens = x[idxs // topk_ids.shape[1]] + assert sorted_tokens.shape == sorted_tokens_shape + + # This part exchange the information about the number of tokens send and + # received by each expert. We can understand this information as "side + # band", which is not part of the actual data. Thus no gradient is + # needed. + with torch.no_grad(): + # Sum the tokens over local experts, then we get tokens per EP rank, + # which is the input splits + tokens_per_expert_group = tokens_per_expert.new_empty( + tokens_per_expert.shape[0] + ) + dist.all_to_all_single( + tokens_per_expert_group, tokens_per_expert, group=self.ep_group + ) + input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + + # DP to EP token shuffle. This part needs gradient. + if self.shuffle_method == "symm_mem": + # Move input to the `token_send_buf` symm mem + token_send_buf = self.get_send_buf() + token_send_buf[: idxs.shape[0]].copy_(sorted_tokens) + # Note: `out=` avoids copy, but it is not differentiable + # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]]) + token_gather_buf, output_splits = OnDeviceAllToAllV.apply( + token_send_buf, + input_splits, + self.ep_group, + ) + with torch.no_grad(): + # Received tokens from all other ranks. TODO: use mask instead + received = output_splits.sum() + # TODO: don't use `received` + gathered_tokens = token_gather_buf[:received] + else: # "torch_all_to_all" + # Prepare input ans output splits + with torch.no_grad(): + output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum( + dim=1 + ) + gathered_tokens = all_to_all_single_autograd( + sorted_tokens, + output_splits.tolist(), + input_splits.tolist(), + self.ep_group, + ) + + # This part prepares a 1D tensor with the same length as + # `gathered_tokens`. The 1D tensor is filled with local expert IDs which + # the tokens in `gathered_tokens` are headed for. This part doesn't need + # gradient. + with torch.no_grad(): + gatherd_idxs = ( + torch.arange( + tokens_per_expert_group.numel(), + device=tokens_per_expert_group.device, + ) + % self.experts_per_rank + ) + gatherd_idxs = gatherd_idxs.repeat_interleave(tokens_per_expert_group) + + # Prepare buffer for tokens processed by experts + if self.shuffle_method == "symm_mem": + # Take necessary space from `token_gather_buf` symm mem because we are + # going to send them out after expert processing + processed_tokens = self.get_gather_buf()[: gathered_tokens.shape[0]] + else: # "torch_all_to_all" + processed_tokens = torch.empty_like(gathered_tokens) + + # This part processes the tokens routed to the local experts. + # TODO: can we use group GEMM here? + for i, expert in enumerate(self.experts.values()): + processed_tokens[gatherd_idxs == i] = expert( + gathered_tokens[gatherd_idxs == i] + ) + + # Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle. + # The input/output splits are just a reverse of the previous shuffle. + if self.shuffle_method == "symm_mem": + token_return_buf, _ = OnDeviceAllToAllV.apply( + processed_tokens, + output_splits, + self.ep_group, + ) + returned_tokens = token_return_buf[: sorted_tokens_shape[0]] + else: # "torch_all_to_all" + returned_tokens = all_to_all_single_autograd( + processed_tokens, + input_splits.tolist(), + output_splits.tolist(), + self.ep_group, + ) + + output_tokens = torch.empty_like(returned_tokens) + output_tokens[idxs] = returned_tokens + final_out = ( + output_tokens.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(returned_tokens.dtype) + ) + return final_out + + def moe_on_device(self, x, topk_ids, topk_weight): + # This part sorts the token indices so that tokens routed to the same expert reside consecutively. + # An implication is that tokens to the same "expert group" (i.e., device) are also consecutive. + # Since this is an "aritificial" index creation (final outcome being + # `idxs`), we don't need gradients here. + with torch.no_grad(): + # [seq_len, n_routed_experts] + cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts)) + # Fill 1 to the selected experts + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + # Token indices for each expert + idxs = topk_ids.view(-1).argsort() + sorted_tokens_shape = idxs.shape + x.shape[1:] + + sorted_tokens = x[idxs // topk_ids.shape[1]] + assert sorted_tokens.shape == sorted_tokens_shape + + # This part exchange the information about the number of tokens send and + # received by each expert. We can understand this information as "side + # band", which is not part of the actual data. Thus no gradient is + # needed. + with torch.no_grad(): + # Sum the tokens over local experts, then we get tokens per EP rank, + # which is the input splits + tokens_per_expert_group = tokens_per_expert.new_empty( + tokens_per_expert.shape[0] + ) + dist.all_to_all_single( + tokens_per_expert_group, tokens_per_expert, group=self.ep_group + ) + input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + + # Move input to the `token_send_buf` symm mem + token_send_buf = self.get_send_buf() + token_send_buf[: idxs.shape[0]].copy_(sorted_tokens) + # Note: `out=` avoids copy, but it is not differentiable + # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]]) + token_gather_buf, output_splits = OnDeviceAllToAllV.apply( + token_send_buf, + input_splits, + self.ep_group, + ) + + # We need to permute the received tokens so that tokens for the same expert are contiguous. + # This part prepares a 1D tensor `permuted_indices` for such permutation. + # This part doesn't need gradient. + with torch.no_grad(): + permuted_indices, m_sizes = generate_permute_indices( + tokens_per_expert_group, + self.experts_per_rank, + self.ep_size, + token_gather_buf.shape[0], + ALIGN_SIZE_M, + ) + + # Permute the received tokens so that tokens for the same expert are contiguous. + contig_tokens = token_gather_buf[permuted_indices] + + # Run the first grouped GEMM + w1 = self.get_parameter("gate_proj_weight") + gate_proj = grouped_gemm_forward(contig_tokens, w1, m_sizes) + + # Run the second grouped GEMM + w3 = self.get_parameter("up_proj_weight") + up_proj = grouped_gemm_forward(contig_tokens, w3, m_sizes) + + # Apply activation + hidden_outputs = MLP.act_fn(gate_proj) * up_proj + + # Run the third grouped GEMM + w2 = self.get_parameter("down_proj_weight") + hidden_outputs = grouped_gemm_forward(hidden_outputs, w2, m_sizes) + + # Prepare buffer for tokens processed by experts + # Take necessary space from `token_gather_buf` symm mem because we are + # going to send them out after expert processing + processed_tokens = self.get_gather_buf() + + # Move into Symmetric Memory for the return shuffle + processed_tokens[permuted_indices] = hidden_outputs + + # Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle. + # The input/output splits are just a reverse of the previous shuffle. + token_return_buf, _ = OnDeviceAllToAllV.apply( + processed_tokens, + output_splits, + self.ep_group, + ) + returned_tokens = token_return_buf[: sorted_tokens_shape[0]] + + output_tokens = torch.empty_like(returned_tokens) + output_tokens[idxs] = returned_tokens + final_out = ( + output_tokens.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(returned_tokens.dtype) + ) + return final_out + + +class Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: ModelArgs, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + self._init_rope() + + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if attention_mask is not None: + # Attention mask was made 4D because the `attn_weights` above is 4D. + # We probably can make this mask smarter if we want to pack sequences + # together, instead of using padding. This optimization can be used in + # inference. For training, if we want to pack sequences, data loader + # will pass in a mask containing such info. + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, # None, or user provided mask in 2D + (bsz, q_len), + hidden_states, + 0, # past_key_values_length, 0 when training + ) + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query=query_states, + key=key_states, + value=value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout, + is_causal=attention_mask is None, + scale=self.softmax_scale, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output + + +class DecoderLayer(nn.Module): + def __init__(self, config: ModelArgs, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Attention(config=config, layer_idx=layer_idx) + + self.mlp = ( + MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else MLP(config) + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +Deepseek_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class DeepseekModel(torch.nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DecoderLayer`] + + Args: + config: ModelArgs + """ + + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Creating model parts related to my stage + assert ( + config.stage_idx < config.num_stages + ), f"Stage {config.stage_idx} is not in the model" + print(f"Creating model stage {config.stage_idx} of {config.num_stages}") + + self.embed_tokens = ( + nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + if config.stage_idx == 0 + else None + ) + + self.layers = torch.nn.ModuleDict() + division = config.num_hidden_layers // config.num_stages + residual = config.num_hidden_layers % config.num_stages + # Some earlier stages may have 1 more layer than latter stages because + # the division may have residual; this is more even than giving the + # entire residual to the last stage. + layers_per_stage = [ + division + 1 if stage < residual else division + for stage in range(config.num_stages) + ] + assert sum(layers_per_stage) == config.num_hidden_layers + layer_id_start = sum(layers_per_stage[: config.stage_idx]) + layer_id_end = layer_id_start + layers_per_stage[config.stage_idx] + for layer_id in range(layer_id_start, layer_id_end): + self.layers[str(layer_id)] = DecoderLayer(config, layer_id) + + self.norm = ( + RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if config.stage_idx == config.num_stages - 1 + else None + ) + + # Initialize weights and apply final processing + self.apply(self._init_weights) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def forward( + self, + tokens: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + # Embedding + hidden_states = ( + self.embed_tokens(tokens) if self.embed_tokens is not None else tokens + ) + + # decoder layers + for decoder_layer in self.layers.values(): + hidden_states = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = ( + self.norm(hidden_states) if self.norm is not None else hidden_states + ) + return hidden_states + + +class DeepseekForCausalLM(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.model = DeepseekModel(config) + self.lm_head = ( + nn.Linear(config.hidden_size, config.vocab_size, bias=False) + if config.stage_idx == config.num_stages - 1 + else None + ) + + # Initialize weights and apply final processing + # self.post_init() + + def forward( + self, + tokens: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Tuple: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekForCausalLM + + >>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + hidden_states = self.model( + tokens, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + logits = ( + self.lm_head(hidden_states) if self.lm_head is not None else hidden_states + ) + return logits + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + **kwargs, + ): + if past_key_values is not None: + # Assuming isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + # Setup Symmetric Memory for MoE token shuffle. + # Supports inference currently. + def setup_symm_mem(self, dtype: torch.dtype, device: torch.device): + for layer in self.model.layers.values(): + if not isinstance(layer.mlp, MoE): + continue + layer.mlp.setup_symm_mem(dtype, device) diff --git a/torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..335bc2d966efbe486418525cb784078a6ec879d5 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .triton_on_device_all_to_all_v import OnDeviceAllToAllV + +__all__ = [ + "OnDeviceAllToAllV", +] diff --git a/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd9b283f41daffab3f4ce4d1e0a5d844f2a2c70 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import triton +import triton.language as tl + +from .triton_utils import get_flat_bid, get_flat_tid + + +@triton.jit +def send_signal(addrs, sem: tl.constexpr): + if sem == "relaxed": + tl.inline_asm_elementwise( + """ + { + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + send_signal: + atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1; + setp.eq.u32 %p0, %tmp32_0, 0; + @!%p0 bra send_signal; + } + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + elif sem == "acq_rel": + tl.inline_asm_elementwise( + """ + { + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + send_signal: + atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1; + setp.eq.u32 %p0, %tmp32_0, 0; + @!%p0 bra send_signal; + } + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + else: + raise RuntimeError(f"Unrecognized sem: {sem}") + + +@triton.jit +def wait_signal(addrs, sem: tl.constexpr): + if sem == "relaxed": + tl.inline_asm_elementwise( + """ + { + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + wait_signal: + atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0; + setp.eq.u32 %p0, %tmp32_0, 1; + @!%p0 bra wait_signal; + } + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + elif sem == "acq_rel": + tl.inline_asm_elementwise( + """ + { + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + wait_signal: + atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0; + setp.eq.u32 %p0, %tmp32_0, 1; + @!%p0 bra wait_signal; + } + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + else: + raise RuntimeError(f"Unrecognized sem: {sem}") + + +@triton.jit +def blockwise_barrier( + signal_pad_ptrs, + block_id, + rank: tl.constexpr, + world_size: tl.constexpr, + sem: tl.constexpr, +): + """ + Synchronizes blocks with matching block_id across participating devices. + + Note: the function itself is not a system level barrier/fence. It is a + building block for expressing different synchronization patterns. + + Pattern 0: Ensures that all writes to symm_mem buffers from previous + kernels across all devices are visible to the current kernel: + + blockwise_barrier(..., sem="relaxed") + sync_threads() + + Pattern 1: Ensures that all writes to symm_mem buffers from the current + block are visible to all remote blocks with matching blockIdx: + + sync_threads() + blockwise_barrier(..., sem="acq_rel") + sync_threads() + + Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe + for writing by subsequent kernels across all devices. + + sync_threads() + blockwise_barrier(..., sem="relaxed") + + CUDA graph friendliness: + + This barrier operates through atomic operations on a zero-filled signal + pad, which resets to a zero-filled state after each successful + synchronization. This design eliminates the need for incrementing a + flag from host. + """ + if block_id is None: + block_id = get_flat_bid() + flat_tid = get_flat_tid() + + remote_ranks = tl.arange(0, world_size) + signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64)) + remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to( + tl.pointer_type(tl.uint32) + ) + send_addrs = remote_signal_pad_addrs + block_id * world_size + rank + + local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to( + tl.pointer_type(tl.uint32) + ) + wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks + + if flat_tid < world_size: + send_signal(send_addrs, sem) + wait_signal(wait_addrs, sem) diff --git a/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py new file mode 100644 index 0000000000000000000000000000000000000000..5cd023c36bd9737bfb03da22ea38ef57a448eb80 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py @@ -0,0 +1,260 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem +import triton +import triton.language as tl + +from .triton_barrier import blockwise_barrier +from .triton_utils import sync_threads + + +@triton.jit +def _exchange_row_offsets( + split_sizes_ptrs, + rank: tl.constexpr, + world_size: tl.constexpr, + BLOCKS_PER_REMOTE_RANK: tl.constexpr, +): + remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK + + # split_sizes_ptr for all ranks + # All these vector stacks into split_sizes_matrix + split_sizes_ptrs = split_sizes_ptrs.to(tl.pointer_type(tl.uint64)) + + # split_sizes_matrix[remote_rank, :] + input_split_sizes_ptr = tl.load(split_sizes_ptrs + remote_rank).to( + tl.pointer_type(tl.int64) + ) + + offsets_ = tl.arange(0, world_size) + input_split_sizes = tl.load( + input_split_sizes_ptr + offsets_, mask=offsets_ <= rank, other=0 + ) + + num_rows = tl.load(input_split_sizes_ptr + rank) + input_row_offset = tl.sum(input_split_sizes) - num_rows + + # split_sizes_matrix[:, rank] + output_split_sizes_ptrs = ( + tl.load(split_sizes_ptrs + offsets_).to(tl.pointer_type(tl.int64)) + rank + ) + output_split_sizes = tl.load( + output_split_sizes_ptrs, mask=offsets_ <= remote_rank, other=0 + ) + output_row_offset = tl.sum(output_split_sizes) - num_rows + + return input_row_offset, output_row_offset, num_rows + + +@triton.jit +def on_device_all_to_all_v_kernel( + output_ptr, + output_splits_ptr, + input_ptrs, + input_splits_ptr, + signal_pad_ptrs, + dim: tl.constexpr, # Separate dim for easier vectorization + rank: tl.constexpr, + world_size: tl.constexpr, + BLOCKS_PER_REMOTE_RANK: tl.constexpr, + UNROLL_FACTOR: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed") + sync_threads() + + remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK + block_offset = tl.program_id(0) % BLOCKS_PER_REMOTE_RANK + + input_row_offset, output_row_offset, num_rows = _exchange_row_offsets( + input_splits_ptr, rank, world_size, BLOCKS_PER_REMOTE_RANK + ) + + output_splits_ptr = output_splits_ptr.to(tl.pointer_type(tl.uint64)) + if block_offset == 0: + # Update output_splits + tl.store(output_splits_ptr + remote_rank, num_rows) + + input_ptr = ( + tl.load(input_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to( + tl.pointer_type(tl.bfloat16) + ) + + input_row_offset * dim + ) + output_ptr = output_ptr + output_row_offset * dim + + outer_loop_step = BLOCK_SIZE * UNROLL_FACTOR + outer_loop_iters_per_rank = tl.cdiv( + tl.cdiv(num_rows * dim, outer_loop_step), BLOCKS_PER_REMOTE_RANK + ) + numel_per_rank = outer_loop_step * outer_loop_iters_per_rank + offset = numel_per_rank * block_offset + end = tl.minimum(numel_per_rank * (block_offset + 1), num_rows * dim) + + unroll_region_size = (end - offset) // outer_loop_step * outer_loop_step + for i in tl.range(offset, offset + unroll_region_size, outer_loop_step): + datas = [] + for j in tl.range( + i, + i + outer_loop_step, + BLOCK_SIZE, + loop_unroll_factor=UNROLL_FACTOR, + ): + offsets = j + tl.arange(0, BLOCK_SIZE) + data = tl.load(input_ptr + offsets) + tl.store(output_ptr + offsets, data) + + offset += unroll_region_size + while offset < end: + offsets = offset + tl.arange(0, BLOCK_SIZE) + mask = offsets < num_rows * dim + data = tl.load(input_ptr + offsets, mask=mask) + tl.store(output_ptr + offsets, data, mask=mask) + offset += BLOCK_SIZE + + sync_threads() + blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed") + return + + +def _on_device_all_to_all_v( + output: torch.Tensor, + output_splits: torch.Tensor, + input: torch.Tensor, + input_splits: torch.Tensor, + group: dist.ProcessGroup = dist.group.WORLD, + BLOCKS_PER_REMOTE_RANK=8, + UNROLL_FACTOR: int = 8, + BLOCK_SIZE: int = 16384, +): + assert output.dim() == 2, f"{output.shape}" + assert input.dim() == 2, f"{input.shape}" + assert output.shape[1] == input.shape[1] + + dim = output.shape[1] + input_hdl = symm_mem.rendezvous(input, group=group) + input_splits_hdl = symm_mem.rendezvous(input_splits, group=group) + + num_blocks = input_hdl.world_size * BLOCKS_PER_REMOTE_RANK + kernel = on_device_all_to_all_v_kernel[(num_blocks, 1, 1)]( + output, + output_splits, + input_hdl.buffer_ptrs_dev, + input_splits_hdl.buffer_ptrs_dev, + input_hdl.signal_pad_ptrs_dev, + dim=dim, + rank=input_hdl.rank, + world_size=input_hdl.world_size, + BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK, + UNROLL_FACTOR=UNROLL_FACTOR, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=16, + ) + # log_triton_kernel(kernel) + return output + + +class OnDeviceAllToAllV(torch.autograd.Function): + # A symmetric memory holding the grad_output during backward + grad_output_buf = None + # A symmetric memory for exchanges split sizes during both forward and backward + splits_buf = None + # Maximum output length (need to be set before use of OnDeviceAllToAllV) + max_output_len = None + + @staticmethod + def forward( + ctx, + input: torch.Tensor, + input_splits: torch.Tensor, + group: dist.ProcessGroup = dist.group.WORLD, + ): + """ + Args: + input: input tensor with data for all ranks concatenated. + input_splits: input splits of shape (group.world_size,) + group: process group to scope the collective. + """ + # Initialize input splits buffer (one time only) + if OnDeviceAllToAllV.splits_buf is None: + OnDeviceAllToAllV.splits_buf = symm_mem.empty( + *input_splits.shape, + dtype=input_splits.dtype, + device=input_splits.device, + ) + + if OnDeviceAllToAllV.max_output_len is None: + raise RuntimeError( + "Please set max output length via `OnDeviceAllToAllV.max_output_len = ...`" + ) + + # Allocate output buffer + output = input.new_empty(OnDeviceAllToAllV.max_output_len, *input.shape[1:]) + # Allocate output splits tensor + output_splits = torch.empty_like(input_splits) + # Copy input splits to the buffer + OnDeviceAllToAllV.splits_buf.copy_(input_splits) + + # Shuffle input to output + _on_device_all_to_all_v( + output, output_splits, input, OnDeviceAllToAllV.splits_buf, group=group + ) + + # Output splits in forward is the input splits in backward + ctx.save_for_backward(output_splits) + ctx.group = group + ctx.input_shape = input.shape + return output, output_splits + + @staticmethod + def backward(ctx, grad_output, grad_splits): + """ + Backward is implemented as a shuffle of the output's gradients to the input. + Args: + `grad_output`: output's gradients passed from the downstream. + `grad_splits`: unused. + """ + + # Initialize grad_output buffer (one time only) + if OnDeviceAllToAllV.grad_output_buf is None: + assert ( + OnDeviceAllToAllV.max_output_len is not None + ), "`max_output_len` not set" + OnDeviceAllToAllV.grad_output_buf = symm_mem.empty( + OnDeviceAllToAllV.max_output_len, + *grad_output.shape[1:], + dtype=grad_output.dtype, + device=grad_output.device, + ) + + # TODO: is there a way to tell autograd to feed grad_output directly to + # our symm_mem buffer? + OnDeviceAllToAllV.grad_output_buf.narrow(0, 0, grad_output.shape[0]).copy_( + grad_output + ) + + # Size info + (grad_output_splits,) = ctx.saved_tensors + OnDeviceAllToAllV.splits_buf.copy_(grad_output_splits) + grad_input_splits = torch.empty_like(grad_output_splits) # unused + grad_input = grad_output.new_empty(*ctx.input_shape) + + # Shuffle gradients back to the input + _on_device_all_to_all_v( + grad_input, + grad_input_splits, + OnDeviceAllToAllV.grad_output_buf, + OnDeviceAllToAllV.splits_buf, + group=ctx.group, + ) + return grad_input, None, None + + +# Alias +on_device_all_to_all_v = OnDeviceAllToAllV.apply diff --git a/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ed00317084d85abd10e13cc4f18437d6e9337a75 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import triton +import triton.language as tl + + +@triton.jit +def get_tid(): + return tl.inline_asm_elementwise( + """ + mov.u32 $0, %tid.x; + mov.u32 $1, %tid.y; + mov.u32 $2, %tid.z; + """, + "=r,=r,=r", + [], + dtype=(tl.uint32, tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + + +@triton.jit +def get_ntid(): + return tl.inline_asm_elementwise( + """ + mov.u32 $0, %ntid.x; + mov.u32 $1, %ntid.y; + mov.u32 $2, %ntid.z; + """, + "=r,=r,=r", + [], + dtype=(tl.uint32, tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + + +@triton.jit +def get_flat_tid(): + tid_x, tid_y, tid_z = get_tid() + ntid_x, ntid_y, _ = get_ntid() + return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x + + +@triton.jit +def get_flat_bid(): + return ( + tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0) + + tl.program_id(1) * tl.num_programs(0) + + tl.program_id(0) + ) + + +@triton.jit +def sync_threads(): + tl.inline_asm_elementwise( + "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 + ) diff --git a/torchtitan/experiments/deepseek_v3/train.py b/torchtitan/experiments/deepseek_v3/train.py new file mode 100644 index 0000000000000000000000000000000000000000..1b9ed2dd65164744686647964de3ffdfa3813771 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/train.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# torchrun --standalone --nproc-per-node 8 run.py +import torch +import torch.distributed as dist +from checkpoint import load_weights_from_hf +from model import DeepseekForCausalLM +from model_config import deepseek_config_registry + +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import fully_shard +from torch.distributed.pipelining import PipelineStage, Schedule1F1B + + +# Use DeepSeek-V2-Lite as a proxy +model_id = "deepseek-ai/DeepSeek-V2-Lite" + + +# Run full model +def run_full_model( + mesh: DeviceMesh, +): + rank = dist.get_rank() + device_count = torch.cuda.device_count() + device = torch.device("cuda", rank % device_count) + + pp_mesh = mesh["pp"] + ep_mesh = mesh["ep"] + pp_rank = pp_mesh.get_local_rank() + ep_rank = ep_mesh.get_local_rank() + pp_size = pp_mesh.size() + ep_size = ep_mesh.size() + + # Get model configs + model_args = deepseek_config_registry[model_id] + # [Note]: I am making the model smaller for testing / avoiding OOM. If you + # have sufficient GPUs for model parallelism, you can remove this line. + model_args.num_hidden_layers = 16 + + # Apply model parallelism + model_args.ep_size = ep_size + model_args.num_stages = pp_size + model_args.stage_idx = pp_rank + print(model_args) + + # Instantiate model + with device, mesh: + model = DeepseekForCausalLM(model_args) + + # Load weights + load_weights_from_hf(model, model_id, device) + model.train() + + # Apply data parallelism + fsdp_mesh = mesh["fsdp"] + hsdp_mesh = mesh["ep", "fsdp"] + # Using `reshard_after_forward=False` to implement Zero-2, i.e. sharding the + # optimizer (Zero-1) and gradients (Zero-2), but not the model weights. + # Reason: the MoE is "sparsely activated" compared to the dense model, thus + # it will be ineconomical re-gather the weights. + for layer in model.model.layers.values(): + # Apply FSDP to experts + if hasattr(layer.mlp, "experts"): + for expert in layer.mlp.experts.values(): + fully_shard(expert, mesh=fsdp_mesh, reshard_after_forward=False) + # Apply HSDP to other parts such as attention, layernorm, because they + # are doing DDP on EP dimension + fully_shard(layer, mesh=hsdp_mesh, reshard_after_forward=False) + + # Apply HSDP on root model (lm_head, embeddings, etc) + fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False) + + # Synthetic setting + microbatches = pp_size * 2 + + # Use Symmetric Memory for MoE token shuffle. + # TODO: we are rewriting `moe_on_device` function. `setup_symm_mem` is + # currently supported for forward only. See `generate.py`. + # model.setup_symm_mem(torch.bfloat16, device) + + # Example inputs + torch.manual_seed(ep_rank) + bs = 4 + seqlen = 128 + x = torch.randint(model_args.vocab_size, (microbatches * bs, seqlen), device=device) + label = torch.rand(microbatches * bs, seqlen, model_args.vocab_size, device=device) + + # Create loss function + loss_fn = torch.nn.functional.cross_entropy + + # Run forward and backward + steps = 2 + for _ in range(steps): + if pp_size > 1: + # Create pipeline stage + stage = PipelineStage( + model, + pp_rank, + pp_size, + device, + group=pp_mesh.get_group(), + ) + + # Create pipeline schedule + losses = [] + pp_schedule = Schedule1F1B(stage, microbatches, loss_fn=loss_fn) + + if pp_rank == 0: + y = pp_schedule.step(x) + elif pp_rank == pp_size - 1: + y = pp_schedule.step(target=label, losses=losses) + loss = torch.mean(torch.stack(losses)) + else: + pp_schedule.step() + else: + y = model(x) + loss = loss_fn(y, label) + loss.backward() + + if pp_rank == pp_size - 1: + print(f"logits: {y.shape}") + print(f"{loss=}") + + if pp_rank == 0: + param = model.get_parameter("model.layers.0.self_attn.q_proj.weight") + print(f"{torch.linalg.norm(param.grad)=}") + + model.zero_grad() + + print("Backward done") + + +if __name__ == "__main__": + mesh = dist.init_device_mesh("cuda", (2, 2, 2), mesh_dim_names=("pp", "ep", "fsdp")) + + run_full_model(mesh) + + dist.destroy_process_group() diff --git a/torchtitan/experiments/flux/README.md b/torchtitan/experiments/flux/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2e56939b6eea7769d5130703cd3acb58f7eb5f5a --- /dev/null +++ b/torchtitan/experiments/flux/README.md @@ -0,0 +1,23 @@ +# FLUX model in torchtitan + +## Overview + +## Usage +First, download the autoencoder model from HuggingFace with your own access token: +```bash +python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token +``` +This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file. + +Run the following command to train the model on a single GPU: +```bash +PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml +``` + +## TODO +- [ ] Supporting for multiple GPUs is comming soon (FSDP, etc) +- [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc) +- [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc) +- [ ] Support for distributed checkpointing and loading +- [ ] Implement init_weights() function to initialize the model weights +- [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function diff --git a/torchtitan/experiments/flux/__init__.py b/torchtitan/experiments/flux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..314a8689b291c74db639669e7bc4943612b47a03 --- /dev/null +++ b/torchtitan/experiments/flux/__init__.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader +from torchtitan.experiments.flux.loss import build_mse_loss +from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams +from torchtitan.experiments.flux.parallelize_flux import parallelize_flux +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .model.model import FluxModel, FluxModelArgs + +__all__ = [ + "FluxModelArgs", + "FluxModel", + "flux_configs", + "parallelize_flux", +] + + +flux_configs = { + "flux-dev": FluxModelArgs( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=512, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=(16, 56, 56), + theta=10_000, + qkv_bias=True, + guidance_embed=True, + autoencoder_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": FluxModelArgs( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=(16, 56, 56), + theta=10_000, + qkv_bias=True, + guidance_embed=False, + autoencoder_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-debug": FluxModelArgs( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=512, + hidden_size=512, + mlp_ratio=4.0, + num_heads=4, + depth=2, + depth_single_blocks=2, + axes_dim=(16, 56, 56), + theta=10_000, + qkv_bias=True, + guidance_embed=True, + autoencoder_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +register_train_spec( + TrainSpec( + name="flux", + cls=FluxModel, + config=flux_configs, + parallelize_fn=parallelize_flux, + pipelining_fn=None, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_flux_dataloader, + build_tokenizer_fn=None, + build_loss_fn=build_mse_loss, + ) +) diff --git a/torchtitan/experiments/flux/dataset/flux_dataset.py b/torchtitan/experiments/flux/dataset/flux_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..995f0af3b4152052bcfb21b4331e8dcff8ddd7da --- /dev/null +++ b/torchtitan/experiments/flux/dataset/flux_dataset.py @@ -0,0 +1,267 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import random +from dataclasses import dataclass +from typing import Any, Callable, Optional + +import numpy as np + +import torch + +from datasets import Dataset, load_dataset +from datasets.distributed import split_dataset_by_node +from PIL import Image + +from torch.distributed.checkpoint.stateful import Stateful + +from torch.utils.data import IterableDataset +from torchtitan.components.dataloader import ParallelAwareDataloader + +from torchtitan.config_manager import JobConfig +from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer +from torchtitan.tools.logging import logger + + +def _process_cc12m_image( + img: Image.Image, + output_size: int = 256, +) -> Optional[torch.Tensor]: + """Process CC12M image to the desired size.""" + + width, height = img.size + # Skip low resolution images + if width < output_size or height < output_size: + return None + + if width >= height: + # resize height to be equal to output_size, then crop + new_width, new_height = math.ceil(output_size / height * width), output_size + img = img.resize((new_width, new_height)) + left = random.randint(0, new_width - output_size) + resized_img = img.crop((left, 0, left + output_size, output_size)) + else: + # resize width to be equal to output_size, the crop + new_width, new_height = ( + output_size, + math.ceil(output_size / width * height), + ) + img = img.resize((new_width, new_height)) + lower = random.randint(0, new_width - output_size) + resized_img = img.crop((0, lower, output_size, lower + output_size)) + + assert resized_img.size[0] == resized_img.size[1] == output_size + + # Skip grayscale images + if resized_img.mode == "L": + return None + + np_img = np.array(resized_img).transpose((2, 0, 1)) + tensor_img = torch.tensor(np_img).float() / 255.0 + + # NOTE: The following commented code is an alternative way + # img_transform = transforms.Compose( + # [ + # transforms.Resize(max(output_size, output_size)), + # transforms.CenterCrop((output_size, output_size)), + # transforms.ToTensor(), + # ] + # ) + # tensor_img = img_transform(img) + + return tensor_img + + +def _flux_data_processor( + sample: dict[str, Any], + t5_tokenizer: FluxTokenizer, + clip_tokenizer: FluxTokenizer, + output_size: int = 256, +) -> dict[str, Any]: + """ + Preprocess CC12M dataset sample image and text for Flux model. + + Args: + sample: A sample from dataset + t5_encoder: T5 encoder + clip_encoder: CLIP encoder + output_size: The output image size + + """ + img = _process_cc12m_image(sample["jpg"], output_size=output_size) + t5_tokens = t5_tokenizer.encode(sample["txt"]) + clip_tokens = clip_tokenizer.encode(sample["txt"]) + + return { + "image": img, + "clip_tokens": clip_tokens, # type: List[int] + "t5_tokens": t5_tokens, # type: List[int] + } + + +@dataclass +class TextToImageDatasetConfig: + path: str + loader: Callable + data_processor: Callable + + +DATASETS = { + "cc12m": TextToImageDatasetConfig( + path="pixparse/cc12m-wds", + loader=lambda path: load_dataset(path, split="train", streaming=True), + data_processor=_flux_data_processor, + ), +} + + +def _validate_dataset( + dataset_name: str, dataset_path: Optional[str] = None +) -> tuple[str, Callable, Callable]: + """Validate dataset name and path.""" + if dataset_name not in DATASETS: + raise ValueError( + f"Dataset {dataset_name} is not supported. " + f"Supported datasets are: {list(DATASETS.keys())}" + ) + + config = DATASETS[dataset_name] + path = dataset_path or config.path + logger.info(f"Preparing {dataset_name} dataset from {path}") + return path, config.loader, config.data_processor + + +class FluxDataset(IterableDataset, Stateful): + """Dataset for FLUX text-to-image model. + + Args: + dataset_name (str): Name of the dataset. + dataset_path (str): Path to the dataset. + model_transform (Transform): Callable that applies model-specific preprocessing to the sample. + dp_rank (int): Data parallel rank. + dp_world_size (int): Data parallel world size. + infinite (bool): Whether to loop over the dataset infinitely. + """ + + def __init__( + self, + dataset_name: str, + dataset_path: Optional[str], + t5_tokenizer: FluxTokenizer, + clip_tokenizer: FluxTokenizer, + job_config: Optional[JobConfig] = None, + dp_rank: int = 0, + dp_world_size: int = 1, + infinite: bool = False, + ) -> None: + + # Force lowercase for consistent comparison + dataset_name = dataset_name.lower() + + path, dataset_loader, data_processor = _validate_dataset( + dataset_name, dataset_path + ) + ds = dataset_loader(path) + + self.dataset_name = dataset_name + self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) + + self._t5_tokenizer = t5_tokenizer + self._clip_tokenizer = clip_tokenizer + self._data_processor = data_processor + self.job_config = job_config + + self.infinite = infinite + + # Variables for checkpointing + self._sample_idx = 0 + self._all_samples: list[dict[str, Any]] = [] + + def _get_data_iter(self): + if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): + return iter([]) + + it = iter(self._data) + for _ in range(self._sample_idx): + next(it) + return it + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + # Use the dataset-specific preprocessor + sample_dict = self._data_processor( + sample, self._t5_tokenizer, self._clip_tokenizer, output_size=256 + ) + + # skip low quality image or image with color channel = 1 + if sample_dict["image"] is None: + logger.warning( + f"Low quality image {sample['__key__']} is skipped in Flux Dataloader" + ) + continue + + self._all_samples.extend(sample_dict) + self._sample_idx += 1 + + labels = sample_dict.pop("image") + yield sample_dict, labels + + if not self.infinite: + logger.warning(f"Dataset {self.dataset_name} has run out of data") + break + else: + # Reset offset for the next iteration + self._sample_idx = 0 + logger.warning(f"Dataset {self.dataset_name} is being re-looped") + + def load_state_dict(self, state_dict): + self._sample_idx = state_dict["sample_idx"] + self._all_samples = state_dict["all_samples"] + + def state_dict(self): + return { + "all_samples": self._all_samples, + "sample_idx": self._sample_idx, + } + + +def build_flux_dataloader( + dp_world_size: int, + dp_rank: int, + job_config: JobConfig, + # This parameter is not used, keep it for compatibility + tokenizer: FluxTokenizer | None, + infinite: bool = True, +) -> ParallelAwareDataloader: + """Build a data loader for HuggingFace datasets.""" + dataset_name = job_config.training.dataset + dataset_path = job_config.training.dataset_path + batch_size = job_config.training.batch_size + + t5_encoder_name = job_config.encoder.t5_encoder + clip_encoder_name = job_config.encoder.clip_encoder + max_t5_encoding_len = job_config.encoder.max_t5_encoding_len + + ds = FluxDataset( + dataset_name=dataset_name, + dataset_path=dataset_path, + t5_tokenizer=FluxTokenizer(t5_encoder_name, max_length=max_t5_encoding_len), + clip_tokenizer=FluxTokenizer( + clip_encoder_name, max_length=77 + ), # fix max_length for CLIP + dp_rank=dp_rank, + dp_world_size=dp_world_size, + infinite=infinite, + ) + + return ParallelAwareDataloader( + dataset=ds, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + batch_size=batch_size, + ) diff --git a/torchtitan/experiments/flux/dataset/tokenizer.py b/torchtitan/experiments/flux/dataset/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..090bfc955152d87614f03793fd606330995da39d --- /dev/null +++ b/torchtitan/experiments/flux/dataset/tokenizer.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + + +from typing import List + +from torchtitan.components.tokenizer import Tokenizer +from transformers import CLIPTokenizer, T5Tokenizer + + +class FluxTokenizer(Tokenizer): + """ + Tokenizing and encoding/decoding text using the T5 or Clip tokenizer. + + Args: + model_path (str): Path to the tokenzier from hugging face. + + """ + + def __init__(self, model_path: str = "t5-small", max_length: int = 77): + super().__init__() + self._n_words = 8 # TODO(jianiw): check + self._max_length = max_length + + self.is_clip = model_path.startswith("openai") + + if self.is_clip: + self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( + model_path, max_length=max_length + ) + else: + self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( + model_path, max_length=max_length + ) + + def encode( + self, + s: str, + ) -> List[int]: + """ + Encode the prompt text into tokens. + """ + tokens = self._tokenizer( + s, + truncation=True, + max_length=self._max_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", # return pytorch tensors, default return List[int] + )["input_ids"] + return tokens + + def decode(self, t: List[int]) -> str: + """ + Decode function. This function will not be called. + """ + return self._tokenizer.decode(t) diff --git a/torchtitan/experiments/flux/model/autoencoder.py b/torchtitan/experiments/flux/model/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a68d5fb750d04b37d059dbef1de1f399bd3caea2 --- /dev/null +++ b/torchtitan/experiments/flux/model/autoencoder.py @@ -0,0 +1,388 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from dataclasses import dataclass + +import torch +from einops import rearrange +from safetensors.torch import load_file as load_sft +from torch import nn, Tensor + + +@dataclass +class AutoEncoderParams: + resolution: int = 256 + in_channels: int = 3 + ch: int = 128 + out_ch: int = 3 + ch_mult: tuple[int] = (1, 2, 4, 4) + num_res_blocks: int = 2 + z_channels: int = 16 + scale_factor: float = 0.3611 + shift_factor: float = 0.1159 + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = nn.GroupNorm( + num_groups=32, num_channels=out_channels, eps=1e-6, affine=True + ) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=block_in, eps=1e-6, affine=True + ) + self.conv_out = nn.Conv2d( + block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=block_in, eps=1e-6, affine=True + ) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # get dtype for proper tracing + upscale_dtype = next(self.up.parameters()).dtype + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # cast to proper dtype + h = h.to(upscale_dtype) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.params = params + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) + + +def load_ae( + ckpt_path: str, + autoencoder_params: AutoEncoderParams, + device: str | torch.device = "cuda", + dtype=torch.bfloat16, +) -> AutoEncoder: + """ + Load the autoencoder from the given model name. + Args: + name (str): The name of the autoencoder. + device (str or torch.device): The device to load the autoencoder to. + Returns: + AutoEncoder: The loaded autoencoder. + """ + # Loading the autoencoder + print("Init AE") + with torch.device(device): + ae = AutoEncoder(autoencoder_params) + + if not os.path.exists(ckpt_path): + raise ValueError( + f"Autoencoder path {ckpt_path} does not exist. Please download it first." + ) + + if ckpt_path is not None: + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + if len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + if len(unexpected) > 0: + print( + f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) + ) + return ae.to(dtype=dtype) diff --git a/torchtitan/experiments/flux/model/hf_embedder.py b/torchtitan/experiments/flux/model/hf_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..495fd7a81d16cc0cadeaab3b390a638339ff0f94 --- /dev/null +++ b/torchtitan/experiments/flux/model/hf_embedder.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn, Tensor +from transformers import CLIPTextModel, T5EncoderModel + + +class FluxEmbedder(nn.Module): + def __init__(self, version: str, **hf_kwargs): + super().__init__() + self.is_clip = version.startswith("openai") + self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" + + if self.is_clip: + self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( + version, **hf_kwargs + ) + else: + self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained( + version, **hf_kwargs + ) + + self.hf_module = self.hf_module.eval().requires_grad_(False) + + def forward(self, batch_tokens: Tensor) -> Tensor: + """ + batch_tokens: [bsz, embedding_length] + + For T5 Encoder, embeding_length is 768 + For CLIP, embedding_length is 256 + """ + outputs = self.hf_module( + input_ids=batch_tokens.to(self.hf_module.device), + attention_mask=None, + output_hidden_states=False, + ) + return outputs[self.output_key] diff --git a/torchtitan/experiments/flux/model/layers.py b/torchtitan/experiments/flux/model/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..73141b373a5d579b8c8988fa66d1f9594e5bad3f --- /dev/null +++ b/torchtitan/experiments/flux/model/layers.py @@ -0,0 +1,286 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# imported from black-forest-labs/FLUX +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import nn, Tensor + +from torchtitan.experiments.flux.model.math import attention, rope + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) # TODO(jianiw): switch to pytorch nn.RMSNorm + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk( + self.multiplier, dim=-1 + ) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__( + self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False + ): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + def forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor + ) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange( + img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange( + txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp( + (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift + ) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp( + (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift + ) + return img, txt + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split( + self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1 + ) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/torchtitan/experiments/flux/model/math.py b/torchtitan/experiments/flux/model/math.py new file mode 100644 index 0000000000000000000000000000000000000000..69a2d4acf13c1acf9f66edba1e5fe49c26d9b1d5 --- /dev/null +++ b/torchtitan/experiments/flux/model/math.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from einops import rearrange +from torch import Tensor + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack( + [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1 + ) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/torchtitan/experiments/flux/model/model.py b/torchtitan/experiments/flux/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..67b9e6aeaacee709c4fdc7d86f338eec050bf322 --- /dev/null +++ b/torchtitan/experiments/flux/model/model.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +import torch + +from torch import nn, Tensor +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig + +from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams +from torchtitan.experiments.flux.model.layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, +) + +from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol +from torchtitan.tools.logging import logger + + +@dataclass +class FluxModelArgs(BaseModelArgs): + in_channels: int = 64 + out_channels: int = 64 + vec_in_dim: int = 768 + context_in_dim: int = 512 + hidden_size: int = 3072 + mlp_ratio: float = 4.0 + num_heads: int = 24 + depth: int = 19 + depth_single_blocks: int = 38 + axes_dim: tuple = (16, 56, 56) + theta: int = 10_000 + qkv_bias: bool = True + guidance_embed: bool = True + autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams) + + def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: + # context_in_dim is the same as the T5 embedding dimension + self.context_in_dim = job_config.encoder.max_t5_encoding_len + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + # TODO(jianiw): Add the number of flops for the autoencoder + nparams = sum(p.numel() for p in model.parameters()) + logger.warning("FLUX model haven't implement get_nparams_and_flops() function") + return nparams, 1 + + +class FluxModel(nn.Module, ModelProtocol): + """ + Transformer model for flow matching on sequences. + + Agrs: + model_args: FluxModelArgs. + + Attributes: + model_args (TransformerModelArgs): Model configuration arguments. + """ + + def __init__(self, model_args: FluxModelArgs): + super().__init__() + + self.model_args = model_args + self.in_channels = model_args.in_channels + self.out_channels = model_args.out_channels + if model_args.hidden_size % model_args.num_heads != 0: + raise ValueError( + f"Hidden size {model_args.hidden_size} must be divisible by num_heads {model_args.num_heads}" + ) + pe_dim = model_args.hidden_size // model_args.num_heads + if sum(model_args.axes_dim) != pe_dim: + raise ValueError( + f"Got {model_args.axes_dim} but expected positional dim {pe_dim}" + ) + self.hidden_size = model_args.hidden_size + self.num_heads = model_args.num_heads + self.pe_embedder = EmbedND( + dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim + ) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(model_args.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + if model_args.guidance_embed + else nn.Identity() + ) + self.txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=model_args.mlp_ratio, + qkv_bias=model_args.qkv_bias, + ) + for _ in range(model_args.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, self.num_heads, mlp_ratio=model_args.mlp_ratio + ) + for _ in range(model_args.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + def init_weights(self, buffer_device=None): + # TODO(jianiw): replace placeholder with real weight init + for param in self.parameters(): + param.data.uniform_(0, 0.1) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.model_args.guidance_embed: + if guidance is None: + raise ValueError( + "Didn't get guidance strength for guidance distilled model." + ) + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img + + @classmethod + def from_model_args(cls, model_args: FluxModelArgs) -> "FluxModel": + """ + Initialize a Flux model from a FluxModelArgs object. + + Args: + model_args (FluxModelArgs): Model configuration arguments. + + Returns: + FluxModel: FluxModel model. + + """ + return cls(model_args) diff --git a/torchtitan/experiments/flux/parallelize_flux.py b/torchtitan/experiments/flux/parallelize_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..fcdde64f86899ae19fa2f0891bdd71d14b9cbe97 --- /dev/null +++ b/torchtitan/experiments/flux/parallelize_flux.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D parallelisms (except pipeline parallelism) and various +# training techniques (e.g. activation checkpointing and compile) to the Llama model. + + +import torch.nn as nn + +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims + + +def parallelize_flux( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + # TODO: Add model parallel strategy here + return model diff --git a/torchtitan/experiments/flux/requirements.txt b/torchtitan/experiments/flux/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..92aa5840c430aa21213530bddc2009a0a85170f1 --- /dev/null +++ b/torchtitan/experiments/flux/requirements.txt @@ -0,0 +1,2 @@ +transformers +einops diff --git a/torchtitan/experiments/flux/scripts/download_autoencoder.py b/torchtitan/experiments/flux/scripts/download_autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c4dd4437bc583987da69ace57e61ef1b8314d582 --- /dev/null +++ b/torchtitan/experiments/flux/scripts/download_autoencoder.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +from requests.exceptions import HTTPError + + +def hf_download( + repo_id: str, file_path: str, local_dir: str, hf_token: Optional[str] = None +) -> None: + from huggingface_hub import hf_hub_download + + try: + hf_hub_download( + repo_id=repo_id, + filename=file_path, + local_dir=local_dir, + local_dir_use_symlinks=False, + token=hf_token, + ) + except HTTPError as e: + if e.response.status_code == 401: + print( + "You need to pass a valid `--hf_token=...` to download private checkpoints." + ) + else: + raise e + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.") + parser.add_argument( + "--repo_id", + type=str, + default="black-forest-labs/FLUX.1-dev", + help="Repository ID to download from. default to Flux-dev model", + ) + parser.add_argument( + "--ae_path", + type=str, + default="ae.safetensors", + help="the autoencoder path relative to repo_id", + ) + parser.add_argument( + "--hf_token", type=str, default=None, help="HuggingFace API token" + ) + parser.add_argument( + "--local_dir", + type=str, + default="torchtitan/experiments/flux/assets/autoencoder/", + help="local directory to save the autoencoder", + ) + + args = parser.parse_args() + hf_download(args.repo_id, args.ae_path, args.local_dir, args.hf_token) diff --git a/torchtitan/experiments/flux/tests/test_flux_dataloader.py b/torchtitan/experiments/flux/tests/test_flux_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..fc87f1b8b4ae3ad7daf1558835716720127e3b42 --- /dev/null +++ b/torchtitan/experiments/flux/tests/test_flux_dataloader.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +from torchtitan.config_manager import JobConfig +from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader +from torchtitan.tools.profiling import ( + maybe_enable_memory_snapshot, + maybe_enable_profiling, +) + + +class TestFluxDataLoader: + def test_flux_dataloader(self): + dataset_name = "cc12m" + batch_size = 32 + world_size = 4 + rank = 0 + + num_steps = 10 + + path = "torchtitan.experiments.flux.flux_argparser" + sys.argv.append(f"--experimental.custom_args_module={path}") + config = JobConfig() + config.maybe_add_custom_args() + config.parse_args( + [ + # Profiling options + # "--profiling.enable_profiling", + # "--profiling.profile_freq", + # "5", + # "--profiling.enable_memory_snapshot", + # "--profiling.save_memory_snapshot_folder", + # "memory_snapshot_flux", + "--training.dataset", + dataset_name, + "--training.batch_size", + str(batch_size), + "--encoder.t5_encoder", + "google/t5-v1_1-small", + "--encoder.clip_encoder", + "openai/clip-vit-large-patch14", + "--encoder.max_t5_encoding_len", + "512", + ] + ) + + with maybe_enable_profiling( + config, global_step=0 + ) as torch_profiler, maybe_enable_memory_snapshot( + config, global_step=0 + ) as memory_profiler: + dl = self._build_dataloader( + config, + world_size, + rank, + ) + dl = iter(dl) + + for i in range(0, num_steps): + input_data, labels = next(dl) + print(f"Step {i} image size: {labels.shape}") + if torch_profiler: + torch_profiler.step() + if memory_profiler: + memory_profiler.step() + + print(len(input_data["clip_tokens"])) + for k, v in input_data.items(): + print(f"Step {i} {k} value: {type(v), v.shape}") + + assert len(input_data) == 2 # (clip_encodings, t5_encodings) + assert labels.shape == (batch_size, 3, 256, 256) + # assert input_data["clip_tokens"].shape[0] == batch_size + # assert input_data["t5_tokens"].shape == (batch_size, 512, 512) + + if torch_profiler: + torch_profiler.step() + if memory_profiler: + memory_profiler.step(exit_ctx=True) + + def test_preprocess(self): + # TODO + pass + + def _build_dataloader( + self, + job_config, + world_size, + rank, + ): + + return build_flux_dataloader( + dp_world_size=world_size, + dp_rank=rank, + job_config=job_config, + tokenizer=None, + infinite=False, + ) diff --git a/torchtitan/experiments/flux/tests/test_generate_image.py b/torchtitan/experiments/flux/tests/test_generate_image.py new file mode 100644 index 0000000000000000000000000000000000000000..86d8d16cfbbcbfaa706e6ff6713403520744efd5 --- /dev/null +++ b/torchtitan/experiments/flux/tests/test_generate_image.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import os +import time +from typing import Callable + +import torch +from einops import rearrange + +from PIL import ExifTags, Image + +from torch import Tensor + +from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer + +from torchtitan.experiments.flux.model.autoencoder import ( + AutoEncoder, + AutoEncoderParams, + load_ae, +) +from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder + +from torchtitan.experiments.flux.model.model import FluxModel, FluxModelArgs +from torchtitan.experiments.flux.utils import ( + create_position_encoding_for_latents, + generate_noise_latent, + pack_latents, + preprocess_flux_data, + unpack_latents, +) + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # estimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +class TestGenerateImage: + def test_generate_image(self): + """ + Run a forward pass of flux model to generate an image. + """ + name = "flux-dev" + img_width = 512 + img_height = 512 + seed = None + prompt = ( + "a photo of a forest with mist swirling around the tree trunks. The word " + '"FLUX" is painted over it in big, red brush strokes with visible texture' + ) + device = "cuda" + num_steps = None + loop = False + guidance = 3.5 + output_dir = "output" + add_sampling_metadata = True + + prompt = prompt.split("|") + if len(prompt) == 1: + prompt = prompt[0] + additional_prompts = None + else: + additional_prompts = prompt[1:] + prompt = prompt[0] + + assert not ( + (additional_prompts is not None) and loop + ), "Do not provide additional prompts and set loop to True" + + torch_device = torch.device(device) + if num_steps is None: + num_steps = 30 + + # allow for packing and conversion to latent space + img_height = 16 * (img_height // 16) + img_width = 16 * (img_width // 16) + + # init all components + model = FluxModel(FluxModelArgs()).to(device=torch_device, dtype=torch.bfloat16) + + ae = load_ae( + ckpt_path="assets/autoencoder/ae.safetensors", + autoencoder_params=AutoEncoderParams(), + device=torch_device, + dtype=torch.bfloat16, + ) + clip_tokenizer = FluxTokenizer( + model_path="openai/clip-vit-large-patch14", max_length=77 + ) + t5_tokenizer = FluxTokenizer(model_path="google/t5-v1_1-small", max_length=512) + clip_encoder = FluxEmbedder(version="openai/clip-vit-large-patch14").to( + torch_device, dtype=torch.bfloat16 + ) + t5_encoder = FluxEmbedder(version="google/t5-v1_1-small").to( + torch_device, dtype=torch.bfloat16 + ) + + rng = torch.Generator(device="cpu") + + if seed is None: + seed = rng.seed() + print(f"Generating with seed {seed}:\n{prompt}") + t0 = time.perf_counter() + output_name = os.path.join(output_dir, f"img_{seed}.jpg") + + # Tokenize the prompt, on CPU + clip_tokens = clip_tokenizer.encode(prompt) + t5_tokens = t5_tokenizer.encode(prompt) + + batch = preprocess_flux_data( + device=torch_device, + dtype=torch.bfloat16, + autoencoder=None, + clip_encoder=clip_encoder, + t5_encoder=t5_encoder, + batch={ + "clip_tokens": clip_tokens, + "t5_tokens": t5_tokens, + }, + ) + + img = self._generate_images( + device=torch_device, + dtype=torch.bfloat16, + model=model, + decoder=ae, + img_width=img_width, + img_height=img_height, + denoising_steps=num_steps, + seed=seed, + clip_encodings=batch["clip_encodings"], + t5_encodings=batch["t5_encodings"], + guidance=guidance, + ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t1 = time.perf_counter() + + print(f"Done in {t1 - t0:.1f}s.") + + self._save_image(name, output_name, img, add_sampling_metadata, prompt) + + def _generate_images( + self, + device: torch.device, + dtype: torch.dtype, + model: FluxModel, + decoder: AutoEncoder, + # image params: + img_width: int, + img_height: int, + # sampling params: + denoising_steps: int, + seed: int, + clip_encodings: torch.Tensor, + t5_encodings: torch.Tensor, + guidance: float = 4.0, + ): + + bsz = clip_encodings.shape[0] + latents = generate_noise_latent(bsz, img_height, img_width, device, dtype, seed) + _, latent_channels, latent_height, latent_width = latents.shape + + # create denoising schedule + timesteps = get_schedule(denoising_steps, latent_channels, shift=True) + + # create positional encodings + POSITION_DIM = 3 # constant for Flux flow model + latent_pos_enc = create_position_encoding_for_latents( + bsz, latent_height, latent_width, POSITION_DIM + ).to(latents) + text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents) + + # convert img-like latents into sequences of patches + latents = pack_latents(latents) + + # this is ignored for schnell + guidance_vec = torch.full((bsz,), guidance, device=device, dtype=dtype) + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device) + pred = model( + img=latents, + img_ids=latent_pos_enc, + txt=t5_encodings, + txt_ids=text_pos_enc, + y=clip_encodings, + timesteps=t_vec, + guidance=guidance_vec, + ) + + latents = latents + (t_prev - t_curr) * pred + + # convert sequences of patches into img-like latents + latents = unpack_latents(latents, latent_height, latent_width) + + img = decoder.decode(latents) + return img + + def _save_image( + self, + name: str, + output_name: str, + x: torch.Tensor, + add_sampling_metadata: bool, + prompt: str, + ): + print(f"Saving {output_name}") + # bring into PIL format and save + x = x.clamp(-1, 1) + x = rearrange(x[0], "c h w -> h w c") + + img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + + exif_data = Image.Exif() + exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" + exif_data[ExifTags.Base.Make] = "Black Forest Labs" + exif_data[ExifTags.Base.Model] = name + if add_sampling_metadata: + exif_data[ExifTags.Base.ImageDescription] = prompt + img.save(output_name, exif=exif_data, quality=95, subsampling=0) diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py new file mode 100644 index 0000000000000000000000000000000000000000..064e854b2650c4792295438247fbe37e56a1d1b2 --- /dev/null +++ b/torchtitan/experiments/flux/train.py @@ -0,0 +1,224 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from typing import Optional + +import torch + +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import utils as dist_utils +from torchtitan.experiments.flux.model.autoencoder import load_ae +from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder +from torchtitan.experiments.flux.model.model import FluxModel +from torchtitan.experiments.flux.utils import ( + create_position_encoding_for_latents, + pack_latents, + preprocess_flux_data, + unpack_latents, +) +from torchtitan.tools.logging import init_logger, logger +from torchtitan.train import Trainer + + +class FluxTrainer(Trainer): + def __init__(self, job_config: JobConfig): + super().__init__(job_config) + + self.preprocess_fn = preprocess_flux_data + # self.dtype = job_config.encoder.dtype + self._dtype = torch.bfloat16 + self._seed = job_config.training.seed + self._guidance = job_config.training.guidance + + # load components + model_config = self.train_spec.config[job_config.model.flavor] + self.autoencoder = load_ae( + job_config.encoder.auto_encoder_path, + model_config.autoencoder_params, + device="cpu", + dtype=self._dtype, + ) + self.clip_encoder = FluxEmbedder(version=job_config.encoder.clip_encoder).to( + dtype=self._dtype + ) + self.t5_encoder = FluxEmbedder(version=job_config.encoder.t5_encoder).to( + dtype=self._dtype + ) + + def _predict_noise( + self, + model: FluxModel, + latents: torch.Tensor, + clip_encodings: torch.Tensor, + t5_encodings: torch.Tensor, + timesteps: torch.Tensor, + guidance: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Use Flux's flow-matching model to predict the noise in image latents. + Args: + model (FluxFlowModel): The Flux flow model. + latents (Tensor): Image encodings from the Flux autoencoder. + Shape: [bsz, 16, latent height, latent width] + clip_encodings (Tensor): CLIP text encodings. + Shape: [bsz, 768] + t5_encodings (Tensor): T5 text encodings. + Shape: [bsz, sequence length, 256 or 512] + timesteps (Tensor): The amount of noise (0 to 1). + Shape: [bsz] + guidance (Optional[Tensor]): The guidance value (1.5 to 4) if guidance-enabled model. + Shape: [bsz] + Default: None + model_ctx (ContextManager): Optional context to wrap the model call (e.g. for activation offloading) + Default: nullcontext + Returns: + Tensor: The noise prediction. + Shape: [bsz, 16, latent height, latent width] + """ + bsz, _, latent_height, latent_width = latents.shape + + POSITION_DIM = 3 # constant for Flux flow model + with torch.no_grad(): + # Create positional encodings + latent_pos_enc = create_position_encoding_for_latents( + bsz, latent_height, latent_width, POSITION_DIM + ) + text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM) + + # Convert latent into a sequence of patches + latents = pack_latents(latents) + + # Predict noise + latent_noise_pred = model( + img=latents, + img_ids=latent_pos_enc.to(latents), + txt=t5_encodings.to(latents), + txt_ids=text_pos_enc.to(latents), + y=clip_encodings.to(latents), + timesteps=timesteps.to(latents), + guidance=guidance.to(latents) if guidance is not None else None, + ) + + # Convert sequence of patches to latent shape + latent_noise_pred = unpack_latents( + latent_noise_pred, latent_height, latent_width + ) + + return latent_noise_pred + + def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): + # generate t5 and clip + input_dict["image"] = labels + input_dict = self.preprocess_fn( + device=self.device, + dtype=self._dtype, + autoencoder=self.autoencoder, + clip_encoder=self.clip_encoder, + t5_encoder=self.t5_encoder, + batch=input_dict, + offload=True, + ) + labels = input_dict["img_encodings"] + + self.optimizers.zero_grad() + + # Keep these variables local to shorten the code as these are + # the major variables that are used in the training loop. + model_parts = self.model_parts + world_mesh = self.world_mesh + parallel_dims = self.parallel_dims + + # image in latent space transformed by self.auto_encoder + clip_encodings = input_dict["clip_encodings"] + t5_encodings = input_dict["t5_encodings"] + + bsz = labels.shape[0] + + with torch.no_grad(): + noise = torch.randn_like(labels) + timesteps = torch.rand((bsz,)).to(labels) + sigmas = timesteps.view(-1, 1, 1, 1) + noisy_latents = (1 - sigmas) * labels + sigmas * noise + guidance = torch.full((bsz,), self._guidance).to(labels) + + target = noise - labels + + assert len(model_parts) == 1 + # TODO(jianiw): model_parts will be wrapped by FSDP, which will cacluate + model_parts[0] = model_parts[0].to(dtype=self._dtype) + + pred = self._predict_noise( + model_parts[0], + noisy_latents, + clip_encodings, + t5_encodings, + timesteps, + guidance, + ) + loss = self.loss_fn(pred, target) + # pred.shape=(bs, seq_len, vocab_size) + # need to free to before bwd to avoid peaking memory + del (pred, noise, target) + loss.backward() + + dist_utils.clip_grad_norm_( + [p for m in model_parts for p in m.parameters()], + self.job_config.training.max_norm, + foreach=True, + pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None, + ) + self.checkpointer.maybe_wait_for_staging() + self.optimizers.step() + self.lr_schedulers.step() + + # log metrics + if not self.metrics_processor.should_log(self.step): + return + + if ( + parallel_dims.dp_replicate_enabled + or parallel_dims.dp_shard_enabled + or parallel_dims.cp_enabled + ): + loss = loss.detach() + global_avg_loss, global_max_loss = ( + dist_utils.dist_mean(loss, world_mesh["dp_cp"]), + dist_utils.dist_max(loss, world_mesh["dp_cp"]), + ) + else: + global_avg_loss = global_max_loss = loss.item() + + self.metrics_processor.log(self.step, global_avg_loss, global_max_loss) + + +if __name__ == "__main__": + init_logger() + config = JobConfig() + config.maybe_add_custom_args() + config.parse_args() + trainer: Optional[FluxTrainer] = None + + try: + trainer = FluxTrainer(config) + if config.checkpoint.create_seed_checkpoint: + assert int( + os.environ["WORLD_SIZE"] + ), "Must create seed checkpoint using a single device, to disable sharding." + assert ( + config.checkpoint.enable_checkpoint + ), "Must enable checkpointing when creating a seed checkpoint." + trainer.checkpointer.save(curr_step=0, force=True) + logger.info("Created seed checkpoint") + else: + trainer.train() + finally: + if trainer: + trainer.close() + + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + logger.info("Process group destroyed.") diff --git a/torchtitan/experiments/flux/train_configs/debug_model.toml b/torchtitan/experiments/flux/train_configs/debug_model.toml new file mode 100644 index 0000000000000000000000000000000000000000..250a71d60ec28028b548803bad7f14b6b3a6db62 --- /dev/null +++ b/torchtitan/experiments/flux/train_configs/debug_model.toml @@ -0,0 +1,68 @@ + +[job] +dump_folder = "./outputs" +description = "Flux debug model" +print_args = false +use_for_integration_test = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "flux" +flavor = "flux-debug" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm +# test tokenizer.model, for debug purpose only +# tokenizer_path = "./tests/assets/test_tiktoken.model" +# converters = "float8" + + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 0.0 + +[training] +batch_size = 32 +seq_len = 512 +max_norm = 1.0 # grad norm clipping +steps = 10 +compile = false +dataset = "cc12m" +guidance = 3.5 +seed = 0 + +[encoder] +t5_encoder="google/t5-v1_1-small" +clip_encoder="openai/clip-vit-large-patch14" +max_t5_encoding_len=512 +auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = 1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[experimental] +custom_args_module = "torchtitan.experiments.flux.flux_argparser" diff --git a/torchtitan/experiments/flux/utils.py b/torchtitan/experiments/flux/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..15db50d90c81ed0fa9f5296a1c725af8005e3601 --- /dev/null +++ b/torchtitan/experiments/flux/utils.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from torch import Tensor + +from torchtitan.experiments.flux.model.autoencoder import AutoEncoder +from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder + + +def preprocess_flux_data( + # arguments from the recipe + device: torch.device, + dtype: torch.dtype, + *, + # arguments from the config + autoencoder: Optional[AutoEncoder], + clip_encoder: FluxEmbedder, + t5_encoder: FluxEmbedder, + batch: dict[str, Tensor], + offload: bool = False, +) -> dict[str, Tensor]: + """ + Take a batch of inputs and encoder as input and return a batch of preprocessed data. + + Args: + device (torch.device): device to do preprocessing on + dtype (torch.dtype): data type to do preprocessing in + autoencoer(AutoEncoder): autoencoder to use for preprocessing + clip_encoder + t5_encoder + batch (dict[str, Tensor]): batch of data to preprocess + + Returns: + dict[str, Tensor]: batch of preprocessed data + """ + + # The input of encoder should be torch.int type + if offload: + clip_encoder.to(device) + t5_encoder.to(device) + if autoencoder is not None: + autoencoder.to(device) + + clip_tokens = batch["clip_tokens"].squeeze().to(device=device, dtype=torch.int) + t5_tokens = batch["t5_tokens"].squeeze().to(device=device, dtype=torch.int) + + clip_text_encodings = clip_encoder(clip_tokens) + t5_text_encodings = t5_encoder(t5_tokens) + + if autoencoder is not None: + images = batch["image"].to(device=device, dtype=dtype) + img_encodings = autoencoder.encode(images) + batch["img_encodings"] = img_encodings.to(device=device, dtype=dtype) + + batch["clip_encodings"] = clip_text_encodings.to(dtype) + batch["t5_encodings"] = t5_text_encodings.to(dtype) + + # offload encoders to cpu after preprocessing + if offload: + clip_encoder.to("cpu") + t5_encoder.to("cpu") + if autoencoder is not None: + autoencoder.to("cpu") + + return batch + + +def generate_noise_latent( + bsz: int, + height: int, + width: int, + device: str | torch.device, + dtype: torch.dtype, + seed: int, +) -> Tensor: + """Generate noise latents for the Flux flow model. + + Args: + bsz (int): batch_size. + height (int): The height of the image. + width (int): The width of the image. + device (str | torch.device): The device to use. + dtype (torch.dtype): The dtype to use. + seed (int): The seed to use for randomize. + + Returns: + Tensor: The noise latents. + Shape: [num_samples, LATENT_CHANNELS, height // IMG_LATENT_SIZE_RATIO, width // IMG_LATENT_SIZE_RATIO] + + """ + LATENT_CHANNELS, IMAGE_LATENT_SIZE_RATIO = 16, 8 + return torch.randn( + bsz, + LATENT_CHANNELS, + height // IMAGE_LATENT_SIZE_RATIO, + width // IMAGE_LATENT_SIZE_RATIO, + dtype=dtype, + generator=torch.Generator().manual_seed(seed), + ).to(device) + + +def create_position_encoding_for_latents( + bsz: int, latent_height: int, latent_width: int, position_dim: int = 3 +) -> Tensor: + """ + Create the packed latents' position encodings for the Flux flow model. + + Args: + bsz (int): The batch size. + latent_height (int): The height of the latent. + latent_width (int): The width of the latent. + + Returns: + Tensor: The position encodings. + Shape: [bsz, (latent_height // PATCH_HEIGHT) * (latent_width // PATCH_WIDTH), POSITION_DIM) + """ + PATCH_HEIGHT, PATCH_WIDTH = 2, 2 + + height = latent_height // PATCH_HEIGHT + width = latent_width // PATCH_WIDTH + + position_encoding = torch.zeros(height, width, position_dim) + + row_indices = torch.arange(height) + position_encoding[:, :, 1] = row_indices.unsqueeze(1) + + col_indices = torch.arange(width) + position_encoding[:, :, 2] = col_indices.unsqueeze(0) + + # Flatten and repeat for the full batch + # [height, width, 3] -> [bsz, height * width, 3] + position_encoding = position_encoding.view(1, height * width, position_dim) + position_encoding = position_encoding.repeat(bsz, 1, 1) + + return position_encoding + + +def pack_latents(x: Tensor) -> Tensor: + """ + Rearrange latents from an image-like format into a sequence of patches. + Equivalent to `einops.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)")`. + + Args: + x (Tensor): The unpacked latents. + Shape: [bsz, ch, latent height, latent width] + + Returns: + Tensor: The packed latents. + Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw) + """ + PATCH_HEIGHT, PATCH_WIDTH = 2, 2 + + b, c, latent_height, latent_width = x.shape + h = latent_height // PATCH_HEIGHT + w = latent_width // PATCH_WIDTH + + # [b, c, h*ph, w*ph] -> [b, c, h, w, ph, pw] + x = x.unfold(2, PATCH_HEIGHT, PATCH_HEIGHT).unfold(3, PATCH_WIDTH, PATCH_WIDTH) + + # [b, c, h, w, ph, PW] -> [b, h, w, c, ph, PW] + x = x.permute(0, 2, 3, 1, 4, 5) + + # [b, h, w, c, ph, PW] -> [b, h*w, c*ph*PW] + return x.reshape(b, h * w, c * PATCH_HEIGHT * PATCH_WIDTH) + + +def unpack_latents(x: Tensor, latent_height: int, latent_width: int) -> Tensor: + """ + Rearrange latents from a sequence of patches into an image-like format. + Equivalent to `einops.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)")`. + + Args: + x (Tensor): The packed latents. + Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw) + latent_height (int): The height of the unpacked latents. + latent_width (int): The width of the unpacked latents. + + Returns: + Tensor: The unpacked latents. + Shape: [bsz, ch, latent height, latent width] + """ + PATCH_HEIGHT, PATCH_WIDTH = 2, 2 + + b, _, c_ph_pw = x.shape + h = latent_height // PATCH_HEIGHT + w = latent_width // PATCH_WIDTH + c = c_ph_pw // (PATCH_HEIGHT * PATCH_WIDTH) + + # [b, h*w, c*ph*pw] -> [b, h, w, c, ph, pw] + x = x.reshape(b, h, w, c, PATCH_HEIGHT, PATCH_WIDTH) + + # [b, h, w, c, ph, pw] -> [b, c, h, ph, w, pw] + x = x.permute(0, 3, 1, 4, 2, 5) + + # [b, c, h, ph, w, pw] -> [b, c, h*ph, w*pw] + return x.reshape(b, c, h * PATCH_HEIGHT, w * PATCH_WIDTH) diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..7dbabd1317a5923545f24c9a77feca46f5a92130 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py @@ -0,0 +1,630 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Benchmark comparing reference PyTorch vs optimized M*G group GEMM implementation + +import argparse +import logging +import time + +# from typing import Dict, List, Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import torch +import triton + +# import triton.language as tl + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +# Try to import the optimized implementations +try: + from torchao_pr.mg_grouped_gemm import grouped_gemm_forward + +except ImportError: + logging.error( + "Error importing MG grouped GEMM modules. Make sure the implementation files are in the correct path." + ) + raise + + +def compute_reference_forward(x, w, m_sizes): + """ + Reference PyTorch implementation of M*G grouped GEMM forward pass. + + Args: + x (torch.Tensor): Input tensor of shape (M, K) + w (torch.Tensor): Weight tensor of shape (N, K) + m_sizes (torch.Tensor): Group sizes tensor of shape (G) + + Returns: + torch.Tensor: Output tensor of shape (M, N) + """ + result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device) + + m_start = 0 + for g in range(len(m_sizes)): + m_size = m_sizes[g].item() + if m_size > 0: + m_end = m_start + m_size + + # Extract group input + x_g = x[m_start:m_end] + + # Compute group output + y_g = torch.matmul(x_g, w.T) + + # Store result + result[m_start:m_end] = y_g + + # Update start index + m_start = m_end + + return result + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["N"], # We'll vary the output dimension + x_vals=[1024, 2048, 4096, 8192, 16384], # Different output dimensions to test + # x_vals=[8192, 16384], + line_arg="provider", # We'll compare different providers + line_vals=["pytorch_reference", "M*G grouped GEMM"], + line_names=["PyTorch Reference", "M*G grouped Kernel"], + styles=[("blue", "-"), ("red", "-")], + ylabel="TFLOPS", # We'll measure TFLOPS + plot_name="mg_grouped_gemm_comparison", + args={ + "M": 8192, # Batch dimension, fixed for all tests + "K": 7168, # Hidden dimension, fixed for all tests + "G": 8, # Number of groups + "dtype": torch.float16, + "device": "cuda", + }, + ) +) +def benchmark_forward(M, K, N, G, provider, dtype=torch.float16, device="cuda"): + """ + Benchmark the forward pass of the grouped GEMM implementation. + + Args: + M (int): Total batch size dimension + K (int): Hidden dimension + N (int): Output dimension + G (int): Number of groups + provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel') + dtype (torch.dtype): Data type to use + device (str): Device to use + + Returns: + float: Performance in TFLOPS + """ + # Create group sizes for M dimension (balanced across groups) + base_size = M // G + remainder = M % G + M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)] + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + print(f"N: {N}, M: {M}, K: {K}, G: {G}, dtype: {dtype}, device: {device}") + + # Create input and weight tensors + x = torch.randn(M, K, dtype=dtype, device=device) + w = torch.randn(N, K, dtype=dtype, device=device) + + # Pre-compute for PyTorch reference to ensure fair comparison + if provider == "pytorch_reference": + # Warmup + torch.cuda.synchronize() + compute_reference_forward(x, w, m_sizes) + torch.cuda.synchronize() + + # Benchmark + start_time = time.time() + for _ in range(10): # Average over 10 runs + compute_reference_forward(x, w, m_sizes) + torch.cuda.synchronize() + end_time = time.time() + else: # Optimized kernel + # Warmup + torch.cuda.synchronize() + grouped_gemm_forward(x, w, m_sizes) + torch.cuda.synchronize() + + # Benchmark + start_time = time.time() + for _ in range(10): # Average over 10 runs + grouped_gemm_forward(x, w, m_sizes) + torch.cuda.synchronize() + end_time = time.time() + + # Calculate FLOPs + # For GEMM: 2 * M * N * K FLOPs (multiply-add counts as 2 FLOPs) + flops = 2 * M * N * K + + # Convert to TFLOPS (tera-FLOPS) + avg_time = (end_time - start_time) / 10 # Average time per run + tflops = flops / avg_time / 1e12 + + return tflops + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["G"], # We'll vary the number of groups + x_vals=[1, 2, 4, 8, 16], # Different numbers of groups to test + line_arg="provider", # We'll compare different providers + line_vals=["pytorch_reference", "optimized_kernel"], + line_names=["PyTorch Reference", "Optimized Kernel"], + styles=[("blue", "-"), ("red", "-")], + ylabel="TFLOPS", # We'll measure TFLOPS + plot_name="mg_grouped_gemm_group_scaling", + args={ + "M": 8192, # Batch dimension, fixed for all tests + "K": 4096, # Hidden dimension, fixed for all tests + "N": 8192, # Output dimension, fixed for all tests + "dtype": torch.float16, + "device": "cuda", + }, + ) +) +def benchmark_forward_groups(M, K, N, G, provider, dtype=torch.float16, device="cuda"): + """ + Benchmark how performance scales with number of groups. + + Args: + M (int): Total batch size dimension + K (int): Hidden dimension + N (int): Output dimension + G (int): Number of groups + provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel') + dtype (torch.dtype): Data type to use + device (str): Device to use + + Returns: + float: Performance in TFLOPS + """ + # Create group sizes for M dimension (balanced across groups) + base_size = M // G + remainder = M % G + M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)] + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create input and weight tensors + x = torch.randn(M, K, dtype=dtype, device=device) + w = torch.randn(N, K, dtype=dtype, device=device) + + # Benchmark logic - same as previous function + if provider == "pytorch_reference": + torch.cuda.synchronize() + compute_reference_forward(x, w, m_sizes) + torch.cuda.synchronize() + + start_time = time.time() + for _ in range(10): + compute_reference_forward(x, w, m_sizes) + torch.cuda.synchronize() + end_time = time.time() + else: + torch.cuda.synchronize() + grouped_gemm_forward(x, w, m_sizes) + torch.cuda.synchronize() + + start_time = time.time() + for _ in range(10): + grouped_gemm_forward(x, w, m_sizes) + torch.cuda.synchronize() + end_time = time.time() + + # Calculate FLOPs and TFLOPS + flops = 2 * M * N * K + avg_time = (end_time - start_time) / 10 + tflops = flops / avg_time / 1e12 + + return tflops + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["group_balance"], # We'll vary the group balance factor + x_vals=[ + 0.0, + 0.25, + 0.5, + 0.75, + 0.9, + ], # Different imbalance factors (0 = balanced, 1 = max imbalance) + line_arg="provider", # We'll compare different providers + line_vals=["pytorch_reference", "optimized_kernel"], + line_names=["PyTorch Reference", "Optimized Kernel"], + styles=[("blue", "-"), ("red", "-")], + ylabel="TFLOPS", # We'll measure TFLOPS + plot_name="mg_grouped_gemm_imbalance", + args={ + "M": 8192, # Batch dimension, fixed for all tests + "K": 4096, # Hidden dimension, fixed for all tests + "N": 8192, # Output dimension, fixed for all tests + "G": 4, # Number of groups + "dtype": torch.float16, + "device": "cuda", + }, + ) +) +def benchmark_imbalance( + M, K, N, G, group_balance, provider, dtype=torch.float16, device="cuda" +): + """ + Benchmark how performance is affected by imbalanced group sizes. + + Args: + M (int): Total batch size dimension + K (int): Hidden dimension + N (int): Output dimension + G (int): Number of groups + group_balance (float): Balance factor from 0 to 1 (0 = balanced, 1 = max imbalance) + provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel') + dtype (torch.dtype): Data type to use + device (str): Device to use + + Returns: + float: Performance in TFLOPS + """ + # Create imbalanced group sizes for M dimension + if group_balance == 0: + # Balanced case + base_size = M // G + remainder = M % G + M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)] + else: + # Imbalanced case + # First group gets more elements, last group gets fewer + # The imbalance is controlled by the group_balance factor + remaining = M + M_sizes = [] + for g in range(G): + # Interpolate from balanced to imbalanced based on group_balance + # For balanced (group_balance=0), each group gets M/G + # For imbalanced (group_balance=1), first group gets much more than last group + balanced_size = remaining // (G - g) + + # Adjusting size based on position and imbalance factor + # First groups get more, last groups get less + if g < G // 2: + # First half of groups get more + adjustment = int(balanced_size * group_balance * (1 - g / (G - 1))) + size = balanced_size + adjustment + else: + # Second half of groups get less + adjustment = int(balanced_size * group_balance * ((g / (G - 1)) - 0.5)) + size = balanced_size - adjustment + + # Ensure we don't go below 1 or take more than remaining + size = max(1, min(size, remaining)) + M_sizes.append(size) + remaining -= size + + # Handle any remaining elements + if remaining > 0: + M_sizes[-1] += remaining + + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create input and weight tensors + x = torch.randn(M, K, dtype=dtype, device=device) + w = torch.randn(N, K, dtype=dtype, device=device) + + # Benchmark logic + if provider == "pytorch_reference": + torch.cuda.synchronize() + compute_reference_forward(x, w, m_sizes) + torch.cuda.synchronize() + + start_time = time.time() + for _ in range(10): + compute_reference_forward(x, w, m_sizes) + torch.cuda.synchronize() + end_time = time.time() + else: + torch.cuda.synchronize() + grouped_gemm_forward(x, w, m_sizes) + torch.cuda.synchronize() + + start_time = time.time() + for _ in range(10): + grouped_gemm_forward(x, w, m_sizes) + torch.cuda.synchronize() + end_time = time.time() + + # Calculate FLOPs and TFLOPS + flops = 2 * M * N * K + avg_time = (end_time - start_time) / 10 + tflops = flops / avg_time / 1e12 + + return tflops + + +def benchmark_model_configs(): + """ + Benchmark common model configurations used in DeepSeek-like models. + """ + # Model configurations: (M, K, N, G) + configs = [ + (8192, 7168, 4096, 4), # Config 1 + (8192, 2048, 7168, 4), # Config 2 + (4096, 7168, 4096, 8), # Config 3 + (4096, 2048, 7168, 8), # Config 4 + ] + + results = [] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float16 + + for config_idx, (M, K, N, G) in enumerate(configs): + logging.info(f"\n===== Benchmarking DeepSeek Config {config_idx + 1} =====") + logging.info(f"M={M}, K={K}, N={N}, G={G}") + + # Create group sizes for M dimension + base_size = M // G + remainder = M % G + M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)] + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create tensors + x = torch.randn(M, K, dtype=dtype, device=device) + w = torch.randn(N, K, dtype=dtype, device=device) + + # Benchmark PyTorch reference + torch.cuda.synchronize() + compute_reference_forward(x, w, m_sizes) # Warmup + torch.cuda.synchronize() + + logging.info("Benchmarking PyTorch reference...") + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + for _ in range(10): + compute_reference_forward(x, w, m_sizes) + torch.cuda.synchronize() + end_time = time.time() + pt_time = (end_time - start_time) / 10 + pt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB + + # Benchmark optimized kernel + torch.cuda.synchronize() + grouped_gemm_forward(x, w, m_sizes) # Warmup + torch.cuda.synchronize() + + logging.info("Benchmarking optimized kernel...") + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + for _ in range(10): + grouped_gemm_forward(x, w, m_sizes) + torch.cuda.synchronize() + end_time = time.time() + opt_time = (end_time - start_time) / 10 + opt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB + + # Calculate FLOPs and speedup + flops = 2 * M * N * K + pt_tflops = flops / pt_time / 1e12 + opt_tflops = flops / opt_time / 1e12 + speedup = pt_time / opt_time + + # Store results + results.append( + { + "config": f"Config {config_idx + 1}", + "dimensions": f"M={M}, K={K}, N={N}, G={G}", + "pt_time_ms": pt_time * 1000, + "opt_time_ms": opt_time * 1000, + "pt_tflops": pt_tflops, + "opt_tflops": opt_tflops, + "speedup": speedup, + "pt_memory_mb": pt_memory, + "opt_memory_mb": opt_memory, + "memory_savings": ( + (pt_memory - opt_memory) / pt_memory * 100 if pt_memory > 0 else 0 + ), + } + ) + + logging.info( + f"PyTorch Reference: {pt_time * 1000:.2f} ms, {pt_tflops:.2f} TFLOPS, {pt_memory:.2f} MB" + ) + logging.info( + f"Optimized Kernel: {opt_time * 1000:.2f} ms, {opt_tflops:.2f} TFLOPS, {opt_memory:.2f} MB" + ) + logging.info( + f"Speedup: {speedup:.2f}x, Memory savings: {results[-1]['memory_savings']:.2f}%" + ) + + # Print summary table + logging.info("\n===== Benchmark Results Summary =====") + logging.info( + f"{'Config':<10} | {'Time (ms)':<20} | {'TFLOPS':<20} | {'Speedup':<10} | {'Memory (MB)':<20} | {'Memory Saved':<12}" + ) + logging.info( + f"{'':<10} | {'PyTorch':<9} {'Kernel':<9} | {'PyTorch':<9} {'Kernel':<9} | {'':<10} | " + f"{'PyTorch':<9} {'Kernel':<9} | {'':<12}" + ) + logging.info("-" * 100) + + for result in results: + logging.info( + f"{result['config']:<10} | " + f"{result['pt_time_ms']:<9.2f} {result['opt_time_ms']:<9.2f} | " + f"{result['pt_tflops']:<9.2f} {result['opt_tflops']:<9.2f} | " + f"{result['speedup']:<10.2f} | " + f"{result['pt_memory_mb']:<9.2f} {result['opt_memory_mb']:<9.2f} | " + f"{result['memory_savings']:<12.2f}%" + ) + + return results + + +def plot_benchmark_results(results): + """ + Plot benchmark results as bar charts. + """ + # Extract data + configs = [r["config"] for r in results] + pt_tflops = [r["pt_tflops"] for r in results] + opt_tflops = [r["opt_tflops"] for r in results] + speedups = [r["speedup"] for r in results] + + # Create figure with subplots + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) + + # Plot TFLOPS comparison + x = np.arange(len(configs)) + width = 0.35 + ax1.bar(x - width / 2, pt_tflops, width, label="PyTorch Reference") + ax1.bar(x + width / 2, opt_tflops, width, label="Optimized Kernel") + ax1.set_xlabel("Model Configuration") + ax1.set_ylabel("TFLOPS") + ax1.set_title("Performance Comparison (Higher is Better)") + ax1.set_xticks(x) + ax1.set_xticklabels(configs) + ax1.legend() + ax1.grid(axis="y", linestyle="--", alpha=0.7) + + # Plot speedup + ax2.bar(x, speedups, width=0.6, color="green") + ax2.set_xlabel("Model Configuration") + ax2.set_ylabel("Speedup (x)") + ax2.set_title("Speedup Factor (Higher is Better)") + ax2.set_xticks(x) + ax2.set_xticklabels(configs) + ax2.grid(axis="y", linestyle="--", alpha=0.7) + + # Add speedup values on top of bars + for i, v in enumerate(speedups): + ax2.text(i, v + 0.1, f"{v:.2f}x", ha="center") + + plt.tight_layout() + plt.savefig("mg_grouped_gemm_benchmark_results.png") + logging.info( + "Benchmark results plot saved to 'mg_grouped_gemm_benchmark_results.png'" + ) + + +def compare_mg_implementations(): + """ + Combine the M*G and N*G benchmark results for comparison. + """ + # Only run this if both NG and MG benchmarks have been run + try: + import pandas as pd + + # Try to load previous benchmark results + mg_results = pd.read_csv("mg_grouped_gemm_benchmark_results.csv") + ng_results = pd.read_csv("ng_grouped_gemm_benchmark_results.csv") + + # Create comparison plot + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + + # Plot speedup comparison + configs = mg_results["config"].unique() + mg_speedups = mg_results.groupby("config")["speedup"].mean() + ng_speedups = ng_results.groupby("config")["speedup"].mean() + + x = np.arange(len(configs)) + width = 0.35 + + axes[0].bar(x - width / 2, mg_speedups, width, label="M*G Grouping") + axes[0].bar(x + width / 2, ng_speedups, width, label="N*G Grouping") + axes[0].set_xlabel("Model Configuration") + axes[0].set_ylabel("Speedup (x)") + axes[0].set_title("Speedup Comparison: M*G vs N*G") + axes[0].set_xticks(x) + axes[0].set_xticklabels(configs) + axes[0].legend() + axes[0].grid(axis="y", linestyle="--", alpha=0.7) + + # Plot TFLOPS comparison for optimized kernels + mg_tflops = ( + mg_results[mg_results["implementation"] == "optimized"] + .groupby("config")["tflops"] + .mean() + ) + ng_tflops = ( + ng_results[ng_results["implementation"] == "optimized"] + .groupby("config")["tflops"] + .mean() + ) + + axes[1].bar(x - width / 2, mg_tflops, width, label="M*G Grouping") + axes[1].bar(x + width / 2, ng_tflops, width, label="N*G Grouping") + axes[1].set_xlabel("Model Configuration") + axes[1].set_ylabel("TFLOPS") + axes[1].set_title("Performance Comparison: M*G vs N*G") + axes[1].set_xticks(x) + axes[1].set_xticklabels(configs) + axes[1].legend() + axes[1].grid(axis="y", linestyle="--", alpha=0.7) + + plt.tight_layout() + plt.savefig("mg_vs_ng_comparison.png") + logging.info("Comparison plot saved to 'mg_vs_ng_comparison.png'") + + except Exception as e: + logging.error(f"Could not create comparison plot: {e}") + logging.info( + "Run both M*G and N*G benchmarks first to generate comparison plots" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Benchmark M*G Grouped GEMM implementations" + ) + parser.add_argument("--run-all", action="store_true", help="Run all benchmarks") + parser.add_argument( + "--triton-bench", action="store_true", help="Run Triton performance reports" + ) + parser.add_argument( + "--model-configs", action="store_true", help="Benchmark model configurations" + ) + parser.add_argument( + "--compare-mg-ng", + action="store_true", + help="Compare M*G and N*G implementations", + ) + args = parser.parse_args() + + # Check if CUDA is available + if not torch.cuda.is_available(): + logging.error( + "CUDA is not available. This benchmark requires a CUDA-capable GPU." + ) + exit(1) + + if args.run_all or args.model_configs: + # Benchmark model configurations + logging.info("Running benchmark for model configurations...") + results = benchmark_model_configs() + plot_benchmark_results(results) + + if args.run_all or args.triton_bench: + # Run Triton performance reports + logging.info("Running Triton performance reports...") + benchmark_forward.run(save_path="mg_grouped_gemm_benchmark_results") + benchmark_forward_groups.run(save_path="mg_grouped_gemm_benchmark_results") + benchmark_imbalance.run(save_path="mg_grouped_gemm_benchmark_results") + logging.info( + "Triton performance reports saved to 'mg_grouped_gemm_benchmark_results' directory" + ) + + if args.run_all or args.compare_mg_ng: + # Compare M*G and N*G implementations + logging.info("Comparing M*G and N*G implementations...") + compare_mg_implementations() diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py new file mode 100644 index 0000000000000000000000000000000000000000..7e893a54443a6c05a548b35325421e66db321d43 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py @@ -0,0 +1,885 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import math +import time + +from typing import Dict, List, Tuple + +# import numpy as np +import torch # +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +# from torchao_pr.mg_grouped_gemm import mg_grouped_gemm + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +# Try to import the optimized MG GEMM implementation +try: + from torchao_pr.mg_grouped_gemm import ( # grouped_gemm_backward, + grouped_gemm_forward, + ) + + has_mg_gemm = True +except ImportError: + logging.warning("MG GEMM implementation not found. Will use manual looping only.") + has_mg_gemm = False + + +class Router(nn.Module): + """ + Router module that assigns tokens to experts. + """ + + def __init__(self, input_dim: int, num_experts: int, top_k: int = 2): + super().__init__() + self.input_dim = input_dim + self.num_experts = num_experts + self.top_k = top_k + + # Routing layer + self.router = nn.Linear(input_dim, num_experts) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: + """ + Route input tokens to experts. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_dim) + + Returns: + Tuple containing: + - router_logits: Raw routing probabilities + - dispatch_tensor: One-hot tensor indicating expert assignment + - expert_indices: List of indices for each expert's tokens + """ + batch_size, seq_len, _ = x.shape + + # Flatten batch and sequence dimensions + x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim) + + # Compute routing probabilities + router_logits = self.router(x_flat) # (batch_size * seq_len, num_experts) + + # Apply softmax to get probabilities + router_probs = F.softmax(router_logits, dim=-1) + + # Get top-k experts for each token + top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1) + + # Normalize top-k probabilities + top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True) + + # Create dispatch tensor (one-hot representation of assignments) + dispatch_tensor = torch.zeros_like(router_probs) + token_indices = ( + torch.arange(router_probs.size(0), device=router_probs.device) + .unsqueeze(1) + .expand(-1, self.top_k) + ) + dispatch_tensor.scatter_(1, top_k_indices, top_k_probs) # .unsqueeze(-1)) + + # For each expert, get the indices of tokens routed to it + expert_indices = [] + for expert_idx in range(self.num_experts): + # Get indices of tokens that have non-zero probability for this expert + indices = torch.nonzero(dispatch_tensor[:, expert_idx] > 0, as_tuple=True)[ + 0 + ] + expert_indices.append(indices) + + return router_logits, dispatch_tensor, expert_indices + + +class Expert(nn.Module): + """ + Individual expert module. + """ + + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): + super().__init__() + self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False) + self.activation = nn.GELU() + self.fc2 = nn.Linear(hidden_dim, output_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + + +class MixtureOfExperts(nn.Module): + """ + Mixture of Experts layer with support for both manual looping and grouped GEMM. + """ + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_experts: int, + top_k: int = 2, + use_mg_gemm: bool = False, + ): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.num_experts = num_experts + self.top_k = top_k + self.use_mg_gemm = use_mg_gemm and has_mg_gemm + + # Router + self.router = Router(input_dim, num_experts, top_k) + + # Create expert modules + if self.use_mg_gemm: + # For MG GEMM, we need a single weight tensor for all experts + # First layer (input -> hidden) + self.expert_fc1_weight = nn.Parameter( + torch.randn(num_experts * hidden_dim, input_dim) / math.sqrt(input_dim) + ) + # self.expert_fc1_bias = nn.Parameter(torch.zeros(num_experts * hidden_dim)) + + # Second layer (hidden -> output) + self.expert_fc2_weight = nn.Parameter( + torch.randn(num_experts * output_dim, hidden_dim) + / math.sqrt(hidden_dim) + ) + # self.expert_fc2_bias = nn.Parameter(torch.zeros(num_experts * output_dim)) + else: + # For manual looping, create separate experts + self.experts = nn.ModuleList( + [Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)] + ) + + def forward_manual_loop(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass using manual looping over experts. + """ + batch_size, seq_len, _ = x.shape + x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim) + + # Get routing information + router_logits, dispatch_tensor, expert_indices = self.router(x) + + # Initialize output tensor + final_output = torch.zeros( + batch_size * seq_len, self.output_dim, device=x.device + ) + + # Process each expert + for expert_idx, indices in enumerate(expert_indices): + if indices.numel() > 0: + # Get tokens routed to this expert + expert_inputs = x_flat[indices] # (num_tokens_for_expert, input_dim) + + # Process tokens through expert + expert_outputs = self.experts[expert_idx]( + expert_inputs + ) # (num_tokens_for_expert, output_dim) + + # Scale outputs by router probabilities + scaled_outputs = expert_outputs * dispatch_tensor[ + indices, expert_idx + ].unsqueeze(1) + + # Add to final output + final_output.index_add_(0, indices, scaled_outputs) + + # Reshape back to original dimensions + output = final_output.reshape(batch_size, seq_len, self.output_dim) + + return output, router_logits + + def forward_mg_gemm(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, _ = x.shape + x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim) + total_tokens = batch_size * seq_len + + # Get routing information + router_logits, dispatch_tensor, expert_indices = self.router(x) + + # Get token counts for each expert + token_counts = [indices.numel() for indices in expert_indices] + m_sizes = torch.tensor(token_counts, dtype=torch.int32, device=x.device) + + print(f"Token counts per expert: {token_counts}") + print(f"m_sizes: {m_sizes}") + + # Create the combined input tensor + combined_input = torch.zeros(sum(token_counts), self.input_dim, device=x.device) + + start_idx = 0 + for expert_idx, indices in enumerate(expert_indices): + if indices.numel() > 0: + end_idx = start_idx + indices.numel() + combined_input[start_idx:end_idx] = x_flat[indices] + start_idx = end_idx + + print(f"combined_input shape: {combined_input.shape}") + + # First layer: input -> hidden + fc1_weight_reshaped = self.expert_fc1_weight.reshape( + self.num_experts, self.hidden_dim, self.input_dim + ) + fc1_weight_combined = fc1_weight_reshaped.reshape(-1, self.input_dim) + + print(f"fc1_weight_combined shape: {fc1_weight_combined.shape}") + + # Run the grouped GEMM + hidden_outputs = grouped_gemm_forward( + combined_input, fc1_weight_combined, m_sizes + ) + + print(f"hidden_outputs shape after first GEMM: {hidden_outputs.shape}") + + # Apply activation + hidden_outputs = F.gelu(hidden_outputs) + + print(f"hidden_outputs shape after activation: {hidden_outputs.shape}") + + # Second layer: hidden -> output + # Reshape hidden_outputs to match expected dimensions + reshaped_hidden_outputs = [] + start_idx = 0 + + for expert_idx, count in enumerate(token_counts): + if count > 0: + end_idx = start_idx + count + # Take this expert's outputs and reshape to [count, hidden_dim] + expert_output = hidden_outputs[ + start_idx:end_idx, + expert_idx * self.hidden_dim : (expert_idx + 1) * self.hidden_dim, + ] + reshaped_hidden_outputs.append(expert_output) + start_idx = end_idx + + # Concatenate all reshaped outputs + hidden_outputs = torch.cat(reshaped_hidden_outputs, dim=0) + + # Reshape expert weights for second layer + fc2_weight_reshaped = self.expert_fc2_weight.reshape( + self.num_experts, self.output_dim, self.hidden_dim + ) + fc2_weight_combined = fc2_weight_reshaped.reshape(-1, self.hidden_dim) + + print(f"fc2_weight_combined shape: {fc2_weight_combined.shape}") + + # Run the second grouped GEMM + expert_outputs_combined = grouped_gemm_forward( + hidden_outputs, fc2_weight_combined, m_sizes + ) + + # Initialize final output tensor with correct shape + final_output = torch.zeros(total_tokens, self.output_dim, device=x.device) + + # Distribute the outputs back to the original token positions + start_idx = 0 + for expert_idx, indices in enumerate(expert_indices): + if indices.numel() > 0: + end_idx = start_idx + indices.numel() + # Get this expert's outputs + expert_outputs = expert_outputs_combined[start_idx:end_idx] + + print( + f"Expert {expert_idx} - indices shape: {indices.shape}, expert_outputs shape: {expert_outputs.shape}" + ) + + # Scale outputs by router probabilities + scaled_outputs = expert_outputs * dispatch_tensor[ + indices, expert_idx + ].unsqueeze(1) + + # Ensure dimensions match before using index_add_ + if scaled_outputs.shape[1] != final_output.shape[1]: + # print( + # f"Reshaping: Dimension mismatch: scaled_outputs {scaled_outputs.shape}, final_output {final_output.shape}" + # ) + # Reshape if needed - make sure output_dim is correct + scaled_outputs = scaled_outputs[:, : self.output_dim] + + # Add to final output + final_output.index_add_(0, indices, scaled_outputs) + + start_idx = end_idx + + # Reshape back to original dimensions + output = final_output.reshape(batch_size, seq_len, self.output_dim) + + return output, router_logits + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_mg_gemm and has_mg_gemm: + return self.forward_mg_gemm(x) + else: + return self.forward_manual_loop(x) + + +class MoEModel(nn.Module): + """ + Simple model using MoE layers. + """ + + def __init__( + self, + vocab_size: int, + embed_dim: int, + hidden_dim: int, + num_experts: int, + top_k: int = 2, + use_mg_gemm: bool = False, + ): + super().__init__() + self.embedding = nn.Embedding(vocab_size, embed_dim) + self.moe_layer = MixtureOfExperts( + input_dim=embed_dim, + hidden_dim=hidden_dim, + output_dim=embed_dim, + num_experts=num_experts, + top_k=top_k, + use_mg_gemm=use_mg_gemm, + ) + self.output_layer = nn.Linear(embed_dim, vocab_size) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # x shape: (batch_size, seq_len) + embedded = self.embedding(x) # (batch_size, seq_len, embed_dim) + moe_output, router_logits = self.moe_layer( + embedded + ) # (batch_size, seq_len, embed_dim) + logits = self.output_layer(moe_output) # (batch_size, seq_len, vocab_size) + return logits, router_logits + + +def compute_load_balancing_loss( + router_logits: torch.Tensor, num_experts: int +) -> torch.Tensor: + """ + Compute the load balancing loss for MoE training. + + Args: + router_logits (torch.Tensor): Router logits of shape (batch_size * seq_len, num_experts) + num_experts (int): Number of experts + + Returns: + torch.Tensor: Load balancing loss + """ + # Get router probabilities + router_probs = F.softmax( + router_logits, dim=-1 + ) # (batch_size * seq_len, num_experts) + + # Compute fraction of tokens routed to each expert + # Sum across the batch dimension and normalize + router_probs_sum = router_probs.sum(dim=0) # (num_experts,) + router_probs_sum = router_probs_sum / router_probs_sum.sum() + + # Compute the mean probability per expert + mean_prob = 1.0 / num_experts + + # Compute the fraction of tokens routed to each expert + # The goal is to have uniform routing across experts + load_balancing_loss = num_experts * torch.sum(router_probs_sum * router_probs_sum) + + return load_balancing_loss + + +def generate_sample_data( + batch_size: int, seq_len: int, vocab_size: int, device: str = "cuda" +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Generate sample data for training. + + Args: + batch_size (int): Batch size + seq_len (int): Sequence length + vocab_size (int): Vocabulary size + device (str): Device to use + + Returns: + Tuple of input tokens and target tokens + """ + # Generate random input tokens + inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + + # Generate random target tokens + targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + + return inputs, targets + + +def train_epoch( + model: nn.Module, + optimizer: torch.optim.Optimizer, + batch_size: int, + seq_len: int, + vocab_size: int, + num_batches: int, + device: str, + load_balance_coef: float = 0.01, +) -> Dict[str, float]: + """ + Train the model for one epoch. + + Args: + model (nn.Module): Model to train + optimizer (torch.optim.Optimizer): Optimizer + batch_size (int): Batch size + seq_len (int): Sequence length + vocab_size (int): Vocabulary size + num_batches (int): Number of batches per epoch + device (str): Device to use + load_balance_coef (float): Coefficient for load balancing loss + + Returns: + Dict containing training metrics + """ + model.train() + total_loss = 0.0 + total_acc = 0.0 + start_time = time.time() + + for i in range(num_batches): + # Generate sample data + inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device) + + # Forward pass + optimizer.zero_grad() + logits, router_logits = model(inputs) + + # Compute loss + # Reshape for cross entropy loss + logits_flat = logits.reshape(-1, vocab_size) + targets_flat = targets.reshape(-1) + + # Cross entropy loss + ce_loss = F.cross_entropy(logits_flat, targets_flat) + + # Load balancing loss + lb_loss = compute_load_balancing_loss( + router_logits, model.moe_layer.num_experts + ) + + # Combined loss + loss = ce_loss + load_balance_coef * lb_loss + + # Backward pass + loss.backward() + optimizer.step() + + # Compute accuracy + preds = logits_flat.argmax(dim=-1) + correct = (preds == targets_flat).float().sum() + acc = correct / (batch_size * seq_len) + + # Accumulate metrics + total_loss += loss.item() + total_acc += acc.item() + + # Log progress + if (i + 1) % 10 == 0: + logging.info( + f"Batch {i + 1}/{num_batches} | " + f"Loss: {loss.item():.4f} | " + f"CE Loss: {ce_loss.item():.4f} | " + f"LB Loss: {lb_loss.item():.4f} | " + f"Acc: {acc.item():.4f}" + ) + + # Compute average metrics + avg_loss = total_loss / num_batches + avg_acc = total_acc / num_batches + epoch_time = time.time() - start_time + + return {"loss": avg_loss, "acc": avg_acc, "time": epoch_time} + + +def evaluate( + model: nn.Module, + batch_size: int, + seq_len: int, + vocab_size: int, + num_batches: int, + device: str, +) -> Dict[str, float]: + """ + Evaluate the model. + + Args: + model (nn.Module): Model to evaluate + batch_size (int): Batch size + seq_len (int): Sequence length + vocab_size (int): Vocabulary size + num_batches (int): Number of batches for evaluation + device (str): Device to use + + Returns: + Dict containing evaluation metrics + """ + model.eval() + total_loss = 0.0 + total_acc = 0.0 + + with torch.no_grad(): + for i in range(num_batches): + # Generate sample data + inputs, targets = generate_sample_data( + batch_size, seq_len, vocab_size, device + ) + + # Forward pass + logits, router_logits = model(inputs) + + # Compute loss + logits_flat = logits.reshape(-1, vocab_size) + targets_flat = targets.reshape(-1) + + # Cross entropy loss + loss = F.cross_entropy(logits_flat, targets_flat) + + # Compute accuracy + preds = logits_flat.argmax(dim=-1) + correct = (preds == targets_flat).float().sum() + acc = correct / (batch_size * seq_len) + + # Accumulate metrics + total_loss += loss.item() + total_acc += acc.item() + + # Compute average metrics + avg_loss = total_loss / num_batches + avg_acc = total_acc / num_batches + + return {"loss": avg_loss, "acc": avg_acc} + + +def measure_performance( + model: nn.Module, + batch_size: int, + seq_len: int, + vocab_size: int, + num_batches: int, + device: str, +) -> Dict[str, float]: + """ + Measure forward and backward pass performance. + + Args: + model (nn.Module): Model to evaluate + batch_size (int): Batch size + seq_len (int): Sequence length + vocab_size (int): Vocabulary size + num_batches (int): Number of batches for measurement + device (str): Device to use + + Returns: + Dict containing performance metrics + """ + model.train() + + # Create dummy optimizer + optimizer = optim.Adam(model.parameters(), lr=0.001) + + # Warmup + for _ in range(5): + inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device) + logits, router_logits = model(inputs) + loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1)) + loss.backward() + optimizer.zero_grad() + + # Measure forward pass time + torch.cuda.synchronize() + forward_start = time.time() + + for _ in range(num_batches): + inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device) + with torch.no_grad(): + logits, router_logits = model(inputs) + + torch.cuda.synchronize() + forward_end = time.time() + forward_time = (forward_end - forward_start) / num_batches + + # Measure backward pass time + torch.cuda.synchronize() + backward_start = time.time() + + for _ in range(num_batches): + inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device) + logits, router_logits = model(inputs) + loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1)) + loss.backward() + optimizer.zero_grad() + + torch.cuda.synchronize() + backward_end = time.time() + backward_time = (backward_end - backward_start) / num_batches + + return { + "forward_time": forward_time * 1000, # Convert to ms + "backward_time": backward_time * 1000, # Convert to ms + "total_time": (forward_time + backward_time) * 1000, # Convert to ms + } + + +def compare_methods(args): + """ + Compare manual looping and MG GEMM implementations. + """ + device = torch.device(args.device) + + # Create models + manual_model = MoEModel( + vocab_size=args.vocab_size, + embed_dim=args.embed_dim, + hidden_dim=args.hidden_dim, + num_experts=args.num_experts, + top_k=args.top_k, + use_mg_gemm=False, + ).to(device) + + if has_mg_gemm: + mg_model = MoEModel( + vocab_size=args.vocab_size, + embed_dim=args.embed_dim, + hidden_dim=args.hidden_dim, + num_experts=args.num_experts, + top_k=args.top_k, + use_mg_gemm=True, + ).to(device) + else: + mg_model = None + + # Measure performance + logging.info("Measuring performance of manual looping method...") + manual_perf = measure_performance( + manual_model, + args.batch_size, + args.seq_len, + args.vocab_size, + args.perf_batches, + device, + ) + + if mg_model is not None: + logging.info("Measuring performance of MG GEMM method...") + mg_perf = measure_performance( + mg_model, + args.batch_size, + args.seq_len, + args.vocab_size, + args.perf_batches, + device, + ) + else: + mg_perf = {"forward_time": 0, "backward_time": 0, "total_time": 0} + + # Log results + logging.info("\n===== Performance Comparison =====") + logging.info("Model Configuration:") + logging.info(f" - Batch Size: {args.batch_size}") + logging.info(f" - Sequence Length: {args.seq_len}") + logging.info(f" - Embed Dimension: {args.embed_dim}") + logging.info(f" - Hidden Dimension: {args.hidden_dim}") + logging.info(f" - Number of Experts: {args.num_experts}") + logging.info(f" - Top-K: {args.top_k}") + logging.info("") + + logging.info("Manual Looping Method:") + logging.info(f" - Forward Time: {manual_perf['forward_time']:.2f} ms") + logging.info(f" - Backward Time: {manual_perf['backward_time']:.2f} ms") + logging.info(f" - Total Time: {manual_perf['total_time']:.2f} ms") + logging.info("") + + if mg_model is not None: + logging.info("MG GEMM Method:") + logging.info(f" - Forward Time: {mg_perf['forward_time']:.2f} ms") + logging.info(f" - Backward Time: {mg_perf['backward_time']:.2f} ms") + logging.info(f" - Total Time: {mg_perf['total_time']:.2f} ms") + logging.info("") + + # Calculate speedup + forward_speedup = ( + manual_perf["forward_time"] / mg_perf["forward_time"] + if mg_perf["forward_time"] > 0 + else 0 + ) + backward_speedup = ( + manual_perf["backward_time"] / mg_perf["backward_time"] + if mg_perf["backward_time"] > 0 + else 0 + ) + total_speedup = ( + manual_perf["total_time"] / mg_perf["total_time"] + if mg_perf["total_time"] > 0 + else 0 + ) + + logging.info("Speedup (MG GEMM vs Manual):") + logging.info(f" - Forward Speedup: {forward_speedup:.2f}x") + logging.info(f" - Backward Speedup: {backward_speedup:.2f}x") + logging.info(f" - Total Speedup: {total_speedup:.2f}x") + else: + logging.info("MG GEMM method not available.") + + +def train_model(args): + """ + Train an MoE model. + """ + device = torch.device(args.device) + + # Create model + model = MoEModel( + vocab_size=args.vocab_size, + embed_dim=args.embed_dim, + hidden_dim=args.hidden_dim, + num_experts=args.num_experts, + top_k=args.top_k, + use_mg_gemm=args.use_mg_gemm and has_mg_gemm, + ).to(device) + + # Create optimizer + optimizer = optim.Adam(model.parameters(), lr=args.lr) + + # Log model information + logging.info("Model configuration:") + logging.info(f" - Vocabulary Size: {args.vocab_size}") + logging.info(f" - Embedding Dimension: {args.embed_dim}") + logging.info(f" - Hidden Dimension: {args.hidden_dim}") + logging.info(f" - Number of Experts: {args.num_experts}") + logging.info(f" - Top-K: {args.top_k}") + logging.info(f" - Using MG GEMM: {args.use_mg_gemm and has_mg_gemm}") + + # Training loop + for epoch in range(args.epochs): + logging.info(f"\nEpoch {epoch + 1}/{args.epochs}") + + # Train + train_metrics = train_epoch( + model=model, + optimizer=optimizer, + batch_size=args.batch_size, + seq_len=args.seq_len, + vocab_size=args.vocab_size, + num_batches=args.train_batches, + device=device, + load_balance_coef=args.load_balance_coef, + ) + + # Evaluate + eval_metrics = evaluate( + model=model, + batch_size=args.batch_size, + seq_len=args.seq_len, + vocab_size=args.vocab_size, + num_batches=args.eval_batches, + device=device, + ) + + # Log metrics + logging.info( + f"Train Loss: {train_metrics['loss']:.4f} | Train Acc: {train_metrics['acc']:.4f}" + ) + logging.info( + f"Eval Loss: {eval_metrics['loss']:.4f} | Eval Acc: {eval_metrics['acc']:.4f}" + ) + logging.info(f"Epoch Time: {train_metrics['time']:.2f} seconds") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train MoE model") + + # Model parameters + parser.add_argument("--vocab_size", type=int, default=10000, help="Vocabulary size") + parser.add_argument( + "--embed_dim", type=int, default=512, help="Embedding dimension" + ) + parser.add_argument( + "--hidden_dim", type=int, default=1024, help="Hidden dimension in experts" + ) + parser.add_argument("--num_experts", type=int, default=8, help="Number of experts") + parser.add_argument( + "--top_k", type=int, default=2, help="Top-k experts to route to" + ) + + # Training parameters + parser.add_argument("--batch_size", type=int, default=32, help="Batch size") + parser.add_argument("--seq_len", type=int, default=128, help="Sequence length") + parser.add_argument("--epochs", type=int, default=3, help="Number of epochs") + parser.add_argument("--lr", type=float, default=0.001, help="Learning rate") + parser.add_argument( + "--train_batches", + type=int, + default=100, + help="Number of training batches per epoch", + ) + parser.add_argument( + "--eval_batches", type=int, default=20, help="Number of evaluation batches" + ) + parser.add_argument( + "--perf_batches", + type=int, + default=50, + help="Number of batches for performance testing", + ) + parser.add_argument( + "--load_balance_coef", + type=float, + default=0.01, + help="Load balancing loss coefficient", + ) + + # Runtime parameters + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to use (cuda or cpu)", + ) + parser.add_argument( + "--use_mg_gemm", + action="store_true", + help="Use MG GEMM implementation if available", + ) + parser.add_argument( + "--compare", + action="store_true", + help="Compare manual and MG GEMM implementations", + ) + parser.add_argument("--train", action="store_true", help="Train the model") + + args = parser.parse_args() + + # Check for CUDA + if args.device == "cuda" and not torch.cuda.is_available(): + logging.warning("CUDA not available, using CPU instead.") + args.device = "cpu" + + # Log basic information + logging.info(f"PyTorch version: {torch.__version__}") + logging.info(f"Device: {args.device}") + logging.info(f"MG GEMM available: {has_mg_gemm}") + + # Run the requested action + if args.compare: + compare_methods(args) + elif args.train: + train_model(args) + else: + # Default to comparison if no action specified + compare_methods(args) diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c90da16c282d4b8280f72ad8a0deb94484f59372 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .mg_grouped_gemm import grouped_gemm_forward +from .tma_autotuning import ALIGN_SIZE_M + +__all__ = [ + "grouped_gemm_forward", + "ALIGN_SIZE_M", +] diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py new file mode 100644 index 0000000000000000000000000000000000000000..76e0b12d882fa46ed1f11139352141f06d899f59 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py @@ -0,0 +1,299 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging + +import numpy as np +import torch + +from reference_utils import ( + analyze_tensor_differences, + compute_reference_backward, + compute_reference_forward, +) + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +# Import grouped GEMM implementations +try: + from mg_grouped_gemm import grouped_gemm_backward, grouped_gemm_forward + +except ImportError: + logging.error( + "Error importing grouped GEMM modules. Make sure the implementation files are in the correct path." + ) + raise + + +def test_forward_pass(): + """ + A simple test for the M*G grouped GEMM forward pass with detailed error handling. + + In M*G grouping: + - M dimension is partitioned into G groups (M_total = sum(M_sizes)) + - N dimension is the same for all groups + """ + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Test parameters for DeepSeek-like models + G = 1 # Number of groups + M_sizes = [ + 2048, + ] # 2048, 2048, 2048] # Group sizes (will be adjusted) + M_total = sum(M_sizes) # Total M dimension + N = 4096 # Output dimension (same for all groups) + K = 7168 # Hidden dimension + + # Create group sizes tensor + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create input and weight tensors - using float16 for higher precision + x = torch.randn(M_total, K, dtype=torch.float16, device=device) + w = torch.randn(N, K, dtype=torch.float16, device=device) + + # Log the setup + logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}") + logging.info(f"Group sizes: {m_sizes}") + logging.info(f"Input x shape: {x.shape}") + logging.info(f"Weight w shape: {w.shape}") + + # Run forward pass + logging.info("Running forward pass with grouped GEMM") + result = grouped_gemm_forward(x, w, m_sizes) + logging.info(f"Forward result shape: {result.shape}") + + # Compute reference result + logging.info("Computing reference result with PyTorch") + reference_result = compute_reference_forward(x, w, m_sizes) + + # Compare results + logging.info("Comparing with PyTorch reference") + forward_close = analyze_tensor_differences( + result, reference_result, "Forward output" + ) + + return forward_close + + except Exception as e: + logging.error(f"Test failed with error: {e}") + import traceback + + logging.error(traceback.format_exc()) + return False + + +def test_backward_pass(): + """ + A simple test for the M*G grouped GEMM backward pass with detailed error handling. + + In M*G grouping: + - M dimension is partitioned into G groups (M_total = sum(M_sizes)) + - N dimension is the same for all groups + """ + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Test parameters for DeepSeek-like models + G = 4 # Number of groups + M_sizes = [2048, 2048, 2048, 2048] # Group sizes (will be adjusted) + M_total = sum(M_sizes) # Total M dimension + N = 4096 # Output dimension (same for all groups) + K = 7168 # Hidden dimension + + # Create group sizes tensor + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create input and weight tensors - using float16 for higher precision + x = torch.randn( + M_total, K, dtype=torch.float16, device=device, requires_grad=True + ) + w = torch.randn(N, K, dtype=torch.float16, device=device, requires_grad=True) + + # Log the setup + logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}") + logging.info(f"Group sizes: {m_sizes}") + logging.info(f"Input x shape: {x.shape}") + logging.info(f"Weight w shape: {w.shape}") + + # Step 1: Run forward pass + logging.info("Running forward pass") + result = grouped_gemm_forward(x, w, m_sizes) + logging.info(f"Forward result shape: {result.shape}") + + # Create a gradient for backpropagation + grad_output = torch.randn_like(result) + logging.info(f"Created gradient with shape: {grad_output.shape}") + + # Step 2: Run backward pass directly + logging.info("Running backward pass directly") + grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes) + + # Verify gradient shapes + logging.info( + f"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}" + ) + + # Step 3: Verify gradient computation using PyTorch's autograd + logging.info("Running PyTorch reference implementation") + + # Compute reference gradients + x_ref_grad, w_ref_grad = compute_reference_backward(x, w, m_sizes, grad_output) + + # Compare gradients + logging.info("Comparing gradients with PyTorch reference") + grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x") + grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w") + + # Log overall result + if grad_x_close and grad_w_close: + logging.info("✓ SUCCESS: Gradients match the PyTorch reference") + else: + logging.error("✗ FAILURE: Gradient mismatch detected") + + return grad_x_close and grad_w_close + + except Exception as e: + logging.error(f"Test failed with error: {e}") + import traceback + + logging.error(traceback.format_exc()) + return False + + +def test_multiple_deepseek_configs(): + """ + Test multiple DeepSeek model configurations with both forward and backward pass verification. + """ + # DeepSeek configurations: (G, M, K, N) + configs = [ + (4, 8192, 7168, 4096), # Config 1 + (4, 8192, 2048, 7168), # Config 2 + (8, 4096, 7168, 4096), # Config 3 + (8, 4096, 2048, 7168), # Config 4 + ] + + results = [] + + for config_idx, (G, M, K, N) in enumerate(configs): + logging.info(f"\n\n===== Testing DeepSeek Config {config_idx+1} =====") + logging.info(f"G={G}, M={M}, K={K}, N={N}") + + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Create even group sizes + base_size = M // G + remainder = M % G + M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)] + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create input and weight tensors using float16 for higher precision + x = torch.randn( + M, K, dtype=torch.float16, device=device, requires_grad=True + ) + w = torch.randn( + N, K, dtype=torch.float16, device=device, requires_grad=True + ) + + logging.info(f"Input x shape: {x.shape}, Weight w shape: {w.shape}") + + # Run forward pass + result = grouped_gemm_forward(x, w, m_sizes) + logging.info(f"Forward result shape: {result.shape}") + + # ===== FORWARD PASS VERIFICATION ===== + # Compute reference forward result + reference_result = compute_reference_forward(x, w, m_sizes) + + # Compare forward results + forward_close = analyze_tensor_differences( + result, reference_result, "Forward output" + ) + + # ===== BACKWARD PASS VERIFICATION ===== + # Create gradient for backpropagation + grad_output = torch.randn_like(result) + + # Run backward pass + grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes) + + # Compute reference gradients + x_ref_grad, w_ref_grad = compute_reference_backward( + x, w, m_sizes, grad_output + ) + + # Compare backward results + grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x") + grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w") + + # Overall config result + backward_close = grad_x_close and grad_w_close + config_success = forward_close and backward_close + results.append( + (config_idx + 1, config_success, forward_close, backward_close) + ) + + # Log overall config result + if config_success: + logging.info(f"✓ SUCCESS: Config {config_idx+1} passed all tests!") + else: + logging.error( + f"✗ FAILURE: Config {config_idx+1} failed one or more tests" + ) + + except Exception as e: + logging.error(f"Config {config_idx+1} test failed with error: {e}") + import traceback + + logging.error(traceback.format_exc()) + results.append((config_idx + 1, False, False, False)) + + # Summary + logging.info("\n===== Test Results Summary =====") + for config_idx, overall_success, forward_success, backward_success in results: + overall_status = "✓ PASSED" if overall_success else "✗ FAILED" + forward_status = "✓ PASSED" if forward_success else "✗ FAILED" + backward_status = "✓ PASSED" if backward_success else "✗ FAILED" + + logging.info(f"Config {config_idx}: {overall_status}") + logging.info(f" - Forward pass: {forward_status}") + logging.info(f" - Backward pass: {backward_status}") + + return all(overall_success for _, overall_success, _, _ in results) + + +if __name__ == "__main__": + logging.info( + "Running verification for both forward and backward pass of M*G grouped GEMM" + ) + + # Run basic forward pass test + logging.info("\n===== Running basic forward pass test =====") + success_forward = test_forward_pass() + logging.info(f"Basic forward test {'succeeded' if success_forward else 'failed'}") + + # Run basic backward pass test + logging.info("\n===== Running basic backward pass test =====") + success_backward = test_backward_pass() + logging.info(f"Basic backward test {'succeeded' if success_backward else 'failed'}") + + # Run multiple DeepSeek configs with forward and backward verification + logging.info("\n===== Running tests for all DeepSeek configs =====") + success_configs = test_multiple_deepseek_configs() + logging.info( + f"DeepSeek configs tests {'all succeeded' if success_configs else 'had failures'}" + ) + + # Overall result + overall_success = success_forward and success_backward and success_configs + logging.info( + f"\nOverall test result: {'SUCCESS' if overall_success else 'FAILURE'}" + ) diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..37bf59f29e89b0bd3abb69d3e5d75bc14721b97b --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py @@ -0,0 +1,1304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# credit - flat index forward kernel is derived from FBGemm: +# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm + +# pyre-unsafe +import functools +import logging + +import os +import sys +from typing import Any, Dict, Optional, Tuple + +import torch + +import triton +import triton.language as tl +from triton import Config as TConfig + +from triton.runtime import driver # @manual + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from tma_autotuning import ( + ALIGN_SIZE_M, + _NV_CONFIGS, + CudaUtils, + early_config_prune, + TmaDescriptorHelper, +) + + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +# ============== Start Triton Kernels =============== + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_forward_hopper( + a_desc_ptr, + b_desc_ptr, + c_ptr, + workspace, + m_sizes, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + TMA_SIZE: tl.constexpr, + USE_EPILOGUE_SUBTILING: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + Flat index style forward kernel for Hopper. + For simplicity, we always use TMA Load and TMA Store + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = c_ptr.dtype.element_ty # output dtype + + c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store + + M_end = 0 + M_start = 0 + processed_tiles = 0 + # Size of individual weight matrix + n_size = N // G + n_start = 0 + + for g in range(G): + # Move down along groups + # reset to new M offset + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + n_start = n_size * g + + if m_size > 0: + # Process this group + + # Acquire hold on c_desc_ptr for TMA Store + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr + M_start * n_size, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[m_size, n_size], + element_ty=c_dtype, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + # tiles for this group + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + # columnwise + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + global_n_offset = (n_start + n_offset).to(tl.int32) + + for k_offset in range(0, K, BLOCK_SIZE_K): + # input block [M,K] + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + c_dtype, + ) + # weight block [N, K] + b = tl._experimental_descriptor_load( + b_desc_ptr, + [global_n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + c_dtype, + ) + + accumulator += tl.dot(a, b.T) + + # Store using TMA + + m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + + if USE_EPILOGUE_SUBTILING: + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(c_dtype) + tl._experimental_descriptor_store( + c_desc_ptr, c0, [m_offset, n_offset] + ) + c1 = acc1.to(c_dtype) + tl._experimental_descriptor_store( + c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2] + ) + else: + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_dtype), + [m_offset, n_offset], + ) + # move to next tile in group + tbidx += NUM_SMS + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_forward_tma( + a_desc_ptr, + b_desc_ptr, + c_ptr, + workspace, + m_sizes, + a_scale_ptr, + b_scale_ptr, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + USE_FP8: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + Flat index style forward kernel. + For simplicity, we always use TMA Load and TMA Store + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = c_ptr.dtype.element_ty + + c_desc_ptr = workspace + (tbidx * TMA_SIZE) + + M_end = 0 + processed_tiles = 0 + + for g in range(G): + # Move down along groups + # reset to new M offset + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + if m_size > 0: + # Process this group + n_size = N + + # TMA Store prep + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr + M_start * N, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[m_size, n_size], + element_ty=c_dtype, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + # tiles for this group + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + + for k_offset in range(0, K, BLOCK_SIZE_K): + # input block [M,K] + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + c_dtype, + ) + # weight block [N, K] + b = tl._experimental_descriptor_load( + b_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + c_dtype, + ) + + accumulator += tl.dot(a, b.T) + + # Store using TMA + + m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + # n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_dtype), + [m_offset, n_offset], + ) + + # Move to the next tile + tbidx += NUM_SMS + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_forward_no_tma( + a_ptr, + b_ptr, + c_ptr, + workspace, + m_sizes, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + Flat index style forward kernel. + For bc and Ampere, we never use TMA Load and TMA Store + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = c_ptr.dtype.element_ty + c_desc_ptr = None + + M_end = 0 + processed_tiles = 0 + + for g in range(G): + # Move down along groups + # reset to new M offset + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + if m_size > 0: + # Process this group + n_size = N + + # tiles for this group + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + + offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :] + b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :] + + for k_offset in range(0, K, BLOCK_SIZE_K): + # Load with bounds checking + a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) + b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) + + # Main matmul + accumulator += tl.dot(a, b.T) + + # Update pointers for next block + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + + # Store without TMA + offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + c = accumulator.to(c_dtype) + + tl.store( + c_ptr + + (M_start + offs_am[:, None]) * N # Row stride is N + + offs_bn[None, :], # Column offset + c, + mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size, + ) + # Move to the next tile + tbidx += NUM_SMS + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +""" +Backward pass for grouped GEMM with Triton, where grouping is M*G +We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`). +""" + + +# ---- dx flat linear indexed ---- +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_dx_tma( + grad_output_desc_ptr, # [MG, N] + w_desc_ptr, # [N, K] + grad_input_ptr, # output grad_x [MG, K] + workspace, # for TMA store + m_sizes, # group sizes [G] + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + TMA-optimized kernel for computing gradients with respect to input (dx). + For the forward pass Y = X @ W.T, the backward for input is: + grad_X = grad_Y @ W + + This maps to [MG, N] @ [N, K] -> [MG, K] + + Key differences from forward: + 1. W is used directly and not transposed + 2. The reduction dimension is now N (not K) + 3. Output is [M, K] instead of [M, N] + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = grad_input_ptr.dtype.element_ty + c_desc_ptr = workspace + (tbidx * TMA_SIZE) + + M_end = 0 + processed_tiles = 0 + + for g in range(G): + # Move down along groups - same as forward + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + if m_size > 0: + # Process this group + # tiles for this group - now producing [M, K] output + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + group_num_tiles = num_m_tiles * num_k_tiles + + # TMA Store prep for [M, K] output + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=grad_input_ptr + M_start * K, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], + global_size=[m_size, K], + element_ty=c_dtype, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + # Different tiling scheme for [M, K] output + tile_m_index = group_index % num_m_tiles + tile_k_index = group_index // num_m_tiles + + # for grad_input block [M, K] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + + # Position in full matrix + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32) + + # reduce along N dimension (instead of K in forward) + for n_offset in range(0, N, BLOCK_SIZE_N): + # grad_output block [M, N] + grad_output = tl._experimental_descriptor_load( + grad_output_desc_ptr, + [m_offset, n_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + c_dtype, + ) + + # weight block [N, K] - no transpose needed + w = tl._experimental_descriptor_load( + w_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + c_dtype, + ) + + # grad_x = grad_output @ w + # reducing along N dimension + accumulator += tl.dot(grad_output, w) + + # Store using TMA + m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + # k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32) + + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_dtype), + [m_offset, k_offset], + ) + + # Move to the next tile + tbidx += NUM_SMS + + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +# ---- dw flat linear indexed ---- + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_dw_tma( + x_desc_ptr, # input descriptor [M_total, K] + grad_output_desc_ptr, # grad_output descriptor [M_total, N] + grad_weight_ptr, # output grad_w [N, K] + workspace, # workspace for TMA store + m_sizes, # group sizes [G] + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + # tiles + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension +) -> None: + """ + Improved TMA-optimized kernel for computing gradients with respect to weights (dw). + Uses flat index structure similar to forward. + + For the forward pass Y = X @ W.T, + the backward for weights is: + grad_W = grad_Y.T @ X + + Where: + - grad_Y is [MG, N] + - X is [MG, K] + - grad_W is [N, K] + - we return [N,K] + """ + # Get thread block index l + tbidx = tl.program_id(0) + + # Get output data type + c_dtype = grad_weight_ptr.dtype.element_ty + + # Calculate number of output tiles + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + total_output_tiles = num_n_tiles * num_k_tiles + + # Process tiles in strided manner across SMs + for tile_idx in range(tbidx, total_output_tiles, NUM_SMS): + # Calculate tile indices + tile_n_idx = tile_idx % num_n_tiles + tile_k_idx = tile_idx // num_n_tiles + + # Calculate global offsets + n_offset = tile_n_idx * BLOCK_SIZE_N + k_offset = tile_k_idx * BLOCK_SIZE_K + + # Initialize accumulator for this output tile [N, K] + accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) + + # Process each group + M_end = 0 + for g in range(G): + # Get group boundaries + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + # Only process if group is non-empty + if m_size > 0: + # Process this group in chunks along the M dimension + for m_offset in range(0, m_size, BLOCK_SIZE_M): + # Calculate actual block size (handling boundary) + m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset) + + # Only process if we have actual work to do + if m_block_size > 0: + # Global offset for this chunk + m_global_offset = M_start + m_offset + + if USE_TMA_LOAD: + # Load input chunk [M_chunk, K] using TMA + x_block = tl._experimental_descriptor_load( + x_desc_ptr, + [m_global_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + c_dtype, + ) + + # Load grad_output chunk [M_chunk, N] using TMA + grad_output_block = tl._experimental_descriptor_load( + grad_output_desc_ptr, + [m_global_offset, n_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + c_dtype, + ) + + # Apply masks for valid regions + offs_m = tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < m_block_size + + # Zero out invalid elements + x_block = tl.where(m_mask[:, None], x_block, 0.0) + grad_output_block = tl.where( + m_mask[:, None], grad_output_block, 0.0 + ) + else: + # Manual load with bounds checking + offs_m = tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Create masks + m_mask = offs_m < m_block_size + n_mask = offs_n < N - n_offset + k_mask = offs_k < K - k_offset + + # Combined masks + mk_mask = m_mask[:, None] & k_mask[None, :] + mn_mask = m_mask[:, None] & n_mask[None, :] + + # Global offsets for loading + m_global_offs = m_global_offset + offs_m + + # Load x block [M_chunk, K] + x_block = tl.load( + x_desc_ptr + + m_global_offs[:, None] * K + + (k_offset + offs_k)[None, :], + mask=mk_mask, + other=0.0, + ) + + # Load grad_output block [M_chunk, N] + grad_output_block = tl.load( + grad_output_desc_ptr + + m_global_offs[:, None] * N + + (n_offset + offs_n)[None, :], + mask=mn_mask, + other=0.0, + ) + + # Compute partial contribution: grad_W += grad_Y.T @ X + # transpose grad_output for the matmul + contribution = tl.dot( + grad_output_block.to(tl.float32).T, # [N, M_chunk] + x_block.to(tl.float32), # [M_chunk, K] + ) + + # Accumulate + accumulator += contribution + + # Store the result + if USE_TMA_STORE: + # Store using TMA + tl._experimental_descriptor_store( + workspace, # TMA store descriptor + accumulator.to(c_dtype), + [n_offset, k_offset], + ) + else: + # Manual store with bounds checking + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Create masks for bounds checking + n_mask = offs_n < N - n_offset + k_mask = offs_k < K - k_offset + output_mask = n_mask[:, None] & k_mask[None, :] + + # Store the result + tl.store( + grad_weight_ptr + + (n_offset + offs_n)[:, None] * K + + (k_offset + offs_k)[None, :], + accumulator.to(c_dtype), + mask=output_mask, + ) + + +# ======== End Triton kernels ======== + +# ======== Triton wrapper functions ======== + +# ----- main forward pass wrapper ----- + + +def grouped_gemm_forward( + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + tma_size: int = 128, +) -> torch.Tensor: + """ + M*G style grouped GEMM with TMA and Float8 support. + # Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors. + + """ + if not CudaUtils.verify_tma(): + raise NotImplementedError("Grouped GEMM without TMA is not supported yet") + + G = m_sizes.shape[0] + + assert x.is_contiguous() + assert w.is_contiguous() + assert m_sizes.is_contiguous() + + # Total input size is now [M_total, K] where M_total is the sum of all group sizes + M_total, K = x.shape + N = w.shape[0] # N is now the same for all groups + + assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})" + + # Verify that all group sizes are multiples of ALIGN_SIZE_M + # This check is commented out because it will involve a GPU-CPU sync + # assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M" + + # Create output tensor with correct shape [M_total, N] + y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype) + + if M_total == 0: + return y + + NUM_SMS = CudaUtils.get_num_sms() + USE_TMA_LOAD = True + USE_TMA_STORE = True + USE_EPILOGUE_SUBTILING = False + + # TMA descriptor helper + desc_helper = None + desc_x = x + desc_w = w + workspace = None + + if USE_TMA_LOAD: + desc_helper = TmaDescriptorHelper(tma_size=tma_size) + desc_helper.init_tma_descriptor("x") + desc_helper.init_tma_descriptor("w") + desc_x = desc_helper.get_tma_descriptor_kernel_param("x") + desc_w = desc_helper.get_tma_descriptor_kernel_param("w") + + if USE_TMA_STORE: + workspace = torch.empty( + NUM_SMS * desc_helper.tma_size, + device=x.device, + dtype=torch.uint8, + ) + + def grid(META): + if USE_TMA_LOAD: + nonlocal desc_helper + desc_helper.fill_2d_tma_descriptor( + "x", + x.data_ptr(), + M_total, + K, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_K"], + x.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "w", + w.data_ptr(), + N, + K, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + w.element_size(), + ) + return (NUM_SMS,) + + M_BUCKET = triton.next_power_of_2(M_total) + + _kernel_mg_forward_hopper[grid]( + desc_x, + desc_w, + y, + workspace, + m_sizes, + G, + M_BUCKET, + N, + K, + NUM_SMS, + TMA_SIZE=tma_size, + USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING, + ) + + return y + + +# ======== Improved Backward ============= +def grouped_gemm_backward( + grad_output: torch.Tensor, + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + use_tma: bool = True, + tma_size: int = 128, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Unified backward pass for grouped GeMM with M*G grouping. + Uses optimized TMA-based implementations for both dx and dw when available. + + Args: + grad_output: Gradient of output, shape [M_total, N] + x: Input tensor from forward pass, shape [M_total, K] + w: Weight tensor from forward pass, shape [N, K] + m_sizes: Group sizes tensor, shape [G] + use_tma: Whether to try using TMA acceleration (if available) + tma_size: Size of TMA descriptor in bytes + + + Returns: + Tuple of gradients with respect to x and w: (grad_x, grad_w) + """ + logging.info("Starting unified grouped_gemm_backward") + + # do this once, seems expensive + NUM_SMS = CudaUtils.get_num_sms() + + # Basic validation + G = m_sizes.shape[0] + M_total, K_x = x.shape + M_grad, N = grad_output.shape + N_w, K_w = w.shape + + # Check dimensions + if K_x != K_w: + raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}") + if M_total != M_grad: + raise ValueError( + f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}" + ) + + # Check total M matches sum of group sizes + sum_m_sizes = m_sizes.sum().item() + if M_total != sum_m_sizes: + raise ValueError( + f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" + ) + + # Make sure inputs are contiguous + grad_output = grad_output.contiguous() + x = x.contiguous() + w = w.contiguous() + m_sizes = m_sizes.contiguous() + + # Check TMA support + can_use_tma = use_tma and CudaUtils.verify_tma() + if use_tma and not can_use_tma: + logging.info("TMA requested but not supported on this device") + use_tma = False + + # Compute grad_x using flat linear implementation + try: + logging.info(f"Computing grad_x with flat linear kernel") + + # Use TMA-optimized implementation + grad_x = grouped_gemm_dx_tma( + grad_output=grad_output, + w=w, + m_sizes=m_sizes, + num_sms=NUM_SMS, + tma_size=tma_size, + ) + + except Exception as e: + logging.error(f"Error in grad_x computation: {e}") + raise + + # Compute grad_w using flat linear style implementation + try: + logging.info(f"Computing grad_w with flat linear kernel") + + grad_w = grouped_gemm_dw_tma( + x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size + ) + except Exception as e: + logging.error(f"Error in grad_w computation: {e}") + raise + + return grad_x, grad_w + + +# ----- dx backward pass wrapper ----- + + +def grouped_gemm_dx_tma( + grad_output: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + num_sms: int = 132, + tma_size: int = 128, +) -> torch.Tensor: + """ + Optimized backward pass wrapper for computing gradient with respect to input (dx) + using TMA patterns similar to the forward pass. + + Args: + grad_output: Gradient of output, shape [M_total, N] + w: Weight tensor, shape [N, K] + m_sizes: Group sizes tensor, shape [G] + tma_size: Size of TMA descriptor + # using_fp8: Whether to use FP8 quantization + # grad_output_scale: Scale for grad_output in FP8 mode + # w_scale: Scale for w in FP8 mode + + Returns: + grad_x: Gradient with respect to x, shape [M_total, K] + """ + """ + Optimized backward pass for computing gradient with respect to input (dx) + using TMA patterns similar to the forward pass. + + Args: + grad_output: Gradient of output, shape [M_total, N] + w: Weight tensor, shape [N, K] + m_sizes: Group sizes tensor, shape [G] + tma_size: Size of TMA descriptor + using_fp8: Whether to use FP8 quantization + # grad_output_scale: Scale for grad_output in FP8 mode + # w_scale: Scale for w in FP8 mode + + Returns: + grad_x: Gradient with respect to x, shape [M_total, K] + """ + if not CudaUtils.verify_tma(): + raise NotImplementedError("Optimized dx computation requires TMA support") + + G = m_sizes.shape[0] + + assert grad_output.is_contiguous() + assert w.is_contiguous() + assert m_sizes.is_contiguous() + + M_total, N_grad = grad_output.shape + N_w, K = w.shape + + # Check dimensions + assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})" + + # Verify that the sum of m_sizes matches M_total + sum_m_sizes = m_sizes.sum().item() + assert ( + M_total == sum_m_sizes + ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" + + # Create output tensor (grad_x) with shape [M_total, K] + grad_x = torch.empty( + (M_total, K), device=grad_output.device, dtype=grad_output.dtype + ) + + NUM_SMS = num_sms # CudaUtils.get_num_sms() + USE_TMA_LOAD = True + USE_TMA_STORE = True + + # Set up TMA descriptors + desc_helper = TmaDescriptorHelper(tma_size=tma_size) + desc_helper.init_tma_descriptor("grad_output") + desc_helper.init_tma_descriptor("w") + desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output") + desc_w = desc_helper.get_tma_descriptor_kernel_param("w") + + # Allocate workspace for TMA store + workspace = torch.empty( + NUM_SMS * desc_helper.tma_size, + device=grad_output.device, + dtype=torch.uint8, + ) + + def grid(META): + # Fill TMA descriptors with appropriate dimensions + desc_helper.fill_2d_tma_descriptor( + "grad_output", + grad_output.data_ptr(), + M_total, + N_grad, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_N"], + grad_output.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "w", + w.data_ptr(), + N_w, + K, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + w.element_size(), + ) + return (NUM_SMS,) + + M_BUCKET = triton.next_power_of_2(M_total) + + # Launch the flat linear kernel for computing grad_x + _kernel_mg_dx_tma[grid]( + desc_grad_output, + desc_w, + grad_x, + workspace, + m_sizes, + G, + M_BUCKET, + N_grad, # N dimension is now the reduction dimension + K, + NUM_SMS, + USE_TMA_LOAD, + USE_TMA_STORE, + TMA_SIZE=tma_size, + ) + + return grad_x + + +# ======== dw wrapper function ========== + + +def grouped_gemm_dw_tma( + x: torch.Tensor, + grad_output: torch.Tensor, + m_sizes: torch.Tensor, + num_sms: int = 132, + tma_size: int = 128, +) -> torch.Tensor: + """ + Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA. + For the forward pass Y = X @ W.T, the backward for weights is: + grad_W = grad_Y.T @ X + + Args: + x: Input tensor, shape [M_total, K] + grad_output: Gradient of output, shape [M_total, N] + m_sizes: Group sizes tensor, shape [G] + tma_size: Size of TMA descriptor in bytes + + + Returns: + grad_w: Gradient with respect to weights, shape [N, K] + """ + # Check TMA support + has_tma_support = CudaUtils.verify_tma() + + # Get group count + G = m_sizes.shape[0] + + # Ensure contiguous tensors + x = x.contiguous() + grad_output = grad_output.contiguous() + m_sizes = m_sizes.contiguous() + + # Get dimensions + M_total, K_x = x.shape + M_grad, N = grad_output.shape + + # Check dimensions + assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})" + + # Verify that the sum of m_sizes matches M_total + sum_m_sizes = m_sizes.sum().item() + assert ( + sum_m_sizes == M_total + ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" + + # Create output tensor (grad_w) with shape [N, K] + grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype) + + NUM_SMS = num_sms + + # TODO - hardcoded for now...but should set TMA flags based on hardware support + USE_TMA_LOAD = True # has_tma_support + USE_TMA_STORE = True # has_tma_support + + # Set up TMA descriptors or direct pointers + if USE_TMA_LOAD or USE_TMA_STORE: + desc_helper = TmaDescriptorHelper(tma_size=tma_size) + + if USE_TMA_LOAD: + desc_helper.init_tma_descriptor("x") + desc_helper.init_tma_descriptor("grad_output") + x_desc = desc_helper.get_tma_descriptor_kernel_param("x") + grad_output_desc = desc_helper.get_tma_descriptor_kernel_param( + "grad_output" + ) + else: + x_desc = x + grad_output_desc = grad_output + + if USE_TMA_STORE: + desc_helper.init_tma_descriptor("grad_w") + workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w") + else: + workspace = torch.empty(1, device=x.device, dtype=torch.uint8) + else: + # If not using TMA, just use the tensors directly + x_desc = x + grad_output_desc = grad_output + workspace = torch.empty(1, device=x.device, dtype=torch.uint8) + + # M_BUCKET for grid size + M_BUCKET = triton.next_power_of_2(M_total) + + # Define grid for kernel launch + def grid(META): + if USE_TMA_LOAD or USE_TMA_STORE: + + if USE_TMA_LOAD: + desc_helper.fill_2d_tma_descriptor( + "x", + x.data_ptr(), + M_total, + K_x, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_K"], + x.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "grad_output", + grad_output.data_ptr(), + M_total, + N, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_N"], + grad_output.element_size(), + ) + + if USE_TMA_STORE: + desc_helper.fill_2d_tma_descriptor( + "grad_w", + grad_w.data_ptr(), + N, + K_x, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + grad_w.element_size(), + ) + + # Return grid size - one block per SM for balanced work distribution + return (NUM_SMS,) + + # Launch the optimized kernel + _kernel_mg_dw_tma[grid]( + x_desc, + grad_output_desc, + grad_w, + workspace, + m_sizes, + G, + M_BUCKET, + N, + K_x, + NUM_SMS, + USE_TMA_LOAD, + USE_TMA_STORE, + TMA_SIZE=tma_size, + ) + + return grad_w + + +# ======== End Backwards Wrapper Functions ============= + +# ======== PyTorch wrapper functions ======== + + +class GroupedGEMM_mg(torch.autograd.Function): + """ + Autograd function for GroupedGEMM with M*G grouping. + Supports both standard and FP8 quantized operations. + """ + + @staticmethod + def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128): + """ + Forward pass of GroupedGEMM. + + Args: + x: Input tensor, shape [M_total, K] + w: Weight tensor, shape [N, K] + m_sizes: Tensor of shape [G] containing the size of each group + use_tma: Whether to try using TMA acceleration (if available) + tma_size: Size of TMA descriptor in bytes + using_fp8: Whether to use FP8 quantization + + Returns: + Output tensor, shape [M_total, N] + """ + + # Use regular forward without quantization + output = grouped_gemm_forward( + x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False + ) + + # Save inputs and parameters for backward pass + ctx.save_for_backward(x, w, m_sizes) + ctx.use_tma = use_tma + ctx.tma_size = tma_size + + ctx.save_for_backward(x, w, m_sizes) + + return output + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass of M*G GroupedGEMM. + + Args: + grad_output: Gradient of output, shape [M_total, N] + + Returns: + Tuple of gradients: + - grad_x: Gradient with respect to x, shape [M_total, K] + - grad_w: Gradient with respect to w, shape [N, K] + - None: Gradient with respect to m_sizes (not differentiable) + - None: Gradient with respect to use_tma (not differentiable) + - None: Gradient with respect to tma_size (not differentiable) + + """ + # Retrieve saved tensors and parameters + + x, w, m_sizes = ctx.saved_tensors + + use_tma = ctx.use_tma + tma_size = ctx.tma_size + + # Compute gradients using the unified implementation + grad_x, grad_w = grouped_gemm_backward( + grad_output=grad_output, + x=x, + w=w, + m_sizes=m_sizes, + use_tma=use_tma, + tma_size=tma_size, + ) + + # Return gradients for all inputs (None for non-differentiable parameters) + return grad_x, grad_w, None, None + + +def mg_grouped_gemm( + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + use_tma: bool = True, + tma_size: int = 128, + using_fp8: bool = False, +) -> torch.Tensor: + """ + Unified differentiable grouped GEMM operation for M*G grouped GEMM. + Supports both standard precision and FP8 quantized operations. + + Args: + x: Input tensor, shape [M_total, K] + w: Weight tensor, shape [N, K] + m_sizes: Tensor of shape [G] containing the size of each group + use_tma: Whether to try using TMA acceleration (if available) + tma_size: Size of TMA descriptor in bytes + using_fp8: Whether to use FP8 quantization + + Returns: + Output tensor, shape [M_total, N] + """ + return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8) diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0835132c3ebf31f8c88a066e5bf19eed4c4acd69 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging + +import numpy as np +import torch + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +def compute_reference_forward(x, w, m_sizes): + """ + Compute reference forward pass using PyTorch operations. + + Args: + x (torch.Tensor): Input tensor of shape (M, K) + w (torch.Tensor): Weight tensor of shape (N, K) + m_sizes (torch.Tensor): Group sizes tensor of shape (G) + + Returns: + torch.Tensor: Reference output tensor of shape (M, N) + """ + result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device) + + m_start = 0 + for g in range(len(m_sizes)): + m_size = m_sizes[g].item() + if m_size > 0: + m_end = m_start + m_size + + # Extract group input + x_g = x[m_start:m_end] + + # Compute group output: y_g = x_g @ w.T + y_g = torch.matmul(x_g, w.T) + + # Store result + result[m_start:m_end] = y_g + + # Update start index + m_start = m_end + + return result + + +def compute_reference_backward(x, w, m_sizes, grad_output): + """ + Compute reference backward pass using PyTorch autograd. + + Args: + x (torch.Tensor): Input tensor of shape (M, K) + w (torch.Tensor): Weight tensor of shape (N, K) + m_sizes (torch.Tensor): Group sizes tensor of shape (G) + grad_output (torch.Tensor): Gradient tensor of shape (M, N) + + Returns: + tuple: (grad_x, grad_w) gradient tensors + """ + # Create autograd-enabled copies + x_autograd = x.detach().clone().requires_grad_(True) + w_autograd = w.detach().clone().requires_grad_(True) + + # Compute forward pass + output = compute_reference_forward(x_autograd, w_autograd, m_sizes) + + # Backpropagate + output.backward(grad_output) + + return x_autograd.grad, w_autograd.grad + + +def analyze_tensor_differences(actual, expected, name): + """ + Analyze differences between actual and expected tensors. + + Args: + actual (torch.Tensor): Actual tensor + expected (torch.Tensor): Expected tensor + name (str): Name of the tensor for logging + + Returns: + bool: True if tensors are close enough + """ + rtol = 0.5 # Relative tolerance for float16 + atol = 0.5 # Absolute tolerance for float16 + + # Analyze differences + diff = (actual - expected).abs() + max_idx = diff.argmax().item() + idx = np.unravel_index(max_idx, actual.shape) + max_diff = diff.max().item() + + logging.info(f"Largest {name} difference: {max_diff} at {idx}") + logging.info(f"Values: {actual[idx].item()} vs {expected[idx].item()}") + + is_close = torch.allclose(actual, expected, rtol=rtol, atol=atol) + + if is_close: + logging.info(f"✓ SUCCESS: {name} matches PyTorch reference") + else: + logging.error(f"✗ FAILURE: {name} mismatch detected") + + # Count zeros + zeros_actual = (actual == 0).sum().item() + zeros_expected = (expected == 0).sum().item() + logging.info( + f"Zeros in {name} (actual): {zeros_actual}/{actual.numel()} ({zeros_actual/actual.numel()*100:.2f}%)" + ) + logging.info( + f"Zeros in {name} (expected): {zeros_expected}/{expected.numel()} ({zeros_expected/expected.numel()*100:.2f}%)" + ) + + # Check for NaNs + nan_actual = torch.isnan(actual).sum().item() + if nan_actual > 0: + logging.error(f"NaN values detected in {name}: {nan_actual}") + + return is_close diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py new file mode 100644 index 0000000000000000000000000000000000000000..8fdd7a66c6afc6ca2c3d5d50d55cd9e7d1ae78f1 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# credit - TMAHelper class, AutoTuning are derived from FBGemm: +# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm + +# pyre-unsafe +import functools + +import os +import sys +from typing import Any, Dict, Optional, Tuple + +import torch + +import triton +import triton.language as tl +from triton import Config as TConfig + +from triton.runtime import driver # @manual + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + + +# ===== Supporting utils, CUDA and TMA ===== + + +class CudaUtils: + @staticmethod + def is_cuda() -> bool: + """Check if Triton is running on CUDA backend.""" + return driver.active.get_current_target().backend == "cuda" + + @staticmethod + def verify_tma() -> bool: + """Check if TMA is supported on the current device.""" + return ( + CudaUtils.is_cuda() + and torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 9 + ) + + @staticmethod + def get_num_sms() -> int: + """Get the number of streaming multiprocessors on the current device.""" + if not CudaUtils.is_cuda(): + raise RuntimeError("Triton is not running on CUDA backend") + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + return torch.cuda.get_device_properties("cuda").multi_processor_count + + +class TmaDescriptorHelper: + """Helper class for managing TMA descriptors in Triton kernels.""" + + class KernelParamWrapper: + """Wrapper to implement the TmaDescKernelParam interface.""" + + def __init__(self, desc: torch.Tensor): + self.desc = desc + + def tma_desc_cpu_ptr(self) -> int: + """Return the CPU pointer to the TMA descriptor.""" + return self.desc.data_ptr() + + def __init__(self, tma_size: int = 128): + """Initialize the TMA descriptor helper. + + Args: + tma_size: Size of the TMA descriptor in bytes + """ + if not CudaUtils.verify_tma(): + raise RuntimeError( + "TMA not supported on this device (requires Hopper or newer)" + ) + if "nv_tma_desc_type" not in dir(tl): + raise RuntimeError( + "TMA grid constant descriptors not supported in your Triton version" + ) + + self.tma_size = tma_size + self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor + self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor + self.descriptors: Dict[str, torch.Tensor] = {} + + def init_tma_descriptor(self, name: str) -> None: + """Initialize a TMA descriptor with the given name. + + Call this method outside of the lambda function for grid size. + """ + self.descriptors[name] = torch.empty( + self.tma_size, device="cpu", dtype=torch.int8 + ) + + def fill_1d_tma_descriptor( + self, name: str, ptr: int, dim: int, block_dim: int, element_size: int + ) -> None: + """Fill a 1D TMA descriptor. + + Call this method inside the lambda function for grid size. + """ + if name not in self.descriptors: + raise ValueError(f"TMA descriptor '{name}' not initialized") + + desc_x = self.descriptors[name] + if desc_x.data_ptr() % 64 != 0: + raise ValueError("TMA descriptor must be 64-byte aligned") + self.fill_1d_tma_descriptor_inner( + ptr, dim, block_dim, element_size, desc_x.data_ptr() + ) + + def fill_2d_tma_descriptor( + self, + name: str, + ptr: int, + dim1: int, + dim0: int, + block_dim1: int, + block_dim0: int, + element_size: int, + ) -> None: + """Fill a 2D TMA descriptor. + + Call this method inside the lambda function for grid size. + """ + if name not in self.descriptors: + raise ValueError(f"TMA descriptor '{name}' not initialized") + + desc_x = self.descriptors[name] + if desc_x.data_ptr() % 64 != 0: + raise ValueError("TMA descriptor must be 64-byte aligned") + self.fill_2d_tma_descriptor_inner( + ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr() + ) + + def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper: + """Get the TMA descriptor kernel parameter for the given name.""" + if name not in self.descriptors or self.descriptors[name] is None: + raise ValueError(f"TMA descriptor '{name}' not initialized") + return self.KernelParamWrapper(self.descriptors[name]) + + +# ====== Autotuning utilities ====== +ALIGN_SIZE_M = 128 + +_NV_CONFIGS = [ + triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + }, + num_stages=num_stages, + num_warps=num_warps, + num_ctas=num_ctas, + ) + for block_size_m in [ALIGN_SIZE_M, ] + for block_size_n in [64, 128, 256] + for block_size_k in [64, 128, 256] + for num_stages in [3, 4] + for num_warps in [4, 8] + for num_ctas in [1] +] + + +def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs): + device = torch.cuda.current_device() + # Check for all possible pointer parameter names + if "grad_input_ptr" in named_args: + ptr_name = "grad_input_ptr" + elif "c_ptr" in named_args: + ptr_name = "c_ptr" + elif "grad_weight_ptr" in named_args: + ptr_name = "grad_weight_ptr" + else: + raise KeyError("No recognized pointer parameter found in kernel arguments") + + if dtsize is None: + dtsize = named_args[ptr_name].element_size() + if dtype is None: + dtype = named_args[ptr_name].dtype + + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = ( + kw["BLOCK_SIZE_M"], + kw["BLOCK_SIZE_N"], + kw["BLOCK_SIZE_K"], + config.num_stages, + ) + G, M, N, K = ( + named_args["G"], + named_args["M_BUCKET"], + named_args["N"], + named_args["K"], + ) + + # 1. make sure we have enough smem + max_shared_memory = driver.active.utils.get_device_properties(device)[ + "max_shared_mem" + ] + + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory > max_shared_memory: + continue + + M_PER_GROUP = M // G + MIN_M_TILES = 64 + # 2. make sure we don't load M tiles that are too big + if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2): + continue + # 3. make sure we don't load N tiles that are too small + if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2): + continue + + num_sm = driver.active.utils.get_device_properties(device)[ + "multiprocessor_count" + ] + N_TILES = N // BLOCK_N + MIN_N_TILES = 64 + # 4. make sure we don't load N tiles that are too big + if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm: + continue + # 5. make sure we don't load N tiles that are too small + if BLOCK_N < 128 and M * N_TILES > 2 * num_sm: + continue + # 6. make sure K can be evenly divided + if K % BLOCK_K != 0: + continue + + pruned_configs.append(config) + + return pruned_configs + + +# ======== End Autotuning utilities ======== diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_backwards.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_backwards.py new file mode 100644 index 0000000000000000000000000000000000000000..becb761d83bb8d55ef4cef95ee291262de6e4761 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_backwards.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging +import unittest +from typing import Tuple + +import torch +import torch.nn as nn + +from mg_grouped_gemm import ( + grouped_gemm_backward, + grouped_gemm_dw_tma, + grouped_gemm_dx_tma, + grouped_gemm_forward, + mg_grouped_gemm, +) + +from reference_utils import ( + analyze_tensor_differences, + compute_reference_backward, + compute_reference_forward, +) + + +class TestMG_GroupedGEMM_Backward(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(2020) # Set seed for reproducibility + + def _run_grouped_gemm_backward_test( + self, + shape: Tuple[int, int, int, int], + device: torch.device, + dtype: torch.dtype = torch.bfloat16, + atol: float = 1e-5, + rtol: float = 1.6e-2, + ) -> None: + G, M, N, K = shape + # Set up inputs for forward pass + # In M*G grouping, input is [M*G, K] and weights are [N, K] + a = torch.randn(M * G, K, dtype=dtype, device=device, requires_grad=True) + b = torch.randn(N, K, dtype=dtype, device=device, requires_grad=True) + + # Create equal-sized groups for simplicity + m_size = M + m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32) + + # Run forward pass with our implementation + result = grouped_gemm_forward(a, b, m_sizes) + # Ensure result has correct shape + self.assertTrue(result.shape == (M * G, N)) + + # Compute expected result using reference implementation + expected_result = compute_reference_forward(a, b, m_sizes) + + # Verify forward pass correctness + forward_close = analyze_tensor_differences( + result, expected_result, "Forward output" + ) + self.assertTrue(forward_close) + + # Create a gradient for backpropagation + grad_output = torch.randn_like(result) + + # Compute gradients using our custom backward implementation + grad_a, grad_b = grouped_gemm_backward(grad_output, a, b, m_sizes) + + # Compute expected gradients using reference implementation + expected_grad_a, expected_grad_b = compute_reference_backward( + a, b, m_sizes, grad_output + ) + + # Verify gradient correctness + grad_a_close = analyze_tensor_differences(grad_a, expected_grad_a, "grad_x") + grad_b_close = analyze_tensor_differences(grad_b, expected_grad_b, "grad_w") + + self.assertTrue(grad_a_close) + self.assertTrue(grad_b_close) + + def test_MG_grouped_gemm_backward_bf16(self) -> None: + for G in (1, 8, 16): + for M in (512, 1024): + print(f"Testing BF16 M*G GroupGeMM Backward with G={G}, M={M}") + self._run_grouped_gemm_backward_test( + (G, M, 1024, 1024), + torch.device("cuda"), + dtype=torch.float16, + atol=1e-2, + rtol=1e-2, + ) + + def test_MG_grouped_gemm_backward_deepseek_shapes(self) -> None: + """Test backward pass with shapes from Deepseek model.""" + deepseek_shapes = [ + (4, 2048, 4096, 7168), # G, M, N, K + (4, 2048, 7168, 2048), + (8, 512, 4096, 7168), + (8, 512, 7168, 2048), + ] + + device = torch.device("cuda") + + for shape in deepseek_shapes: + G, M, N, K = shape + print( + f"Testing BF16 M*G Deepseek Backward shape: G={G}, M={M}, N={N}, K={K}" + ) + self._run_grouped_gemm_backward_test( + shape, device, dtype=torch.float16, atol=1e-2, rtol=1e-2 + ) + + def test_MG_dx(self) -> None: + """Test specifically the dx (gradient w.r.t. input) computation.""" + G, M, N, K = 4, 512, 1024, 2048 + device = torch.device("cuda") + dtype = torch.bfloat16 + + # Set up inputs + a = torch.randn(M * G, K, dtype=dtype, device=device, requires_grad=True) + b = torch.randn(N, K, dtype=dtype, device=device, requires_grad=True) + + # Create equal-sized groups + m_size = M + m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32) + + # Forward pass + result = grouped_gemm_forward(a, b, m_sizes) + + # Create gradient for backward + grad_output = torch.randn_like(result) + + # Compute gradient using our optimized function + grad_a, _ = grouped_gemm_backward(grad_output, a, b, m_sizes) + + # Compute expected gradient using reference implementation + expected_grad_a, _ = compute_reference_backward(a, b, m_sizes, grad_output) + + # Verify gradient + dx_close = analyze_tensor_differences(grad_a, expected_grad_a, "grad_a (dx)") + self.assertTrue(dx_close) + + def test_MG_dw(self) -> None: + """Test specifically the dw (gradient w.r.t. weights) computation.""" + G, M, N, K = 4, 512, 1024, 2048 + device = torch.device("cuda") + dtype = torch.bfloat16 + + # Set up inputs + a = torch.randn(M * G, K, dtype=dtype, device=device, requires_grad=True) + b = torch.randn(N, K, dtype=dtype, device=device, requires_grad=True) + + # Create equal-sized groups + m_size = M + m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32) + + # Forward pass + result = grouped_gemm_forward(a, b, m_sizes) + + # Create gradient for backward + grad_output = torch.randn_like(result) + + # Compute gradient using our optimized function + _, grad_b = grouped_gemm_backward(grad_output, a, b, m_sizes) + + # Compute expected gradient using reference implementation + _, expected_grad_b = compute_reference_backward(a, b, m_sizes, grad_output) + + # Verify gradient + dw_close = analyze_tensor_differences(grad_b, expected_grad_b, "grad_b (dw)") + self.assertTrue(dw_close) diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py new file mode 100644 index 0000000000000000000000000000000000000000..2429432d756ae4d5bb6f91a6108c7ba8a4b9c627 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging +import unittest +from typing import Tuple + +import torch +import torch.nn as nn + +from mg_grouped_gemm import grouped_gemm_forward + + +class TestMG_GroupedGEMM(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(2020) + + def _run_grouped_gemm_test( + self, + shape: Tuple[int, int, int, int], + device: torch.device, + dtype: torch.dtype = torch.bfloat16, + atol: float = 1e-5, + rtol: float = 1.6e-2, + ) -> None: + G, M, N, K = shape + # In M*G grouping, input is [M*G, K] and weights are [N*G, K] + a = torch.randn(M * G, K, dtype=dtype, device=device) + b = torch.randn(N * G, K, dtype=dtype, device=device) + + # Create equal-sized groups for simplicity + m_size = M + m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32) + + result = grouped_gemm_forward(a, b, m_sizes) + self.assertTrue(result.shape == (M * G, N)) + + expected_result = torch.zeros(M * G, N, dtype=dtype, device=device) + m_start = 0 + for g in range(G): + m_end = m_start + m_sizes[g] + b_slice = b[N * g : N * (g+1), :] + expected_result[m_start:m_end, :] = a[m_start:m_end, :] @ b_slice.T + m_start = m_end + + # Convert result to match input dtype if needed + result = result.to(dtype) + torch.testing.assert_close(result, expected_result, atol=atol, rtol=rtol) + + def test_MG_grouped_gemm_bf16(self) -> None: + for G in (1, 4, 16): + for M in (128, 512, 1024): + print(f"Testing BF16 M*G GroupGeMM with G={G}, M={M}") + self._run_grouped_gemm_test( + (G, M, 1024, 1024), + torch.device("cuda"), + dtype=torch.bfloat16, + atol=1e-5, + rtol=1.6e-2, + ) + + def test_MG_grouped_gemm_deepseek_shapes(self) -> None: + """Test with shapes from Deepseek model.""" + deepseek_shapes = [ + (4, 2048, 4096, 7168), # G, M, N, K + (4, 2048, 7168, 2048), + (8, 512, 4096, 7168), + (8, 512, 7168, 2048), + ] + + device = torch.device("cuda") + + for shape in deepseek_shapes: + G, M, N, K = shape + print(f"Testing BF16 M*G Deepseek shape: G={G}, M={M}, N={N}, K={K}") + self._run_grouped_gemm_test( + shape, device, dtype=torch.bfloat16, atol=1e-5, rtol=1.6e-2 + ) diff --git a/torchtitan/experiments/llama4/README.md b/torchtitan/experiments/llama4/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d912caaa80b12157f24693b63f8a9a5bd75c717f --- /dev/null +++ b/torchtitan/experiments/llama4/README.md @@ -0,0 +1,29 @@ +**The Llama 4 folder is still under development.** + +#### Available features +- Llama 4 model definition (text-only), including the MoE architecture with token-choice routing +- Basic FSDP, TP, PP, CP support +- DCP checkpoint conversion scripts + +#### Download Llama 4 tokenizer +```bash +# Llama 4 tokenizer.model +python scripts/download_tokenizer.py --repo_id meta-llama/Llama-4-Scout-17B-16E --tokenizer_path "" --hf_token=... +``` + +#### To be added +- Modeling + - iRoPE implementation + - load balance loss for token-choice MoE + - alternative expert-choice MoE + - multimodal support +- Kernel integration + - efficient bfloat16 GroupedGEMM kernels (from PyTorch core) + - efficient float8 GroupedGEMM kernels (from torchao) +- Parallelism + - performant TP implementation and torch.compile support for MoE layers + - Context Parallel support for FlexAttention, iRoPE, and multimodal inputs + - Expert Parallel support +- Testing + - perfomance and loss converging tests + - CI integration diff --git a/torchtitan/experiments/llama4/__init__.py b/torchtitan/experiments/llama4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0907e1892fa3840be81e7eefe12047d2e1cf1661 --- /dev/null +++ b/torchtitan/experiments/llama4/__init__.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer +from torchtitan.models.llama3 import pipeline_llama +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .infra.parallelize_llama import parallelize_llama +from .model.args import TransformerModelArgs +from .model.model import Transformer + +__all__ = [ + "TransformerModelArgs", + "Transformer", + "llama4_configs", +] + + +llama4_configs = { + "debugmodel": TransformerModelArgs( + dim=256, + n_layers=8, + n_heads=16, + rope_theta=500000, + ), + "17bx16e": TransformerModelArgs( + dim=5120, + n_layers=48, + n_heads=40, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=2048, + rope_theta=500000, + num_experts=16, + interleave_moe_layer_step=1, + ), + "17bx128e": TransformerModelArgs( + dim=5120, + n_layers=48, + n_heads=40, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=2048, + rope_theta=500000, + num_experts=128, + ), +} + + +register_train_spec( + TrainSpec( + name="llama4", + cls=Transformer, + config=llama4_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_tiktoken_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc b/torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a809aa182a2014b405740145438b653c2b3b9278 Binary files /dev/null and b/torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc differ diff --git a/torchtitan/experiments/llama4/infra/__pycache__/parallelize_llama.cpython-312.pyc b/torchtitan/experiments/llama4/infra/__pycache__/parallelize_llama.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b74c5d0c35ecdd735b252821ee2244406ecc9d6 Binary files /dev/null and b/torchtitan/experiments/llama4/infra/__pycache__/parallelize_llama.cpython-312.pyc differ diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/experiments/llama4/infra/expert_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..63945e8cd6a3f9509ca34c779b09a2f2f7581c2f --- /dev/null +++ b/torchtitan/experiments/llama4/infra/expert_parallel.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from functools import partial +from typing import Optional, Tuple + +import torch.nn as nn +from torch.distributed.tensor import ( + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + Partial, + Replicate, + Shard, +) +from torch.distributed.tensor.parallel import ParallelStyle +from torch.distributed.tensor.placement_types import Placement + + +# implementation of Tensor Parallel on the non-shared experts in MoE +class TensorParallel(ParallelStyle): + def __init__( + self, + *, + input_layouts: Optional[Tuple[Optional[Placement]]] = None, + output_layout: Optional[Placement] = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layouts = input_layouts or (Replicate(), None) + self.output_layout = output_layout or Partial() + self.desired_input_layouts = (Replicate(), None) + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): + # TODO: figure out dynamo support for instance method and switch this to instance method + + # annotate module input placements/sharding with input_layouts + input_tensor, input_layout, desired_input_layout = ( + inputs[0], + input_layouts[0], + desired_input_layouts[0], + ) + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, (input_layout,), run_check=False + ) + + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute( + placements=(desired_input_layout,), async_op=True + ) + return (input_tensor, *inputs[1:]) + + def _partition_fn(self, name, module, device_mesh): + module.register_parameter( + "w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)])) + ) # Column-wise sharding + module.register_parameter( + "w2", + nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(1)])), + ) # Row-wise sharding + module.register_parameter( + "w3", + nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(2)])), + ) # Column-wise sharding + + @staticmethod + def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): + if outputs.placements != (output_layout,): + outputs = outputs.redistribute(placements=(output_layout,), async_op=True) + # back to local tensor + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + self._partition_fn, + partial( + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts + ), + partial(self._prepare_output_fn, self.output_layout, self.use_local_output), + ) + + +# NOTE: This is to achieve replicate computation on the gate module in the MoE router. +# It does nothing other than (1) setting the module parameters as DTensors on the given mesh +# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back. +# TODO: The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh, +# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation. +class NoParallel(ParallelStyle): + def __init__( + self, + *, + input_layout: Optional[Placement] = None, + output_layout: Optional[Placement] = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layout = input_layout or Replicate() + self.output_layout = output_layout or Replicate() + self.desired_input_layout = Replicate() + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh): + # annotate module input placements/sharding with input_layouts + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, (input_layout,), run_check=False + ) + + if input_layout != desired_input_layout: + input_tensor = input_tensor.redistribute( + placements=(desired_input_layout,), async_op=True + ) + return (input_tensor, *inputs[1:]) + + @staticmethod + def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): + if outputs.placements != (output_layout,): + outputs = outputs.redistribute(placements=(output_layout,), async_op=True) + # back to local tensor + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + None, + partial( + self._prepare_input_fn, self.input_layout, self.desired_input_layout + ), + partial(self._prepare_output_fn, self.output_layout, self.use_local_output), + ) diff --git a/torchtitan/experiments/llama4/infra/parallelize_llama.py b/torchtitan/experiments/llama4/infra/parallelize_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..72842fc04f896896772beca4ec7b50b0ce66a7b5 --- /dev/null +++ b/torchtitan/experiments/llama4/infra/parallelize_llama.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims + +from torchtitan.models.llama3.parallelize_llama import ( + apply_ac, + apply_compile, + apply_ddp, + apply_fsdp, + apply_tp, +) +from torchtitan.tools.logging import logger + + +def parallelize_llama( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + if parallel_dims.tp_enabled: + if ( + job_config.parallelism.enable_async_tensor_parallel + and not job_config.training.compile + ): + raise RuntimeError("Async TP requires --training.compile") + + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + # For now, float8 all-gather with TP is only supported for tensorwise + # float8 scaling recipes. For rowwise recipes, we use regular TP and + # all-gather happens in high precision. + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + + apply_tp( + model, + world_mesh["tp"], + loss_parallel=parallel_dims.loss_parallel_enabled, + enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, + ) + + apply_moe_tp(model, world_mesh["tp"]) + + if job_config.activation_checkpoint.mode != "none": + if ( + job_config.activation_checkpoint.mode == "selective" + and job_config.model.use_flex_attn + ): + raise ValueError( + "FlexAttention is not compatible with selective AC yet. " + "See https://github.com/pytorch/pytorch/issues/147879" + ) + apply_ac(model, job_config.activation_checkpoint) + + # turn on per-TransformerBlock compile after AC wrapping and before FSDP + if job_config.training.compile: + apply_compile(model) + + # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE + torch._dynamo.config.capture_scalar_outputs = True + + if ( + parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled + ): # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + + apply_fsdp( + model, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + world_mesh, + enable_compile=job_config.training.compile, + enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, + ) + + return model + + +def apply_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, +): + from torch.distributed.tensor import Partial, Replicate, Shard + from torch.distributed.tensor.parallel import ( + parallelize_module, + PrepareModuleInputOutput, + ) + + from .expert_parallel import NoParallel, TensorParallel + + for _, transformer_block in model.layers.items(): + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_input=True, + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # replicate computation for the router + "moe.router.gate": NoParallel(), + # input Replicate, output Partial + "moe.experts": TensorParallel(), + "moe.shared_expert": TensorParallel(), + } + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=moe_layer_plan, + ) diff --git a/torchtitan/experiments/llama4/model/__pycache__/args.cpython-312.pyc b/torchtitan/experiments/llama4/model/__pycache__/args.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..785a9fd712bccfae1a87f7cc33fc43c6889543bd Binary files /dev/null and b/torchtitan/experiments/llama4/model/__pycache__/args.cpython-312.pyc differ diff --git a/torchtitan/experiments/llama4/model/__pycache__/model.cpython-312.pyc b/torchtitan/experiments/llama4/model/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36c4e03033b9ca0eb9d9c97abc5ae6d4e5deb3cd Binary files /dev/null and b/torchtitan/experiments/llama4/model/__pycache__/model.cpython-312.pyc differ diff --git a/torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc b/torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92de3664de681488246db65f66835f7863e7576f Binary files /dev/null and b/torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc differ diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py new file mode 100644 index 0000000000000000000000000000000000000000..7e5757f08bced3ce6d5f92f343fd6e4beebaf400 --- /dev/null +++ b/torchtitan/experiments/llama4/model/args.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Optional + +from torch import nn +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig + +from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.tools.logging import logger + + +@dataclass +class TransformerModelArgs(BaseModelArgs): + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + + max_seq_len: int = 2048 + # If `True`, then each transformer block init uses its layer ID, and if + # `False`, each uses the total number of transformer blocks + depth_init: bool = True + norm_type: str = "rmsnorm" + + use_flex_attn: bool = False + attn_mask_type: str = "causal" + eos_id: int = 0 + + # MoE args + moe_enabled: bool = True + num_experts: int = 8 + use_shared_expert: bool = True + auto_scale_hidden_dim: bool = True + # frequency of using MoE layer instead of feedforward layer in a transformer block + interleave_moe_layer_step: int = 2 + # token-choice + top_k: int = 1 + + def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: + self.norm_type = job_config.model.norm_type + self.vocab_size = tokenizer.n_words + self.max_seq_len = job_config.training.seq_len + self.use_flex_attn = job_config.model.use_flex_attn + + def get_nparams_and_flops( + self, model: nn.Module, seq_len: int + ) -> tuple[int, float]: + nparams_embedding = 0 + nparams_moe_router = 0 + nparams_shared_expert = 0 + nparams_experts = 0 + nparams_dense = 0 + + for name, p in model.named_parameters(): + if "embedding" in name: + nparams_embedding += p.numel() + nparams_dense += p.numel() + elif "moe.shared_expert" in name: + nparams_shared_expert += p.numel() + elif "moe.router" in name: + nparams_moe_router += p.numel() + elif "moe.experts" in name: + nparams_experts += p.numel() + else: + nparams_dense += p.numel() + + nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts + nparams = nparams_dense + nparams_sparse + nparams_sparse_active = ( + nparams_moe_router + + nparams_shared_expert + + nparams_experts * self.top_k // self.num_experts + ) + + logger.info( + f"Total parameter count: dense {nparams_dense:,}, " + f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}" + ) + + l, h, q, t = ( + self.n_layers, + self.n_heads, + self.dim // self.n_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + num_flops_per_token = ( + 6 * (nparams_dense - nparams_embedding + nparams_sparse_active) + + 12 * l * h * q * t + ) + + return nparams, num_flops_per_token diff --git a/torchtitan/experiments/llama4/model/model.py b/torchtitan/experiments/llama4/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..39be49a5b0e645cc67b04a3e0957d057c3ec40d2 --- /dev/null +++ b/torchtitan/experiments/llama4/model/model.py @@ -0,0 +1,466 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F +from torch import nn + +from torchtitan.models.attention import build_attention, init_attention_mask +from torchtitan.models.norms import build_norm +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import TransformerModelArgs +from .moe import MoE + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert ndim > 1 + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """ + Multi-head attention module. + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.dim // model_args.n_heads + + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=False + ) + self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + output = self.sdpa(xq, xk, xv) + + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: float | None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TransformerBlock(nn.Module): + """ + TransformerBlock Module + + Args: + layer_id (int): Identifier for the layer. + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + + def __init__(self, layer_id: int, model_args: TransformerModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + self.attention = Attention(model_args) + + # use MoE layer for every interleave_moe_layer_step FFN layers + self.moe_enabled = ( + model_args.moe_enabled + and (layer_id + 1) % model_args.interleave_moe_layer_step == 0 + ) + if self.moe_enabled: + self.moe = MoE(model_args) + else: + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + + self.layer_id = layer_id + self.num_layers = model_args.n_layers + + self.attention_norm = build_norm( + model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps + ) + self.ffn_norm = build_norm( + model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps + ) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention(self.attention_norm(x), freqs_cis) + if self.moe_enabled: + out = h + self.moe(self.ffn_norm(h)) + else: + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + if self.moe_enabled: + self.moe.init_weights(self.weight_init_std) + else: + self.feed_forward.init_weights(self.weight_init_std) + + +class Transformer(nn.Module, ModelProtocol): + """ + Transformer Module + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + model_args (TransformerModelArgs): Model configuration arguments. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (RMSNorm): Layer normalization for the model output. + output (ColumnParallelLinear): Linear layer for final output. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + """ + + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + self.eos_id = model_args.eos_id + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + # TODO persistent should be set to false, since this buffer can be recomputed. + # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, + # compile or pipeline-tracer will not correctly handle non-persistent buffers, + # so we need to fix that. (2) if we initialize pipeline-parallel models from + # a seed checkpoint rather than calling init_weights, we need freqs_cis to be + # initialized by the checkpoint, or we need to add a separate initializer for + # just the non-persistent buffers that is called after loading checkpoints. + self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + + self.norm = build_norm( + model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps + ) + + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + self.init_weights() + + def init_weights( + self, + buffer_device: torch.device | None = None, + ): + """ + [Note: On ``init_weights`` vs. ``reset_parameters``] + Modules may define ``reset_parameters`` to initialize parameter values. + ``reset_parameters`` is meant to only initialize directly owned + parameters/buffers, not those of their child modules, and it can be + used to give the initial values for these tensors. + Separately, users may want custom initialization for their modules, + different from that in ``reset_parameters``. For this, we define + ``init_weights``. We only call it in the constructor of this + ``Transformer`` root module to avoid reinitializing tensors. + """ + buffer_device = buffer_device or self.freqs_cis.device + with torch.device(buffer_device): + self.freqs_cis = self._precompute_freqs_cis() + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights() + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_freqs_cis(self) -> torch.Tensor: + return precompute_freqs_cis( + self.model_args.dim // self.model_args.n_heads, + # Need to compute until at least the max token limit for generation + # TODO: explain in docs/composability.md why we removed the 2x + # relaxing in our CP enablement PR + self.model_args.max_seq_len, + self.model_args.rope_theta, + ) + + def forward(self, tokens: torch.Tensor): + """ + Perform a forward pass through the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices. + + Returns: + torch.Tensor: Output logits after applying the Transformer model. + + """ + # TODO: We will to change forward() signature to allow tokens to + # be always passed in. + if self.model_args.use_flex_attn: + init_attention_mask(tokens, eos_id=self.eos_id) + + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + + h = self.norm(h) if self.norm else h + output = self.output(h) if self.output else h + return output + + @classmethod + def from_model_args(cls, model_args: TransformerModelArgs) -> "Transformer": + """ + Initialize a Transformer model from a TransformerModelArgs object. + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Returns: + Transformer: Transformer model. + + """ + return cls(model_args) diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..0b925b36207875dedc13a16be10890c3671cdabb --- /dev/null +++ b/torchtitan/experiments/llama4/model/moe.py @@ -0,0 +1,228 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from torch import nn + +from .args import TransformerModelArgs + + +class GroupedExperts(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_experts: int, + ): + super().__init__() + self.num_experts = num_experts + self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) + self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + + def forward( + self, + x: torch.Tensor, + num_local_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_local_tokens_per_expert is not None: + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x, + split_size_or_sections=num_local_tokens_per_expert.tolist(), + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + w1, w2, w3 = ( + self.w1[expert_idx], + self.w2[expert_idx], + self.w3[expert_idx], + ) + h = F.silu(torch.matmul(x_expert, w1)) + h = h * torch.matmul(x_expert, w3) + h = torch.matmul(h, w2) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # TODO:optimize with GroupedGEMM + # https://github.com/pytorch/pytorch/pull/150374 + # _gouped_mm requires shapes to be multiple of 8 + # offsets = torch.cumsum(num_local_tokens_per_expert, dim=0, dtype=torch.int32) + # h = F.silu(torch._grouped_mm(x, self.w1.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)) + # h = h * torch._grouped_mm(x, self.w3.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16) + # out = torch._grouped_mm(h, self.w2.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = F.silu(torch.bmm(x, self.w1)) + h = h * torch.bmm(x, self.w3) + # out shape (num_experts, tokens_per_expert, dim) + out = torch.bmm(h, self.w2) + return out + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std) + + +class TokenChoiceTopKRouter(nn.Module): + """This class implements token-choice routing. In token-choice top-K routing, each token is + routed to top K experts based on the router scores. + + Args: + gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts). + dim (int): Dimension of input tokens. + num_experts (int): Number of experts in each moe layer. + top_k (int): Number of experts each token will be routed to in token-choice routing. + use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False. + """ + + def __init__( + self, + dim: int, + num_experts: int, + top_k: int, + use_sigmoid: bool = False, + ): + super().__init__() + self.gate = nn.Linear(dim, num_experts, bias=False) + self.num_experts = num_experts + self.top_k = top_k + self.use_sigmoid = use_sigmoid + + def forward( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. + + Returns: + routed_input (torch.Tensor): + Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. + token_indices (torch.Tensor): + Token indices for routed_input with shape ``(bs*slen*top_k,)``. + num_local_tokens_per_expert (torch.Tensor): + Number of tokens assigned to each expert with shape ``(num_experts,)``. + """ + # scores shape (bs*slen, num_experts) + scores = self.gate(x) + + # By default, sigmoid or softmax is performed in float32 to avoid loss explosion + if self.use_sigmoid: + scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype) + else: + scores = F.softmax(scores.to(torch.float32), dim=1).to(x.dtype) + + # top scores shape (bs*slen, top_k) + top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1) + # top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype) + + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward + num_local_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1), + bins=self.num_experts, + min=0, + max=self.num_experts, + ) + # token_indices_experts_sorted shape (bs*slen*top_k,) + token_indices_experts_sorted = torch.argsort( + selected_experts_indices.view(-1), stable=True + ) + top_scores = top_scores.view(-1)[token_indices_experts_sorted] + token_indices_experts_sorted = token_indices_experts_sorted // self.top_k + + return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) + + +# TODO: implement load balancing auxiliary loss for token-choice routing +class MoE(nn.Module): + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + dim = model_args.dim + hidden_dim = 4 * model_args.dim + ffn_dim_multiplier = model_args.ffn_dim_multiplier + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + + num_experts = model_args.num_experts + + hidden_dim_denom = 1 + if model_args.auto_scale_hidden_dim: + hidden_dim_denom = model_args.top_k + int(model_args.use_shared_expert) + + if model_args.auto_scale_hidden_dim: + hidden_dim = int(hidden_dim / hidden_dim_denom) + hidden_dim += -hidden_dim % model_args.multiple_of + + self.experts = GroupedExperts( + dim=dim, hidden_dim=hidden_dim, num_experts=num_experts + ) + self.router = TokenChoiceTopKRouter( + dim=dim, num_experts=num_experts, top_k=model_args.top_k + ) + self.shared_expert = ( + GroupedExperts(dim=dim, hidden_dim=hidden_dim, num_experts=1) + if model_args.use_shared_expert + else None + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. + + Returns: + out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. + """ + bs, slen, dim = x.shape + # top_scores and selected_indices shape (bs*slen*top_k,) + # num_local_tokens_per_expert shape (num_experts,) + ( + top_scores, + token_indices, + num_local_tokens_per_expert, + ) = self.router(x.reshape(bs * slen, dim)) + + # shape (bs*slen*top_k, dim) + token_indices = token_indices.reshape(-1, 1).expand(-1, dim) + + # shape (bs*slen*top_k, dim) + routed_input = torch.gather( + x.view(-1, dim), + dim=0, + index=token_indices, + ) + routed_input = routed_input * top_scores.reshape(-1, 1) + + # shape (bs*slen*top_k, dim) + routed_output = self.experts(routed_input, num_local_tokens_per_expert) + + # shared expert + if self.shared_expert is not None: + out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( + bs * slen, dim + ) + else: + out = torch.zeros_like(x.reshape(bs * slen, dim)) + + out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + out = out.reshape(bs, slen, dim) + return out + + def init_weights(self, init_std: float): + self.experts.init_weights(init_std) + self.router.init_weights(init_std) + if self.shared_expert is not None: + self.shared_expert.init_weights(init_std) diff --git a/torchtitan/experiments/llama4/scripts/REAME.md b/torchtitan/experiments/llama4/scripts/REAME.md new file mode 100644 index 0000000000000000000000000000000000000000..c4cd6c32412522eb6efb0fa93eb09344b69ad3cc --- /dev/null +++ b/torchtitan/experiments/llama4/scripts/REAME.md @@ -0,0 +1,17 @@ +## How to convert a Llama 4 checkpoint for use in torchtitan + +To continue training from an existing model checkpoint, the checkpoint must be in the DCP format expected by the checkpoint manager. +This folder contains the scripts for converting officially released Llama 4 checkpoints into the expected DCP format, from original Meta format, or from HuggingFace format, using GPUs. + +#### Example usage + +From Meta format: +```bash +CONFIG_FILE=../train_configs/llama4_16.toml ./convert_meta_to_dcp.sh --checkpoint.enable_checkpoint --checkpoint.convert_path=[checkpoint_folder] --checkpoint.convert_load_every_n_ranks=8 +``` + + +From HuggingFace format: +```bash +CONFIG_FILE=../train_configs/llama4_16.toml ./convert_hf_to_dcp_with_gpus.sh --checkpoint.enable_checkpoint --checkpoint.convert_path=[checkpoint_folder] --checkpoint.convert_load_every_n_ranks=8 +``` diff --git a/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..99eb36ac6ffa8e546d8895358978e937088f7ee1 --- /dev/null +++ b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py @@ -0,0 +1,545 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import json +import math +import os +import pprint +import time +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Optional + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor, Shard +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset +from torchtitan.components.checkpoint import MODEL +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import init_logger, logger +from torchtitan.train import Trainer + + +def extract_layer_number(s): + import re + + match = re.search(r"layers\.(\d+)", s) + if match: + return int(match.group(1)) + else: + return None + + +def convert_to_titan_fqns(fqn: str) -> list[str]: + # From the stored checkpoint keys to TorchTitan keys. + if "language_model." not in fqn: + # TODO: Not support video model yet + return [fqn] + + layer = extract_layer_number(fqn) + + if layer is None: + if "embed_tokens.weight" in fqn: + return ["tok_embeddings.weight"] + elif "norm.weight" in fqn: + return ["norm.weight"] + elif "lm_head.weight" in fqn: + return ["output.weight"] + else: + raise ValueError(f"Unknown fqn {fqn}") + + if "feed_forward.experts.down_proj" in fqn: + return [f"layers.{layer}.moe.experts.w2"] + elif "feed_forward.experts.gate_up_proj" in fqn: + return [f"layers.{layer}.moe.experts.w1", f"layers.{layer}.moe.experts.w3"] + elif "feed_forward.router.weight" in fqn: + return [f"layers.{layer}.moe.router.gate.weight"] + elif "feed_forward.shared_expert.down_proj.weight" in fqn: + return [f"layers.{layer}.moe.shared_expert.w2"] + elif "feed_forward.shared_expert.gate_proj.weight" in fqn: + return [f"layers.{layer}.moe.shared_expert.w3"] + elif "feed_forward.shared_expert.up_proj.weight" in fqn: + return [f"layers.{layer}.moe.shared_expert.w1"] + elif "input_layernorm.weight" in fqn: + return [f"layers.{layer}.ffn_norm.weight"] + elif "self_attn.k_proj" in fqn: + return [f"layers.{layer}.attention.wk.weight"] + elif "self_attn.o_proj" in fqn: + return [f"layers.{layer}.attention.wo.weight"] + elif "self_attn.q_proj" in fqn: + return [f"layers.{layer}.attention.wq.weight"] + elif "self_attn.v_proj" in fqn: + return [f"layers.{layer}.attention.wv.weight"] + elif "post_attention_layernorm.weight" in fqn: + return [f"layers.{layer}.attention_norm.weight"] + else: + raise ValueError(f"Unknown fqn {fqn}") + + +def convert_to_hf_shape(fqn: str, titan_fqns: list[str], dtensor: DTensor) -> list[str]: + if "feed_forward.experts.gate_up_proj" in fqn: + assert len(titan_fqns) == 2 + shape = dtensor.shape + return torch.Size(list(shape[:-1]) + [shape[-1] * 2]) + elif "shared_expert" in fqn: + s = dtensor.shape + # TODO: this is not right but I have to do this to load the checkpoint. + return torch.Size((s[2], s[1])) + return dtensor.shape + + +def convert_to_titan_tensors(fqn: str, full_tensor: torch.Tensor) -> torch.Tensor: + if "feed_forward.experts.gate_up_proj" in fqn: + full_tensors = full_tensor.chunk(2, dim=-1) + elif "shared_expert" in fqn: + # TODO: this is not right but I have to do this to load the checkpoint. + full_tensor = full_tensor.transpose(1, 0) + full_tensors = [full_tensor.unsqueeze(0)] + else: + full_tensors = [full_tensor] + return full_tensors + + +@dataclass +class _Assignment: + loader_id: int + filename: str + fqns: list[str] + shapes: list[torch.Size] + dtypes: list[torch.dtype] + + +@dataclass +class _AssignmentRound: + loader_assignments: dict[int, _Assignment] # List of assignments for each loader + + +@dataclass +class TensorMetadata: + fqn: str + shape: torch.Size + dtype: torch.dtype + + +class CheckpointConverter: + def __init__( + self, + process_group: dist.ProcessGroup, + path: str, + token: Optional[str] = None, + loader_every_n_ranks: int = 8, + ) -> None: + self.path = path + self.token = token + self.pg = process_group + self.my_rank = dist.get_rank(self.pg) + + self.loader_every_n_ranks = loader_every_n_ranks + self.loader_id = self.my_rank // loader_every_n_ranks + self.should_load = self.my_rank % loader_every_n_ranks == 0 + self.total_loader = dist.get_world_size(self.pg) // loader_every_n_ranks + + self.titan_fqn_to_stored_fqn: dict[str, str] = {} + self.stored_fqn_to_titan_fqn: dict[str, list[str]] = {} + self.total_send_bytes = 0 + self.total_recv_bytes = 0 + + def convert(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + begin = time.time() + self._load_metadata() + self._create_fqn_mappings(state_dict) + rounds = self._get_load_assignments(state_dict) + + logger.info(f"Got {len(rounds)} rounds of assignments.") + for idx, assignments in enumerate(rounds): + loader_assignments = assignments.loader_assignments + loaded_state_dict = None + # Let each loader to load its own data and move to its GPU. + logger.info(f"Loading round {idx}") + for i in range(self.total_loader): + # This loader doesn't have any loading assignment for this round. + if i not in loader_assignments: + continue + # This rank is not the loader + if i != self.loader_id or not self.should_load: + continue + loaded_state_dict = self._load_round(loader_assignments[i]) + + torch.cuda.synchronize() + logger.info(f"Loading round {idx} finished") + for i in range(self.total_loader): + if i not in loader_assignments: + continue + + logger.info(f"Resharding round {idx} loader {i} data. ") + if i == self.loader_id and self.should_load: + # This rank is the loader. It needs to send the loaded data to + # the other ranks. + assert loaded_state_dict is not None + results = self._reshard_send( + loader_assignments[i], loaded_state_dict + ) + else: + results = self._reshard_receive(loader_assignments[i], state_dict) + torch.cuda.synchronize() + + logger.info(f"Communication round {idx} loader {i} is done.") + self._reshard(results, state_dict) + logger.info(f"Resharding round {idx} loader {i} is done.") + self._reshard(results, state_dict) + torch.cuda.synchronize() + + dist.barrier() + torch.cuda.synchronize() + logger.info(f"Checkpoint conversion took {time.time() - begin:.2f} seconds.") + logger.info(f"Total send bytes: {self.total_send_bytes / 1e9:.2f} GB") + logger.info(f"Total recv bytes: {self.total_recv_bytes / 1e9:.2f} GB") + return state_dict + + def _load_metadata(self) -> None: + metadata_path = os.path.join(self.path, "model.safetensors.index.json") + with open(metadata_path, "r") as f: + self.metadata = json.load(f)["weight_map"] + + def _create_fqn_mappings(self, state_dict: dict[str, torch.Tensor]) -> None: + if not self.metadata: + return + + # Create the mapping from the stored checkpoint keys to TorchTitan keys. + for fqn in list(self.metadata.keys()): + titan_fqns = convert_to_titan_fqns(fqn) + # We don't know how to process _extra_state + if "_extra_state" in fqn: + self.metadata.pop(fqn) + continue + + if titan_fqns[0] not in state_dict: + for titan_fqn in titan_fqns: + assert titan_fqn not in state_dict + self.metadata.pop(fqn) + continue + + self.stored_fqn_to_titan_fqn[fqn] = titan_fqns + for titan_fqn in titan_fqns: + self.titan_fqn_to_stored_fqn[titan_fqn] = fqn + + torchtitan_extra = sorted( + list(set(state_dict.keys()) - set(self.titan_fqn_to_stored_fqn.keys())) + ) + converted_extra = sorted( + list(set(self.titan_fqn_to_stored_fqn.keys()) - set(state_dict.keys())) + ) + assert set(state_dict.keys()) == set(self.titan_fqn_to_stored_fqn.keys()), ( + f"{pprint.pformat(torchtitan_extra)}", + f"{pprint.pformat(converted_extra)}", + ) + + def _get_load_assignments( + self, state_dict: dict[str, Any] + ) -> list[_AssignmentRound]: + if self.my_rank == 0: + filename_to_metas = defaultdict(list) + for fqn, filename in self.metadata.items(): + titan_fqns = self.stored_fqn_to_titan_fqn[fqn] + shape = convert_to_hf_shape(fqn, titan_fqns, state_dict[titan_fqns[0]]) + meta = TensorMetadata( + fqn=fqn, + shape=shape, + # TODO: don't hardcode this + dtype=torch.bfloat16, + ) + filename_to_metas[filename].append(meta) + + loader_filename_to_metas = [{} for _ in range(self.total_loader)] + for idx, (filename, metas) in enumerate(filename_to_metas.items()): + loader_id = idx % self.total_loader + loader_filename_to_metas[loader_id][filename] = metas + + rounds = [] + while any(len(remain) > 0 for remain in loader_filename_to_metas): + round_assignment = _AssignmentRound(loader_assignments={}) + for loader_id in range(self.total_loader): + if not loader_filename_to_metas[loader_id]: + continue + + filename, metas = loader_filename_to_metas[loader_id].popitem() + round_assignment.loader_assignments[loader_id] = _Assignment( + filename=filename, + fqns=[meta.fqn for meta in metas], + shapes=[meta.shape for meta in metas], + dtypes=[meta.dtype for meta in metas], + loader_id=loader_id, + ) + + rounds.append(round_assignment) + + object_list: list[Any] = [ + rounds, + self.titan_fqn_to_stored_fqn, + self.stored_fqn_to_titan_fqn, + ] + else: + object_list = [None, None, None] + + dist.broadcast_object_list(object_list, src=0, group=self.pg) + rounds = object_list[0] + self.titan_fqn_to_stored_fqn = object_list[1] + self.stored_fqn_to_titan_fqn = object_list[2] + return rounds + + def _load_round(self, assignment: _Assignment) -> dict[str, Any]: + from safetensors.torch import load_file as hf_load_file + + path = os.path.join(self.path, assignment.filename) + state_dict = hf_load_file(path) + return { + k: v.to(device="cuda") + for k, v in state_dict.items() + if k in assignment.fqns + } + + def _reshard_send( + self, + assignment: _Assignment, + loaded_state_dict: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + flatten_tensors = [t.flatten() for t in loaded_state_dict.values()] + flatten_tensor = torch.concat(flatten_tensors) + assert self.loader_id == assignment.loader_id + rank = self.loader_id * self.loader_every_n_ranks + assert rank == self.my_rank + logger.info( + f"Sending {assignment.filename} from {rank} {self.loader_id} " + f"{flatten_tensor.shape=} {flatten_tensor.dtype=} {loaded_state_dict.keys()=}." + ) + logger.info(f"Sending {assignment}") + dist.broadcast(flatten_tensor, src=rank, group=self.pg) + self.total_send_bytes += flatten_tensor.numel() * flatten_tensor.element_size() + return loaded_state_dict + + def _reshard_receive( + self, assignment: _Assignment, state_dict: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + + flatten_tensor = torch.empty( + sum(math.prod(s) for s, d in zip(assignment.shapes, assignment.dtypes)), + dtype=assignment.dtypes[0], + device="cuda", + ) + rank = assignment.loader_id * self.loader_every_n_ranks + logger.info( + f"Receiving {assignment.filename} from {rank} " + f"{flatten_tensor.shape=} {flatten_tensor.dtype=}" + ) + logger.info(f"Receiving {assignment}") + dist.broadcast(flatten_tensor, src=rank, group=self.pg) + self.total_recv_bytes += flatten_tensor.numel() * flatten_tensor.element_size() + + ret: dict[str, torch.Tensor] = {} + loc = 0 + for fqn, shape, dtype in zip( + assignment.fqns, assignment.shapes, assignment.dtypes + ): + n_ele = math.prod(shape) + ret[fqn] = flatten_tensor[loc : loc + n_ele].view(shape) + loc += n_ele + return ret + + def _reshard( + self, + result: dict[str, torch.Tensor], + state_dict: dict[str, torch.Tensor], + ) -> None: + def _inplace_copy(fqn: str, full_tensors: list[torch.Tensor]): + titan_fqns = self.stored_fqn_to_titan_fqn[fqn] + assert len(titan_fqns) == len(full_tensors) + for titan_fqn, full_tensor in zip(titan_fqns, full_tensors): + dtensor = state_dict[titan_fqn] + assert isinstance(dtensor, DTensor) + assert dtensor.shape == full_tensor.shape, ( + (fqn, titan_fqn), + dtensor.shape, + full_tensor.shape, + ) + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, dtensor.device_mesh, dtensor.placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] + logger.debug( + f"Copying {titan_fqn} with {slices=} {dtensor._local_tensor.shape=} " + f"{shape=} {offset=} {self.my_rank=} {dtensor.shape=} {full_tensor.shape=} " + f"{dtensor.placements=} {dtensor.device_mesh=} " + ) + dtensor.to_local().copy_(full_tensor[slices].to(dtensor.dtype)) + + for fqn, full_tensor in result.items(): + full_tensors = convert_to_titan_tensors(fqn, full_tensor) + _inplace_copy(fqn, full_tensors) + + +def _create_verified_state_dict( + pg: dist.ProcessGroup, mesh: DeviceMesh +) -> dict[str, torch.Tensor]: + placements = [Shard(0)] + state_dict = { + "vision_model.vision_adapter.mlp.fc1.weight": torch.rand( + 4096, 5632, device="cuda", dtype=torch.bfloat16 + ), + "vision_model.vision_adapter.mlp.fc2.weight": torch.rand( + 4096, 4096, device="cuda", dtype=torch.bfloat16 + ), + "language_model.model.layers.3.feed_forward.experts.gate_up_proj": torch.rand( + 16, 5120, 16384, device="cuda", dtype=torch.bfloat16 + ), + } + return {k: distribute_tensor(v, mesh, placements) for k, v in state_dict.items()} + + +def _verify_state_dict( + state_dict: dict[str, torch.Tensor], path: str, rank: int +) -> None: + metadata_path = os.path.join(path, "model.safetensors.index.json") + with open(metadata_path, "r") as f: + metadata = json.load(f)["weight_map"] + all_filenames = set() + for fqn, tensor in state_dict.items(): + filename = os.path.join(path, metadata[fqn]) + all_filenames.add(filename) + + stored_state_dict = {} + from safetensors.torch import load_file as hf_load_file + + for filename in all_filenames: + _sd = hf_load_file(filename) + for k in list(_sd.keys()): + if k not in state_dict: + _sd.pop(k) + else: + stored_state_dict[k] = _sd[k] + + def read_and_verify_tensor(fqn: str, dtensor: DTensor) -> None: + logger.info(f"Verifying {fqn} {dtensor.shape=} {dtensor.placements=} ") + stored_tensor = stored_state_dict[fqn] + full_tensor = dtensor.full_tensor() + logger.info(f"Gather {fqn} {full_tensor.shape} completely.") + + if rank > 0: + return + + stored_tensor = stored_tensor.to(device="cuda") + logger.info(f"Move to GPU {fqn} completely.") + + assert stored_tensor.shape == full_tensor.shape, fqn + assert stored_tensor.dtype == full_tensor.dtype, fqn + assert stored_tensor.device == full_tensor.device, fqn + assert torch.allclose(stored_tensor, full_tensor), fqn + + for k, v in state_dict.items(): + read_and_verify_tensor(k, v) + + +if __name__ == "__main__": + init_logger() + config = JobConfig() + config.parser.add_argument( + "--checkpoint.convert_path", + type=str, + default="", + help="""Specify the path of the target checkpoint to convert.""", + ) + config.parser.add_argument( + "--checkpoint.convert_hf_token", + type=str, + default="", + help="""Specify hf token.""", + ) + config.parser.add_argument( + "--checkpoint.convert_load_every_n_ranks", + type=int, + default=8, + help=""" + Specify the interval at which ranks are assigned to load checkpoints. + + For example, if this number is 4, then ranks 0, 4, 8, ... will load the + checkpoint. Each loader is responsible for loading one file. If there + are more loaders than files, only the first few loaders will be assigned + to load the checkpoint. The default value is 8. + """, + ) + config.parser.add_argument( + "--checkpoint.fake_model", + action="store_true", + help="""If true, the model will be fake.""", + ) + config.parse_args() + assert config.checkpoint.convert_path != "" + + trainer: Optional[Trainer] = None + + try: + trainer = Trainer(config) + if os.path.exists(trainer.checkpointer.folder): + raise RuntimeError( + "The checkpoint folder already exists. Abort to avoid overwriting " + f"the checkpoint. {trainer.checkpointer.folder=}" + ) + if config.checkpoint.fake_model: + state_dict = _create_verified_state_dict( + trainer.world_mesh.get_group(), trainer.world_mesh + ) + else: + state_dict = trainer.checkpointer.states[MODEL].state_dict() + + size = 0 + for v in state_dict.values(): + size += v.numel() * v.element_size() + logger.info(f"Total size of the model: {size / 1e9:.2f} GB") + + # Do not support PP yet, we will need to iterate over the PP dimension and + # extract the corresponding state_dict and device_mesh. + if "freqs_cis" in state_dict: + state_dict.pop("freqs_cis") + + # Our tokenizer is not up-to-date yet. + tok_embeddings_weight = state_dict.pop("tok_embeddings.weight") + output_weight = state_dict.pop("output.weight") + state_dict = CheckpointConverter( + process_group=trainer.world_mesh.get_group(), + path=config.checkpoint.convert_path, + token=config.checkpoint.convert_hf_token, + loader_every_n_ranks=config.checkpoint.convert_load_every_n_ranks, + ).convert(state_dict) + state_dict["tok_embeddings.weight"] = tok_embeddings_weight + state_dict["output.weight"] = output_weight + + class DummyModel: + def __init__(self, state_dict: dict[str, torch.Tensor]) -> None: + self._state_dict = state_dict + + def state_dict(self) -> dict[str, torch.Tensor]: + return self._state_dict + + if config.checkpoint.fake_model: + begin = time.time() + _verify_state_dict( + state_dict, + config.checkpoint.convert_path, + trainer.world_mesh.get_rank(), + ) + dist.barrier() + logger.info(f"Verifies state_dict {time.time() - begin}.") + else: + # oh, this is pretty bad, when can we get rid of the freqs_cis issue? + state_dict["freqs_cis"] = None + trainer.checkpointer.states[MODEL] = DummyModel(state_dict) + trainer.checkpointer.model_weights_only = True + trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype + trainer.checkpointer.save(curr_step=0, force=True) + time.sleep(2) + finally: + pass diff --git a/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.sh b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.sh new file mode 100644 index 0000000000000000000000000000000000000000..6530b8ce992c8c33ccec94614e026d73964710ee --- /dev/null +++ b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.sh @@ -0,0 +1,26 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +set -ex + +# use envs as local overrides for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./convert_hf_to_dcp_with_gpus.sh +NGPU=${NGPU:-"8"} +LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7} +CONFIG_FILE=${CONFIG_FILE:-"../train_configs/llama4_17bx16e.toml"} + +overrides="" +if [ $# -ne 0 ]; then + overrides="$*" +fi + +PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +convert_hf_to_dcp_with_gpus.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py b/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..7756afe3de1527f469a38fc6a0bdc6c62eaa2526 --- /dev/null +++ b/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py @@ -0,0 +1,536 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import os +import time +from dataclasses import dataclass +from typing import Any, Optional + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor, Shard +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset +from torchtitan.components.checkpoint import MODEL +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import init_logger, logger +from torchtitan.train import Trainer + +# Sharding dims for MP checkpoints + +column_parallel = [ + "tok_embeddings", + "wq", + "wk", + "wv", + "wqkv", + "w_in_shared_FD", + "w_out_eF_D", + "w_swiglu_FD", + "output", + "_linear", + "c_fc", + "vision_projection", +] + +row_parallel = [ + "wo", + "w_out_shared_DF", + "w_in_eD_F", + "moe_w_swiglu_eD_F", + "c_proj", +] + + +def convert_to_titan_fqns(fqn: str) -> list[str]: + # From the stored checkpoint keys to TorchTitan keys. + if "wqkv" in fqn and "layer_norm_weight" not in fqn: + ret = [] + for k in ("wq", "wk", "wv"): + ret.append(fqn.replace("wqkv", k)) + return ret + return [fqn] + + +def get_shard_dim(fqn: str) -> Optional[int]: + if "bias" in fqn: + # Some bias params are still sharded + if "resblocks" in fqn: + for k in ("wq", "wk", "wv", "c_fc"): + if k in fqn: + return 0 + return None + elif any([x in fqn for x in column_parallel]): + return 0 + elif any([x in fqn for x in row_parallel]): + return 1 + else: + return None + + +def split_fused_qkv(shards: list[torch.Tensor]) -> tuple[torch.Tensor, ...]: + qkvs = [torch.split(shard, [640, 128, 128]) for shard in shards] + q = torch.cat([qkv[0] for qkv in qkvs], dim=0) + k = torch.cat([qkv[1] for qkv in qkvs], dim=0) + v = torch.cat([qkv[2] for qkv in qkvs], dim=0) + return q, k, v + + +@dataclass +class _Assignment: + loader_id: int + filename: str + fqns: tuple[str, ...] + shapes: tuple[torch.Size, ...] + dtypes: tuple[torch.dtype, ...] + + +@dataclass +class _AssignmentRound: + loader_assignments: dict[int, _Assignment] # List of assignments for each loader + + +class CheckpointConverter: + TOTAL_SHARDS = 8 + + def __init__( + self, + process_group: dist.ProcessGroup, + path: str, + loader_every_n_ranks: int = 8, + ) -> None: + self.path = path + self.pg = process_group + self.my_rank = dist.get_rank(self.pg) + self.loader_every_n_ranks = loader_every_n_ranks + self.loader_id = self.my_rank // loader_every_n_ranks + self.should_load = ( + self.my_rank % loader_every_n_ranks == 0 + and self.loader_id < CheckpointConverter.TOTAL_SHARDS + ) + self.total_loader = CheckpointConverter.TOTAL_SHARDS + self.titan_fqn_to_stored_fqn: dict[str, str] = {} + self.stored_fqn_to_titan_fqn: dict[str, list[str]] = {} + self.total_send_bytes = 0 + self.total_recv_bytes = 0 + + def convert(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + begin = time.time() + self._load_metadata() + self._create_fqn_mappings(state_dict) + rounds = self._get_load_assignments(state_dict) + + for assignments in rounds: + loader_assignments = assignments.loader_assignments + loaded_state_dict = None + # Let each loader to load its own data and move to its GPU. + for i in range(self.total_loader): + # This loader doesn't have any loading assignment for this round. + if i not in loader_assignments: + continue + # This rank is not the loader + if i != self.loader_id or not self.should_load: + continue + loaded_state_dict = self._load_round(loader_assignments[i]) + + results = [] + for i in range(self.total_loader): + if i not in loader_assignments: + continue + + if i == self.loader_id and self.should_load: + # This rank is the loader. It needs to send the loaded data to + # the other ranks. + assert loaded_state_dict is not None + results.append( + self._reshard_send(loader_assignments[i], loaded_state_dict) + ) + else: + results.append( + self._reshard_receive(loader_assignments[i], state_dict) + ) + + self._reshard(results, state_dict) + + torch.cuda.synchronize() + logger.info(f"Checkpoint conversion took {time.time() - begin:.2f} seconds.") + logger.info(f"Total send bytes: {self.total_send_bytes / 1e9:.2f} GB") + logger.info(f"Total recv bytes: {self.total_recv_bytes / 1e9:.2f} GB") + return state_dict + + def _get_file_path(self, loader_id: int) -> str: + return os.path.join(self.path, f"consolidated.0{loader_id}.pth") + + def _load_metadata(self) -> None: + if not self.should_load: + self.read_dict = {} + return + self.read_dict = torch.load( + self._get_file_path(self.loader_id), + mmap=True, + weights_only=False, + ) + + def _create_fqn_mappings(self, state_dict: dict[str, torch.Tensor]) -> None: + if not self.read_dict: + return + + # Create the mapping from the stored checkpoint keys to TorchTitan keys. + for fqn in list(self.read_dict.keys()): + titan_fqns = convert_to_titan_fqns(fqn) + # We don't know how to process _extra_state + if "_extra_state" in fqn: + self.read_dict.pop(fqn) + continue + + if titan_fqns[0] not in state_dict: + for titan_fqn in titan_fqns: + assert titan_fqns[0] not in state_dict + self.read_dict.pop(fqn) + continue + self.stored_fqn_to_titan_fqn[fqn] = titan_fqns + for titan_fqn in titan_fqns: + self.titan_fqn_to_stored_fqn[titan_fqn] = fqn + + assert set(state_dict.keys()) == set(self.titan_fqn_to_stored_fqn.keys()), ( + set(state_dict.keys()) - set(self.titan_fqn_to_stored_fqn.keys()), + set(self.titan_fqn_to_stored_fqn.keys()) - set(state_dict.keys()), + ) + + def _get_load_assignments( + self, state_dict: dict[str, torch.Tensor] + ) -> list[_AssignmentRound]: + if self.my_rank == 0: + rounds: list[_AssignmentRound] = [] + size = 0 + fqns = [] + shapes = [] + dtypes = [] + + # All loader must load all the FQNs because the checkpoint is purely TP sharded. + all_keys = list(self.read_dict.keys()) + for fqn in all_keys: + fqns.append(fqn) + shapes.append(self.read_dict[fqn].shape) + dtypes.append(self.read_dict[fqn].dtype) + size += self.read_dict[fqn].numel() * self.read_dict[fqn].element_size() + if size < 1e9 and fqn != all_keys[-1]: + continue + + logger.info(f"Adding {fqns} to round {len(rounds)}") + round_assignment = _AssignmentRound(loader_assignments={}) + for loader_id in range(self.total_loader): + path = self._get_file_path(loader_id) + round_assignment.loader_assignments[loader_id] = _Assignment( + filename=path, + fqns=tuple(fqns), + shapes=tuple(shapes), + dtypes=tuple(dtypes), + loader_id=loader_id, + ) + rounds.append(round_assignment) + size = 0 + fqns.clear() + shapes.clear() + dtypes.clear() + + object_list: list[Any] = [ + rounds, + self.titan_fqn_to_stored_fqn, + self.stored_fqn_to_titan_fqn, + ] + else: + object_list = [None, None, None] + + dist.broadcast_object_list(object_list, src=0, group=self.pg) + rounds = object_list[0] + self.titan_fqn_to_stored_fqn = object_list[1] + self.stored_fqn_to_titan_fqn = object_list[2] + return rounds + + def _load_round(self, assignment: _Assignment) -> dict[str, torch.Tensor]: + ret = {} + assert self.read_dict + for fqn in assignment.fqns: + ret[fqn] = self.read_dict[fqn].to(device="cuda") + return ret + + def _reshard_send( + self, + assignment: _Assignment, + loaded_state_dict: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + flatten_tensors = [t.flatten() for t in loaded_state_dict.values()] + flatten_tensor = torch.concat(flatten_tensors) + assert self.loader_id == assignment.loader_id + rank = self.loader_id * self.loader_every_n_ranks + assert rank == self.my_rank + logger.info(f"Sending {assignment.filename} from {rank} {self.loader_id}") + logger.info(f"Sending {assignment.fqns}") + dist.broadcast(flatten_tensor, src=rank, group=self.pg) + self.total_send_bytes += flatten_tensor.numel() * flatten_tensor.element_size() + return loaded_state_dict + + def _reshard_receive( + self, assignment: _Assignment, state_dict: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + flatten_tensor = torch.empty( + sum(math.prod(s) for s, d in zip(assignment.shapes, assignment.dtypes)), + dtype=assignment.dtypes[0], + device="cuda", + ) + rank = assignment.loader_id * self.loader_every_n_ranks + dist.broadcast(flatten_tensor, src=rank, group=self.pg) + self.total_recv_bytes += flatten_tensor.numel() * flatten_tensor.element_size() + + ret: dict[str, torch.Tensor] = {} + loc = 0 + for fqn, shape, dtype in zip( + assignment.fqns, assignment.shapes, assignment.dtypes + ): + n_ele = math.prod(shape) + ret[fqn] = flatten_tensor[loc : loc + n_ele].view(shape) + loc += n_ele + return ret + + def _reshard( + self, + results: list[dict[str, torch.Tensor]], + state_dict: dict[str, torch.Tensor], + ) -> None: + def _inplace_copy(fqn: str, full_tensors: tuple[torch.Tensor, ...]): + titan_fqns = self.stored_fqn_to_titan_fqn[fqn] + assert len(titan_fqns) == len(full_tensors) + for titan_fqn, full_tensor in zip(titan_fqns, full_tensors): + dtensor = state_dict[titan_fqn] + logger.info(f"{titan_fqn} {full_tensor.sum()}") + assert isinstance(dtensor, DTensor) + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, dtensor.device_mesh, dtensor.placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] + logger.info( + f"Copying {titan_fqn} with {slices=} {dtensor._local_tensor.shape=} " + f"{shape=} {offset=} {self.my_rank=} {dtensor.shape=} {full_tensor.shape=} " + f"{dtensor.placements=} {dtensor.device_mesh=} " + ) + dtensor.to_local().copy_(full_tensor[slices]) + + def _concat_shards(fqn, shards: list[torch.Tensor]) -> tuple[torch.Tensor, ...]: + if "wqkv" in fqn: + if "layer_norm" in fqn: + return (shards[0],) + return split_fused_qkv(shards) + + shard_dim = get_shard_dim(fqn) + if shard_dim is None: + return (shards[0],) + return (torch.cat(shards, dim=shard_dim),) + + fqns = list(results[0].keys()) + for result in results: + assert list(result.keys()) == fqns + + for fqn in fqns: + full_tensors = _concat_shards(fqn, [result[fqn] for result in results]) + _inplace_copy(fqn, full_tensors) + + +def _create_verified_state_dict( + pg: dist.ProcessGroup, mesh: DeviceMesh +) -> dict[str, torch.Tensor]: + placements = [Shard(0)] + state_dict = { + "tok_embeddings.weight": torch.rand( + 25256 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.attention.wqkv.layer_norm_weight": torch.rand( + 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.attention.wq.weight": torch.rand( + 640 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.attention.wk.weight": torch.rand( + 128 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.attention.wv.weight": torch.rand( + 128 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.attention.wo.weight": torch.rand( + 5120, 640 * 8, device="cuda", dtype=torch.bfloat16 + ), + # "layers.47.feed_forward.router_DE": torch.rand(5120, 128, device="cuda", dtype=torch.bfloat16), + # "layers.47.feed_forward.running_gate_stats_3E": torch.rand(3, 128, device="cuda", dtype=torch.bfloat16), + # "layers.47.feed_forward.global_gate_stats_3E": torch.rand(3, 128, device="cuda", dtype=torch.bfloat16), + "layers.47.feed_forward.w_in_shared_FD.weight": torch.rand( + 1024 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.w_out_shared_DF.weight": torch.rand( + 5120, 1024 * 8, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.w_swiglu_FD.weight": torch.rand( + 1024 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.norm.weight": torch.rand( + 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.experts.moe_w_in_eD_F": torch.rand( + 655360, 1024 * 8, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.experts.moe_w_out_eF_D": torch.rand( + 131072 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.experts.moe_w_swiglu_eD_F": torch.rand( + 655360, 1024 * 8, device="cuda", dtype=torch.bfloat16 + ), + } + return {k: distribute_tensor(v, mesh, placements) for k, v in state_dict.items()} + + +def _verify_state_dict( + state_dict: dict[str, torch.Tensor], path: str, rank: int +) -> None: + stored_state_dicts = [ + torch.load( + os.path.join(path, f"consolidated.0{i}.pth"), + map_location="cpu", + weights_only=False, + mmap=True, + ) + for i in range(8) + ] + + def read_and_verify_tensor(fqn: str, dtensor: DTensor) -> None: + logger.info(f"Verifying {fqn} {dtensor.shape=} {dtensor.placements=} ") + shards = [stored_state_dicts[i][fqn] for i in range(8)] + full_tensor = dtensor.full_tensor() + logger.info(f"Gather {fqn} {full_tensor.shape} completely.") + + if rank > 0: + return + + if len(shards[0].shape) == 1: + assert full_tensor.shape == shards[0].shape, fqn + assert torch.allclose(shards[0].to(device="cuda"), full_tensor), fqn + return + elif shards[0].shape[0] == full_tensor.shape[0]: + concat_shards = torch.cat(shards, dim=1) + logger.info(f"Load {fqn} completely.") + elif shards[0].shape[1] == full_tensor.shape[1]: + concat_shards = torch.cat(shards, dim=0) + logger.info(f"Load {fqn} completely.") + + concat_shards = concat_shards.to(device="cuda") + logger.info(f"Move to GPU {fqn} completely.") + + assert concat_shards.shape == full_tensor.shape, fqn + assert concat_shards.dtype == full_tensor.dtype, fqn + assert concat_shards.device == full_tensor.device, fqn + assert torch.allclose(concat_shards, full_tensor), fqn + + for k, v in state_dict.items(): + if "wq" in k and "wqkv" not in k: + pass + elif "wk" in k: + pass + elif "wv" in k: + pass + else: + assert v is not None, k + read_and_verify_tensor(k, v) + + +if __name__ == "__main__": + init_logger() + config = JobConfig() + config.parser.add_argument( + "--checkpoint.convert_path", + type=str, + default="", + help="""Specify the path of the target checkpoint to convert.""", + ) + config.parser.add_argument( + "--checkpoint.convert_load_every_n_ranks", + type=int, + default=8, + help=""" + Specify the interval at which ranks are assigned to load checkpoints. + + For example, if this number is 4, then ranks 0, 4, 8, ... will load the + checkpoint. Each loader is responsible for loading one file. If there + are more loaders than files, only the first few loaders will be assigned + to load the checkpoint. The default value is 8. + """, + ) + config.parser.add_argument( + "--checkpoint.fake_model", + action="store_true", + help="""If true, the model will be fake.""", + ) + config.parse_args() + assert config.checkpoint.convert_path != "" + + trainer: Optional[Trainer] = None + + try: + trainer = Trainer(config) + if os.path.exists(trainer.checkpointer.folder): + raise RuntimeError( + "The checkpoint folder already exists. Abort to avoid overwriting " + f"the checkpoint. {trainer.checkpointer.folder=}" + ) + if config.checkpoint.fake_model: + state_dict = _create_verified_state_dict( + trainer.world_mesh.get_group(), trainer.world_mesh + ) + else: + state_dict = trainer.checkpointer.states[MODEL].state_dict() + + size = 0 + for v in state_dict.values(): + size += v.numel() * v.element_size() + logger.info(f"Total size of the model: {size / 1e9:.2f} GB") + + # Do not support PP yet, we will need to iterate over the PP dimension and + # extract the corresponding state_dict and device_mesh. + if "freq_cis" in state_dict: + state_dict.pop("freqs_cis") + + state_dict = CheckpointConverter( + process_group=trainer.world_mesh.get_group(), + path=config.checkpoint.convert_path, + loader_every_n_ranks=config.checkpoint.convert_load_every_n_ranks, + ).convert(state_dict) + + class DummyModel: + def __init__(self, state_dict: dict[str, torch.Tensor]) -> None: + self._state_dict = state_dict + + def state_dict(self) -> dict[str, torch.Tensor]: + return self._state_dict + + if config.checkpoint.fake_model: + begin = time.time() + _verify_state_dict( + state_dict, + config.checkpoint.convert_path, + trainer.world_mesh.get_rank(), + ) + dist.barrier() + logger.info(f"Verifies state_dict {time.time() - begin}.") + else: + # oh, this is pretty bad, when can we get rid of the freqs_cis issue? + state_dict["freqs_cis"] = None + trainer.checkpointer.states[MODEL] = DummyModel(state_dict) + trainer.checkpointer.model_weights_only = True + trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype + trainer.checkpointer.save(curr_step=0, force=True) + time.sleep(2) + finally: + pass diff --git a/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh b/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh new file mode 100644 index 0000000000000000000000000000000000000000..f3fd310934b1181ed83fa9fc4463f0c2336b46fc --- /dev/null +++ b/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh @@ -0,0 +1,25 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# use envs as local overrides for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./convert_meta_to_dcp_with_gpus.sh +NGPU=${NGPU:-"8"} +LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7} +CONFIG_FILE=${CONFIG_FILE:-"../train_configs/llama4_17bx16e.toml"} + +overrides="" +if [ $# -ne 0 ]; then + overrides="$*" +fi + +PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +convert_meta_to_dcp_with_gpus_meta.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml new file mode 100644 index 0000000000000000000000000000000000000000..139a1f28bfff5136e1ff625ee00d6e015b7729ba --- /dev/null +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -0,0 +1,74 @@ +[job] +dump_folder = "./outputs" +description = "Llama 4 debug training" +print_args = false +use_for_integration_test = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "llama4" +flavor = "debugmodel" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm +# test tokenizer.model, for debug purpose only +tokenizer_path = "./tests/assets/test_tiktoken.model" +# converters = "float8" +use_flex_attn = false +attn_mask_type = "causal" # causal / block_causal + +[optimizer] +name = "AdamW" +lr = 4e-3 +eps = 1e-15 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 0.1 + +[training] +batch_size = 8 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +compile = false +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 10 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'none' # ['none', 'selective', 'full'] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = "output,router.gate" diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml new file mode 100644 index 0000000000000000000000000000000000000000..e947afba56fd3b8ee5bf1fe45e65160c99a6fd18 --- /dev/null +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml @@ -0,0 +1,65 @@ +# TODO: this toml config is still under development + +[job] +dump_folder = "./outputs" +description = "Llama 4 Maverick 17Bx128E training" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = false +save_tb_folder = "tb" + +[model] +name = "llama4" +flavor = "17bx128e" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm +tokenizer_path = "./assets/tokenizer/tokenizer.model" +# converters = "float8" + +[optimizer] +name = "AdamW" +lr = 4e-3 +eps = 1e-15 + +[lr_scheduler] +warmup_steps = 600 +lr_min = 0.1 + +[training] +batch_size = 1 +seq_len = 8192 +max_norm = 1.0 # grad norm clipping +steps = 3000 +compile = false +dataset = "c4" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 8 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 4 +# pipeline_parallel_schedule = "interleaved1f1b" +# pipeline_parallel_microbatches = 2 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 500 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'full' # ['none', 'selective', 'full'] + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = "output,router.gate" diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml new file mode 100644 index 0000000000000000000000000000000000000000..d464d2d8cfddecb0e338a48926d0650a8ecb7930 --- /dev/null +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml @@ -0,0 +1,63 @@ +# NOTE: this toml config is a preset for 64 H100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Llama 4 Scout 17Bx16E training" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = false +save_tb_folder = "tb" + +[model] +name = "llama4" +flavor = "17bx16e" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm +tokenizer_path = "./assets/tokenizer/tokenizer.model" +# converters = "float8" + +[optimizer] +name = "AdamW" +lr = 4e-3 +eps = 1e-15 + +[lr_scheduler] +warmup_steps = 600 +lr_min = 0.1 + +[training] +batch_size = 8 +seq_len = 8192 +max_norm = 1.0 # grad norm clipping +steps = 3000 +compile = false +dataset = "c4" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 8 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 500 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'full' # ['none', 'selective', 'full'] + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = "output,router.gate" diff --git a/torchtitan/experiments/multimodal/__init__.py b/torchtitan/experiments/multimodal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe08681bbc532dc23a734fd961648890cec7d497 --- /dev/null +++ b/torchtitan/experiments/multimodal/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from mm_dataset import build_mm_dataloader + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer +from torchtitan.models.llama3 import parallelize_llama, pipeline_llama +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .model import ModelArgs, MultimodalDecoder, VisionEncoder + +__all__ = ["VisionEncoder", "ModelArgs", "MultimodalDecoder"] + +llama4_mm_configs = { + # TODO: add configs for llama4 multimodal +} + +register_train_spec( + TrainSpec( + name="llama4_multimodal", + cls=MultimodalDecoder, + config=llama4_mm_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_mm_dataloader, + build_tokenizer_fn=build_tiktoken_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/multimodal/check_padding_mm.py b/torchtitan/experiments/multimodal/check_padding_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..0345009256e80ccd3e010ed270d36bff0271555a --- /dev/null +++ b/torchtitan/experiments/multimodal/check_padding_mm.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import click + +from mm_dataset import build_mm_dataloader +from tokenizer.tiktoken import build_tiktoken_tokenizer + +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import init_logger + + +@click.command() +@click.option("--dataset", default="OBELICS") +@click.option("--batch-size", default=4) +@click.option("--seq-len", default=4096) +@click.option("--tokenizer-path", required=True) +@click.option("--dp-rank", default=0) +@click.option("--dp-world-size", default=2) +@click.option("--batch-number", default=4) +def main( + dataset: str, + batch_size: int, + seq_len: int, + tokenizer_path: str, + dp_rank: int, + dp_world_size: int, + batch_number: int, +): + init_logger() + job_config = JobConfig() + job_config.parse_args( + [ + "--training.dataset", + dataset, + "--training.batch_size", + str(batch_size), + "--training.seq_len", + str(seq_len), + "--model.tokenizer_path", + tokenizer_path, + ] + ) + tokenizer = build_tiktoken_tokenizer(job_config) + dl = build_mm_dataloader( + dp_world_size=dp_world_size, + dp_rank=dp_rank, + tokenizer=tokenizer, + job_config=job_config, + ) + dl_iter = iter(dl) + + for _ in range(batch_number): + batch = next(dl_iter) + + # Analyze Batch + # input_ids + total_input_ids = batch["input_ids"].shape[0] * batch["input_ids"].shape[1] + total_non_padding_tokens = total_input_ids - int( + (batch["input_ids"] == 128004).sum() + ) + total_padding_tokens = total_input_ids - total_non_padding_tokens + print(f"Padding tokens in each sample: {(batch['input_ids'] == 128004).sum(dim=1)}") + print( + f"Unpadded tokens: {total_non_padding_tokens}, Total tokens in batch: {total_input_ids}" + ) + print( + f"Padded text tokens: {total_padding_tokens}, {(total_padding_tokens) / total_input_ids * 100:.2f}%" + ) + print(80 * "#") + # Images + padded_images = 0 + padded_tiles = 0 + for sample in batch["encoder_input"]["images"]: + for image in sample: + if int(image.sum()) == 0: + padded_images += 1 + for tile in image: + if int(tile.sum()) == 0: + padded_tiles += 1 + + total_images = ( + batch["encoder_input"]["images"].shape[0] + * batch["encoder_input"]["images"].shape[1] + ) + + print( + f"Unpadded images: {total_images - padded_images}, Total images in batch: {total_images}" + ) + print( + f'Padded images: {padded_images}, {padded_images / total_images * 100:.2f}% (Each image with shape {list(batch["encoder_input"]["images"][0, 0].shape)})' # noqa: B950 + ) + print(80 * "#") + # Tiles + total_number_of_tiles = total_images * batch["encoder_input"]["images"].shape[2] + + print( + f"Unpadded number of tiles: {total_number_of_tiles - padded_tiles}, Total number of tiles: {total_number_of_tiles}" + ) + print( + f'Padded tiles: {padded_tiles}, {padded_tiles / total_number_of_tiles * 100:.2f}% (Each with shape {list(batch["encoder_input"]["images"][0, 0, 0].shape)})' # noqa: B950 + ) + print(80 * "#") + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/multimodal/mm_collator.py b/torchtitan/experiments/multimodal/mm_collator.py new file mode 100644 index 0000000000000000000000000000000000000000..98793a7f6f9f9ad51a3f0b34a18fd102f8b99802 --- /dev/null +++ b/torchtitan/experiments/multimodal/mm_collator.py @@ -0,0 +1,227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F + +from tokenizer.tiktoken import IGNORE_INDEX + +from torch.nn.utils.rnn import pad_sequence + + +def padded_collate( + batch: List[Dict[str, List[int]]], + padding_idx: int = 0, + ignore_idx: int = -100, +) -> Dict[str, torch.Tensor]: + """Pad a batch of sequences to the longest sequence length in the batch, and + convert integer lists to tensors. + + Args: + batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs. + padding_idx (int): Padding index for input ids. Defaults to 0. + ignore_idx (int): Padding index for labels. Defaults to -100. + + Returns: + Dict[str, torch.Tensor]: Collated input and label tensors. + + Example: + >>> token_pairs = [ + >>> {"input_ids": [1, 2, 3], "labels": [4, 5, 6]}, + >>> {"input_ids": [7,], "labels": [10,]}, + >>> ] + >>> collated = padded_collate( + >>> batch=token_pairs, + >>> padding_idx=padding_idx, + >>> ignore_idx=ignore_idx, + >>> ) + >>> collated["input_ids"] + >>> tensor([[1, 2, 3], [7, 0, 0]]) + >>> collated["labels"] + >>> tensor([[4, 5, 6], [10, -100, -100]]) + """ + input_ids = pad_sequence( + [x["input_ids"] for x in batch], + batch_first=True, + padding_value=padding_idx, + ) + labels = pad_sequence( + [x["labels"] for x in batch], + batch_first=True, + padding_value=ignore_idx, + ) + + input_ids_seq_len = input_ids.shape[-1] + labels_seq_len = labels.shape[-1] + + # Hack to pad correctly and not use max_seq_len, which is costly + if input_ids_seq_len > labels_seq_len: + labels = F.pad( + labels, (0, input_ids_seq_len - labels_seq_len), value=ignore_idx + ) + elif labels_seq_len > input_ids_seq_len: + input_ids = F.pad( + input_ids, + (0, labels_seq_len - input_ids_seq_len), + value=padding_idx, + ) + return {"input_ids": input_ids, "labels": labels} + + +# NOTE Inspired from torchtune.data._collate.py +@dataclass +class MultiModalCollator: + padding_idx: int = 128004 + ignore_idx: int = IGNORE_INDEX + pad_max_tiles: Optional[int] = None + pad_max_images: Optional[int] = None + + def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + """Pad a batch of text sequences, tiled image tensors, aspect ratios, + and cross attention masks. This can be used for both training and inference. + + ``batch`` is expected to be a list of sample dicts containing the following:: + - "input_ids": List[int] of length text_seq_len, varies across samples + - "labels": List[int] of length text_seq_len, varies across samples + - "encoder_input": Dict[str, List[torch.Tensor]] + - "images": List[torch.Tensor], each with shape (n_tiles, c, h, w) + - "aspect_ratio": List[torch.Tensor], each with shape (2, ) to indicate h_ratio, w_ratio + + Shape notation: + - c = channel dim + - h = height dim + - w = weight dim + + Note: + For each element in the batch, ``len(images) == len(aspect_ratio)``. + + This collater does the following: + (1) Pad text sequence and encoder mask to the longest sequence length in the batch + (2) Pad image tensors in the tile dimension with zeros to the largest number + of tiles in the batch + (3) Add empty images of zeros to samples up to max number of images in the batch + (4) Pad aspect ratios with (1,1) for all added padding images + + Args: + batch (List[Dict[str, Any]]): A list of sample dicts containing input_ids, + labels, images, and aspect_ratio. + padding_idx (int): Padding index for input token ids. Defaults to 0. + ignore_idx (int): Padding index for labels. Defaults to -100. + pad_max_tiles (Optional[int]): Maximum number of tiles to pad to. If None, will pad to the largest number of tiles + in the batch. Defaults to None. + pad_max_images (Optional[int]): Maximum number of images to pad to. If None, will pad to the largest number of images + in the batch. Defaults to None. + + Returns: + Dict[str, Tensor]: Collated tokens, labels, images, aspect_ratio tensors. + - tokens: Tensor of shape (bsz, max_seq_len) + - labels: Tensor of shape (bsz, max_seq_len) + - images: Tensor of shape (bsz, max_num_images, max_num_tiles, c, h, w) + - aspect_ratio: Tensor of shape (bsz, max_num_images, 2) + + Example: + >>> image_id = 1 + >>> tokens_per_tile = 5 + >>> c, h, w = 1, 1, 1 + >>> batch = [ + ... { + ... "input_ids": [1, 2, 1, 3], "labels": [4, 5, 6, 7], + ... "encoder_input": { + ... # One image with two tiles, one image with three tiles + ... "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)], + ... "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])], + ... }, + ... }, + ... { + ... "input_ids": [1, 4], "labels": [8, 9], + ... "encoder_input": { + ... # One image with four tiles + ... "images": [torch.ones(4, c, h, w)], + ... "aspect_ratio": [torch.tensor([2, 2])], + ... }, + ... }, + ... ] + ... collator = MultiModalCollator(pad_max_tiles=4) + >>> model_inputs = collator(batch=batch) + >>> print(model_inputs["input_ids"]) + tensor([[1, 2, 1, 3], + [1, 4, 0, 0]]) + >>> print(model_inputs["labels"]) + tensor([[4, 5, 6, 7], + [8, 9, -100, -100]]) + >>> print(model_inputs["encoder_input"]["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w) + torch.Size([2, 2, 4, 1, 1, 1]) + >>> print(model_inputs["encoder_input"]["aspect_ratio"].shape) # (bsz, max_num_images, 2) + torch.Size([2, 2, 2]) + >>> print(model_inputs["encoder_input"]["images"][0, 0, ...]) # Image with two tiles got padded to four + tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]]) + >>> print(model_inputs["encoder_input"]["images"][0, 1, ...]) # Image with three tiles got padded to four + tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]]) + >>> print(model_inputs["encoder_input"]["images"][1, 0, ...]) # Image with four tiles did not get padded + tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]]) + >>> print(model_inputs["encoder_input"]["images"][1, 1, ...]) # Extra padding image was added to second sample + tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]]) + """ + # Text tokens can be handled independently by existing collaters + text_only = [ + {"input_ids": sample["input_ids"], "labels": sample["labels"]} + for sample in batch + ] + collated_text = padded_collate(text_only, self.padding_idx, self.ignore_idx) + + if self.pad_max_tiles is None: + # Get max number of tiles in batch + max_num_tiles = max(sample["images_tiles"].shape[0] for sample in batch) + else: + max_num_tiles = self.pad_max_tiles + + # Pad images and aspect ratios to max number of tiles + batch_images = [] + batch_aspect_ratios = [] + + for sample in batch: + sample_images = [] + for image in sample["encoder_input"]["images"]: + # Single image in each sample has shape (n_tiles, c, h, w) + n_tiles = image.shape[0] + # Single mask in each sample corresponds to a single image and has shape (text_seq_len, image_seq_len) + # where image_seq_len = n_tiles * tokens_per_tile + padding_tiles = max_num_tiles - n_tiles + + # Image should now have shape (max_num_tiles, c, h, w) + padded_image = F.pad( + image, (0, 0, 0, 0, 0, 0, 0, padding_tiles), value=0 + ) + + sample_images.append(padded_image) + # Stack multiple images and masks per sample in num_images dimension + batch_images.append(torch.stack(sample_images)) + batch_aspect_ratios.append( + torch.stack(sample["encoder_input"]["aspect_ratio"]) + ) + # Finally, pad images, masks, aspect ratios to max number of images in batch + # (bsz, max_num_images, max_num_tiles, c, h, w) + collated_images = pad_sequence(batch_images, batch_first=True, padding_value=0) + # (bsz, max_num_images, 2) + collated_aspect_ratios = pad_sequence( + batch_aspect_ratios, batch_first=True, padding_value=1 + ) + + batch_dict = { + "input_ids": collated_text["input_ids"], + "labels": collated_text["labels"], + "encoder_input": { + "images": collated_images, + "aspect_ratio": collated_aspect_ratios, + }, + } + + return batch_dict diff --git a/torchtitan/experiments/multimodal/mm_dataset.py b/torchtitan/experiments/multimodal/mm_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a29627aaceed17fd6b5f7f752d4b8a5fb006d47a --- /dev/null +++ b/torchtitan/experiments/multimodal/mm_dataset.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import torch + +from datasets import Dataset, load_dataset +from datasets.distributed import split_dataset_by_node + +from mm_collator import MultiModalCollator +from tokenizer.tiktoken import IGNORE_INDEX, Tokenizer +from torch.distributed.checkpoint.stateful import Stateful +from torch.utils.data import IterableDataset +from transform import CLIPTransform +from utils import load_image + +from torchtitan.components.dataloader import ParallelAwareDataloader +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger + + +def _load_obelics_dataset(dataset_path: str): + """Load C4 dataset with default configuration.""" + return load_dataset(dataset_path, split="train", streaming=True) + + +def _process_obelics_sample( + sample: dict[str, Any], image_token: str = "<|image|>" +) -> Dict[str, List[Union[str, "PIL.Image.Image"]]]: + """ + This function formats samples from the OBELICS dataset + Returns: + Dict[str, Any]: The transformed sample with the following fields: + - images: List[PIL.Image.Image] with the loaded images + - text: str with the text of the sample ready to be tokenized including the image tokens + Example: + >>> formatted_sample = format_obelics(sample, image_token="<|image|>") + >>> print(formatted_sample["text"]) + ... "<|image|><|image|><|image|> The elephant look cute!<|image|><|image|> The cats are sad :(" + """ + sample_images = [image for image in sample["images"] if image is not None] + sample_text = [ + text if text is not None else image_token for text in sample["texts"] + ] + return { + "images": [load_image(image) for image in sample_images], + "text": "".join(map(str, sample_text)), + } + + +@dataclass +class DatasetConfig: + path: str + loader: Callable + sample_processor: Callable + + +# Add your dataset here here - more information at docs/datasets.md +MM_DATASETS = { + "obelics": DatasetConfig( + path="HuggingFaceM4/OBELICS", + loader=_load_obelics_dataset, + sample_processor=_process_obelics_sample, + ), +} + + +def _validate_mm_dataset( + dataset_name: str, dataset_path: str = None +) -> tuple[str, Callable, Callable]: + """Validate dataset name and path.""" + if dataset_name not in MM_DATASETS: + raise ValueError( + f"Dataset {dataset_name} is not supported. " + f"Supported datasets are: {list(MM_DATASETS.keys())}" + ) + + config = MM_DATASETS[dataset_name] + path = dataset_path or config.path + logger.info(f"Preparing {dataset_name} dataset from {path}") + return path, config.loader, config.sample_processor + + +class MultiModalDataset(IterableDataset, Stateful): + """PyTorch MultiModal Dataset. + + Args: + dataset_name (str): name of the dataset to load + tokenizer (Tokenizer): + Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method. + world_size (int): number of data parallel processes participating in training + rank (int): rank of the current data parallel process + infinite (bool): whether to loop infinitely over the dataset + + We currently ONLY support the OBELICS dataset + + Example use: + >>> ds = MultiModalDataset(dataset_name="OBELICS", tokenizer=tokenizer) + >>> for batch in Dataloader(ds, batch_size=8): + print(f"Batch size: {len(batch)}") + Batch size: 8 + """ + + def __init__( + self, + dataset_name: str, + dataset_path: Optional[str], + tokenizer: Tokenizer, + image_token: str = "<|image|>", + tile_size: int = 448, + max_num_tiles: int = 4, + seq_len: int = 2048, + dp_rank: int = 0, + dp_world_size: int = 1, + infinite: bool = False, + ) -> None: + # Force lowercase for consistent comparison + dataset_name = dataset_name.lower() + + path, dataset_loader, sample_processor = _validate_mm_dataset( + dataset_name, dataset_path + ) + ds = dataset_loader(path) + + # TODO: support shuffling + self.dataset_name = dataset_name + self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) + self._tokenizer = tokenizer + self.seq_len = seq_len + self.infinite = infinite + self._sample_processor = sample_processor + self.image_token = ( + image_token # TODO(tj.solergibert) Add `image_token` to JobConfig + ) + # TODO(tj.solergibert) Add `tile_size` & `max_num_tiles` to JobConfig + self.transform_image = CLIPTransform( + image_mean=( + 0.48145466, + 0.4578275, + 0.40821073, + ), # TODO(tj.solergibert) What should we do with `image_mean` & `image_std`?, + image_std=(0.26862954, 0.26130258, 0.27577711), + tile_size=tile_size, + possible_resolutions=None, + max_num_tiles=max_num_tiles, + resample="bilinear", + resize_to_max_canvas=False, + ) + + # variables for checkpointing + self._sample_idx = 0 + + def __iter__(self): + + while True: + for sample in self._get_data_iter(): + try: + sample = self._sample_processor( + sample, image_token=self.image_token + ) + except Exception: + continue + self._sample_idx += 1 + + # CLIP Transform + encoder_input = {"images": [], "aspect_ratio": []} + for image in sample["images"]: + out = self.transform_image(image) + encoder_input["images"].append(out["image"]) + encoder_input["aspect_ratio"].append(out["aspect_ratio"]) + sample["encoder_input"] = encoder_input + + # Tokenize + tokens = self._tokenizer.encode( + sample["text"], + bos=True, + eos=True, + allowed_special=set(["<|image|>"]), + ) + sample["input_ids"] = torch.LongTensor(tokens[:-1]) + sample["labels"] = torch.LongTensor(tokens[1:]) + # Mask BOS, EOS & image tokens from the loss + sample["labels"] = torch.where( + torch.isin( + sample["labels"], + torch.LongTensor( + [ + self._tokenizer.bos_id, + self._tokenizer.eos_id, + self._tokenizer.image_id, + ] + ), + ), + IGNORE_INDEX, + sample["labels"], + ) + # Truncate + sample["input_ids"], sample["labels"] = ( + sample["input_ids"][: self.seq_len], + sample["labels"][: self.seq_len], + ) + yield sample + + if not self.infinite: + logger.warning(f"Dataset {self.dataset_name} has run out of data") + break + else: + # Reset offset for the next iteration + self._sample_idx = 0 + logger.warning(f"Dataset {self.dataset_name} is being re-looped") + + def _get_data_iter(self): + if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): + return iter([]) + + it = iter(self._data) + for _ in range(self._sample_idx): + next(it) + return it + + def load_state_dict(self, state_dict): + self._sample_idx = state_dict["sample_idx"] + + def state_dict(self): + return {"sample_idx": self._sample_idx} + + +def build_mm_dataloader( + dp_world_size: int, + dp_rank: int, + tokenizer: Tokenizer, + job_config: JobConfig, + infinite: bool = True, +) -> ParallelAwareDataloader: + """Build a data loader for HuggingFace datasets.""" + dataset_name = job_config.training.dataset + dataset_path = job_config.training.dataset_path + batch_size = job_config.training.batch_size + seq_len = job_config.training.seq_len + pad_max_tiles = 4 # TODO(tj.solergibert) Add `pad_max_tiles` to JobConfig + padding_idx = 128004 # TODO(tj.solergibert) Add `padding_idx` to JobConfig + + hf_ds = MultiModalDataset( + dataset_name=dataset_name, + dataset_path=dataset_path, + tokenizer=tokenizer, + seq_len=seq_len, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + infinite=infinite, + ) + + collate_fn = MultiModalCollator( + padding_idx=padding_idx, pad_max_tiles=pad_max_tiles + ) + + return ParallelAwareDataloader( + dataset=hf_ds, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + batch_size=batch_size, + collate_fn=collate_fn, + ) diff --git a/torchtitan/experiments/multimodal/model.py b/torchtitan/experiments/multimodal/model.py new file mode 100644 index 0000000000000000000000000000000000000000..419b3f8ab718923ac1478f951e22b9bd6391be5d --- /dev/null +++ b/torchtitan/experiments/multimodal/model.py @@ -0,0 +1,1464 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Llama 3 is licensed under the LLAMA 3 Community License, +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +@dataclass +class ModelArgs: + # encoder part + encoder_embed_dim: int = 4096 + encoder_num_layers: int = 32 + num_layers_projection: int = 32 + encoder_num_heads: int = 32 + encoder_num_kv_heads: Optional[int] = None + patch_size: int = 1 + tile_size: int = 128 + max_num_tiles: int = 8 + activation: nn.Module = nn.GELU() + # in_channels (int): The number of image input channels. + in_channels: int = 3 + # return_intermediates (Optional[List[int]]): The indices of hidden layers to return. + # If provided, it will return the intermediate results of the transformer layers + # before they go through a next layer. For example, ``return_intermediates=[0,3]`` + # will return the tokens before they go through the first and fourth layers. + return_intermediates: Optional[List[int]] = None + is_causal: bool = True + + # decoder part + decoder_embed_dim: int = 4096 # This is for linear projection to convert the output of encoder to decoder + fusion_interval: int = 1 # This is the interval of layers that are used for fusion + num_special_tokens: int = 2 # This is the number of special tokens in the tokenizer + decoder_num_layers: int = 16 + decoder_num_heads: int = 32 + decoder_num_kv_heads: Optional[int] = None + + # common part + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + + max_seq_len: int = 2048 + # If `True`, then each transformer block init uses its layer ID, and if + # `False`, each uses the total number of transformer blocks + depth_init: bool = True + norm_type: str = "rmsnorm" + + +class Fp32LayerNorm(nn.LayerNorm): + """ + Wrapper around :class:`~torch.nn.LayerNorm` to support mixed-precision training. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor. + Returns: + torch.Tensor: The normalized output tensor having the same shape as ``x``. + """ + output = nn.functional.layer_norm( + x.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(x) + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, num_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=num_rep)""" + bsz, seq_len, num_kv_heads, head_dim = x.shape + if num_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bsz, seq_len, num_kv_heads, num_rep, head_dim) + .reshape(bsz, seq_len, num_kv_heads * num_rep, head_dim) + ) + + +class Attention(nn.Module): + """ + Multi-head attention module. + + Args: + model_args (ModelArgs): Model configuration arguments. + + Attributes: + num_kv_heads (int): Number of key and value heads. + num_heads (int): Number of query heads. + num_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.num_heads = model_args.encoder_num_heads + self.num_kv_heads = ( + model_args.encoder_num_heads + if model_args.encoder_num_kv_heads is None + else model_args.encoder_num_kv_heads + ) + self.num_rep = self.num_heads // self.num_kv_heads + self.head_dim = model_args.encoder_embed_dim // model_args.encoder_num_heads + + self.wq = nn.Linear( + model_args.encoder_embed_dim, + model_args.encoder_num_heads * self.head_dim, + bias=False, + ) + self.wk = nn.Linear( + model_args.encoder_embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.wv = nn.Linear( + model_args.encoder_embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.wo = nn.Linear( + model_args.encoder_num_heads * self.head_dim, + model_args.encoder_embed_dim, + bias=False, + ) + self.is_causal = model_args.is_causal + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `num_heads` (or `num_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + if ( + freqs_cis is not None + ): # Only used in the self attention layers for text decoder + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if num_kv_heads < num_heads + keys = repeat_kv(xk, self.num_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.num_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + # we use casual mask for training + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=self.is_causal) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. + activation: (nn.Module): Activation function to use. Defaults to nn.silu. + + Attributes: + w1 (Linear): Linear transformation for the first layer, which projects input from input dim to + hidden dim, and multiplies by the projection from w3 for activation and second layer. + w2 (Linear): Linear transformation for the second layer. + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + activation: nn.Module = nn.SiLU(), + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.activation = activation + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + + def forward(self, x): + return self.w2(self.activation(self.w1(x))) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.w2.weight, mean=0.0, std=init_std) + + +class TanhGate(nn.Module): + """Implements a basic learnable gate to scale layer outputs""" + + def __init__(self) -> None: + super().__init__() + self.scale = nn.Parameter(torch.zeros(1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor to gate + + Returns: + torch.Tensor: The output tensor after gating. Has the same shape as ``x``. + """ + return x * self.scale.tanh() + + +class TilePositionalEmbedding(nn.Module): + """ + Positional embedding for tiles, different for every tile, same for every token within a tile. + + For details, please check the documentation of :class:`ViT`. + + Args: + max_num_tiles (int): The maximum number of tiles an image can be divided into. + emb_dim (int): The dimensionality of each tile embedding. + """ + + def __init__( + self, + max_num_tiles: int, + emb_dim: int, + ): + super().__init__() + self.max_num_tiles = max_num_tiles + self.emb_dim = emb_dim + self.embedding = nn.Parameter( + torch.randn(max_num_tiles, max_num_tiles, 1, emb_dim) / math.sqrt(emb_dim) + ) + self.gate = nn.Parameter(torch.zeros(1)) + + def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor): + """ + args: + x (torch.Tensor): torch.Tensor with shape (bsz * num_imgs, num_tiles, num_tokens, emb_dim). + aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * num_imgs, 2), + representing the aspect ratio of the image before tile-cropping, e.g. (2,1). + returns: + torch.Tensor: The input tensor with added positional embeddings. + """ + bsz_and_num_imgs, num_tiles, num_tokens, emb_dim = x.shape + + for batch_idx, (num_tiles_h, num_tiles_w) in enumerate(aspect_ratio): + # When we batch images, all are padded to the same amount of tiles. + # The aspect_ratio lets us know the non padded tiles for each image. + # We only add positional encoding to those. + num_non_padded_tiles = int(num_tiles_h * num_tiles_w) + + # We get only the positional encoding for non padded tiles, + # i.e. num_tiles_h, num_tiles_w. + pos_embed = self.embedding[:num_tiles_h, :num_tiles_w, :, :] + + # Add pos encoding to the non padded tiles. + pos_embed = pos_embed.reshape(num_non_padded_tiles, 1, self.emb_dim) + x[batch_idx, :num_non_padded_tiles, :, :] += pos_embed * self.gate.tanh() + + return x + + +class TokenPositionalEmbedding(nn.Module): + """ + Token positional embedding for images, different for every token in an image. + + Args: + emb_dim (int): The dimensionality of each token embedding. + tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, + the size of the input image. In this case, the function will consider your image as a single tile. + patch_size (int): The size of each patch. Used to divide the tiles into patches. + E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches + with shape (40, 40) each. + """ + + def __init__(self, emb_dim: int, tile_size: int, patch_size: int) -> None: + super().__init__() + patch_grid_size = tile_size // patch_size + scale = emb_dim**-0.5 + self.positional_embedding = nn.Parameter( + scale * torch.randn((patch_grid_size**2 + 1, emb_dim)) # +1 for CLS token + ) + + def forward(self, x: torch.Tensor, *args: Tuple[Any]) -> torch.Tensor: + """ + Args: + x (torch.Tensor): torch.Tensor with shape (..., num_tokens, emb_dim) + *args (Tuple[Any]): Optional args. + + Returns: + torch.Tensor: The input tensor with added positional embeddings. + """ + return x + self.positional_embedding + + +class TiledTokenPositionalEmbedding(nn.Module): + """ + + Token positional embedding for tiled images. There are two positional embeddings in this module: + + * local_token_positional_embedding: same for every tile, different for every token. Equivalent \ + to :class:`TokenPositionalEmbedding`, but gated. + * global_token_positional_embedding: different for every tile, different for every token. + + Notice that tile is different from patch (token). For details, please check the documentation of + :class:`ViT`. + + Args: + max_num_tiles (int): The maximum number of tiles an image can be divided into. + emb_dim (int): The dimensionality of each token embedding. + tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, + the size of the input image. In this case, the function will consider your image as a single tile. + patch_size (int): The size of each patch. Used to divide the tiles into patches. + E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches + with shape (40, 40) each. + """ + + def __init__( + self, max_num_tiles: int, emb_dim: int, tile_size: int, patch_size: int + ) -> None: + super().__init__() + patch_grid_size = tile_size // patch_size + self.num_tokens_per_tile = patch_grid_size**2 + 1 # +1 for cls token + scale = emb_dim**-0.5 + + # different for every token, same for every tile + self.local_token_positional_embedding = nn.Parameter( + scale * torch.randn((patch_grid_size**2 + 1, emb_dim)) # +1 for CLS token + ) + + # different for every token, different for every tile + self.global_token_positional_embedding = nn.Parameter( + scale + * torch.randn( + max_num_tiles, + max_num_tiles, + self.num_tokens_per_tile, + emb_dim, + ) + ) + + self.gate = nn.Parameter(torch.zeros(1)) + + def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): torch.Tensor with shape (bsz * num_imgs, num_tiles, num_tokens, emb_dim). + aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * num_imgs, 2), + where aspect_ratio[k] represents the aspect ratio of the k^th image + of the batch before tile-cropping, e.g. aspect_ratio[k] = (2,1). + Returns: + torch.Tensor: The input tensor with added positional embeddings. + """ + bsz_and_num_imgs, num_tiles, num_tokens, emb_dim = x.shape + + # apply local position embedding (same for every tile) + x = x + (self.local_token_positional_embedding * (1 - self.gate.tanh())) + + # apply global positional embedding (different for every tile) + x = x.view(bsz_and_num_imgs, num_tiles, num_tokens, emb_dim) + for batch_idx, (num_tiles_h, num_tiles_w) in enumerate(aspect_ratio): + # When we batch images, all are padded to the same amount of tiles. + # The aspect_ratio lets us know the non padded tiles for each image. + # We only add positional encoding to those. + num_non_padded_tiles = int(num_tiles_h * num_tiles_w) + + # We get only the positional encoding for non padded tiles, + # i.e. num_tiles_h, num_tiles_w. + pos_embed = self.global_token_positional_embedding[ + :num_tiles_h, :num_tiles_w, :, : + ] + + # Add pos encoding to the non padded tiles. + pos_embed = pos_embed.reshape( + num_non_padded_tiles, self.num_tokens_per_tile, emb_dim + ) + pos_embed = pos_embed * self.gate.tanh() + x[batch_idx, :num_non_padded_tiles, :, :] += pos_embed + + return x + + +class Conv2dModule(torch.nn.Module): + """Conv2D Module. + This is like Conv2D in PyTorch except: + + - PyTorch Conv2D outputs shape (*, out_channels, h_out, w_out), while this module + outputs (*, h_out * w_out, out_channels). + - We implement the conv as an unfold -> permute -> linear, where we can column-wise + shard the linear. + + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. This module also assumes a square kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + bias: bool = False, + ) -> None: + super().__init__() + self._unfold = torch.nn.Unfold( + kernel_size=(kernel_size, kernel_size), stride=stride + ) + self._linear = torch.nn.Linear( + in_channels * kernel_size * kernel_size, + out_channels, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Input: (bsz, in_channels, width, height) + # Output: (bsz, in_channels * kernel_size * kernel_size, num_tokens) + x = self._unfold(x) + x = x.permute(0, 2, 1) + # Output: (bsz, num_tokens, out_channels), when stride = kernel_size, + # num_tokens = grid ** 2 and out_channels is emd_dim. + return self._linear(x) + + +class VitTransformerBlock(nn.Module): + def __init__( + self, + model_args: ModelArgs, + attn_scale: Optional[nn.Module] = None, + mlp_scale: Optional[nn.Module] = None, + ): + super().__init__() + self.attn = Attention(model_args) + self.ln_attn = Fp32LayerNorm(model_args.encoder_embed_dim, eps=1e-5) + self.mlp = FeedForward( + dim=model_args.encoder_embed_dim, + hidden_dim=4 * model_args.encoder_embed_dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + activation=model_args.activation, + ) + self.ln_mlp = Fp32LayerNorm(model_args.encoder_embed_dim, eps=1e-5) + self.attn_scale = attn_scale or nn.Identity() + self.mlp_scale = mlp_scale or nn.Identity() + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ): + bsz, seq_len, emd_dim = x.shape + # x = x.view(bsz * seq_len, emd_dim) + x = x + self.attn_scale(self.attn(x=self.ln_attn(x), freqs_cis=None)) + x = x + self.mlp_scale(self.mlp(self.ln_mlp(x))) + # return x.view(bsz, seq_len, emd_dim) + return x + + +class CLSEmbedding(nn.Module): + """ + Adds a CLS token to every tile of an image in the beginning of each token. + + Args: + emb_dim (int): The dimensionality of the input patch embedding. + """ + + def __init__(self, emb_dim: int) -> None: + super().__init__() + + scale = emb_dim**-0.5 + self.weight = nn.Parameter(scale * torch.randn(emb_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + # add 1 CLS token to every tile + bsz_and_num_imgs, num_tiles, _, emb_dim = x.shape + cls_emb = self.weight.broadcast_to(bsz_and_num_imgs, num_tiles, 1, emb_dim) + return torch.cat([cls_emb, x], dim=2) + + +class Vit(nn.Module): + """ + Implementation of the ViT architecture (https://arxiv.org/abs/2010.11929), + with support for tile-cropped images, outputting of hidden layers. + + (credit for the documentation below: `vision_transformer.py + + `_). + + ViT is a transformer architecture that takes in images and outputs N embedded tokens that + represent this image. Each image is divided into **patches** by a convolution. + These patches are flattened and subsequently treated as **tokens** by the transformer. + + To further enhance the performance of ViT and avoid downscaling images, we support tile-cropped images, + which are images divided into **tiles** during the preprocessing stage. For example, instead of + downscaling an 800x400 image to fit 400x400, we may crop it into two 400x400 tiles, + if the ``tile_size=400``. + + Each of these tiles is further broken down into patches by a convolution operation. For example, if + your ``patch_size=40``, then each (400, 400) tile will become a grid of 10x10 patches, and your whole image will have + num_tiles * n_tokens -> num_tiles * (10x10 patches + 1 CLS token) -> num_tiles * 101. + + Before the transformer layers, a CLS token is added to each tile as the first token. + In transformers, a token called CLS is a special token that is added to the beginning of each sequence. + This token can be used to represent the whole input, instead of using a pooling operation, for example. + + To help the model "see" the whole image, we use positional embeddings. If your image + was tile-cropped, then you need to use tile positional embeddings: + + - token_pos_embedding (tiled): :class:`TiledTokenPositionalEmbedding` + - pre_tile_pos_embed: :class:`TilePositionalEmbedding` + - post_tile_pos_embed: :class:`TilePositionalEmbedding` + + Otherwise, pre and post tile_pos_embed should be None and all you need is a simple + token positional embedding: + + - token_pos_embedding (not tiled): :class:`TokenPositionalEmbedding` + + All images will be considered as a stack of tiles, even if your image was not tile-cropped. In such cases, + your image would be composed of a single tile. + + In summary: + + 1) An image is broken down into tiles during preprocessing. + 2) In the ViT, the tiles will be broken down into patches. + 3) The patches will be flattened and transformed. We call them tokens, because that's how the transformer sees them. + + Image: shape (8x8) + + .. code-block:: text + + | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | + | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | + | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | + | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | + | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | + | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | + | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | + | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | + + Tiles: shape (4,4,4) # (num_tiles, tile_size, tile_size) + + .. code-block:: text + + | 1 | 2 | 3 | 4 | | 5 | 6 | 7 | 8 | + | 9 | 10 | 11 | 12 | | 13 | 14 | 15 | 16 | + | 17 | 18 | 19 | 20 | | 21 | 22 | 23 | 24 | + | 25 | 26 | 27 | 28 | | 29 | 30 | 31 | 32 | + + | 33 | 34 | 35 | 36 | | 37 | 38 | 39 | 40 | + | 41 | 42 | 43 | 44 | | 45 | 46 | 47 | 48 | + | 49 | 50 | 51 | 52 | | 53 | 54 | 55 | 56 | + | 57 | 58 | 59 | 60 | | 61 | 62 | 63 | 64 | + + Patches: shape (4,4,2,2) # (num_tiles, num_patches_per_tile, patch_size, patch_size) + + .. code-block:: text + + | 1 | 2 | | 3 | 4 | | 5 | 6 | | 7 | 8 | + | 9 | 10 | | 11 | 12 | | 13 | 14 | | 15 | 16 | + + | 17 | 18 | | 19 | 20 | | 21 | 22 | | 23 | 24 | + | 25 | 26 | | 27 | 28 | | 29 | 30 | | 31 | 32 | + + | 33 | 34 | | 35 | 36 | | 37 | 38 | | 39 | 40 | + | 41 | 42 | | 43 | 44 | | 45 | 46 | | 47 | 48 | + + | 49 | 50 | | 51 | 52 | | 53 | 54 | | 55 | 56 | + | 57 | 58 | | 59 | 60 | | 61 | 62 | | 63 | 64 | + + token: shape (4, 4, 4) # (num_tiles, num_patches_per_tile, emb_dim) + + .. code-block:: text + + | 1 | 2 | 9 | 10 | | 3 | 4 | 11 | 12 | | 17 | 18 | 25 | 26 | | 19 | 20 | 27 | 28 | + | ... continuation of data ... + | ... continuation of data ... + | 37 | 38 | 45 | 46 | | 39 | 40 | 47 | 48 | | 53 | 54 | 61 | 62 | | 55 | 56 | 63 | 64 | + + For the positional embeddings: + + Same for every tile, different for every token. + + - :class:`TokenPositionalEmbedding` + + .. code-block:: text + + | 1 | 2 | 3 | 4 | | 1 | 2 | 3 | 4 | + | 9 | 10 | 11 | 12 | | 9 | 10 | 11 | 12 | + | 17 | 18 | 19 | 20 | | 17 | 18 | 19 | 20 | + | 25 | 26 | 27 | 28 | | 25 | 26 | 27 | 28 | + + | 1 | 2 | 3 | 4 | | 1 | 2 | 3 | 4 | + | 9 | 10 | 11 | 12 | | 9 | 10 | 11 | 12 | + | 17 | 18 | 19 | 20 | | 17 | 18 | 19 | 20 | + | 25 | 26 | 27 | 28 | | 25 | 26 | 27 | 28 | + + Different for every tile, different for every token. + + - :class:`TiledTokenPositionalEmbedding` + + .. code-block:: text + + | 1 | 2 | | 3 | 4 | | 5 | 6 | | 7 | 8 | + | 9 | 10 | | 11 | 12 | | 13 | 14 | | 15 | 16 | + + | 17 | 18 | | 19 | 20 | | 21 | 22 | | 23 | 24 | + | 25 | 26 | | 27 | 28 | | 29 | 30 | | 31 | 32 | + + | 33 | 34 | | 35 | 36 | | 37 | 38 | | 39 | 40 | + | 41 | 42 | | 43 | 44 | | 45 | 46 | | 47 | 48 | + + | 49 | 50 | | 51 | 52 | | 53 | 54 | | 55 | 56 | + | 57 | 58 | | 59 | 60 | | 61 | 62 | | 63 | 64 | + + different for every tile, same for every token within a tile. + + - :class:`TilePositionalEmbedding` + + .. code-block:: text + + | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | + | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | + | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | + | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | + + | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | + | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | + | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | + | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | + + Args: + model_args (ModelArgs): The model args. + + Raises: + ValueError: If `patch_size` is not greater than 0. + ValueError: If `len(return_intermediates)` is greater than `num_layers`. + """ + + def __init__( + self, + model_args: ModelArgs, + ): + super().__init__() + if model_args.patch_size <= 0: + raise ValueError(f"kernel size of conv {model_args.patch_size} must be > 0") + if model_args.return_intermediates and ( + len(model_args.return_intermediates) > model_args.encoder_num_layers + ): + raise ValueError( + "len(return_intermediates) must be <= num_layers." + f" Got {model_args.return_intermediate=} and {model_args.encoder_num_layers=}" + ) + + # For test validation purposes + patch_grid_size = model_args.tile_size // model_args.patch_size + self.patches_per_tile = patch_grid_size**2 + + self.return_intermediates = model_args.return_intermediates + + self.conv = Conv2dModule( + in_channels=model_args.in_channels, + out_channels=model_args.encoder_embed_dim, + kernel_size=model_args.patch_size, + stride=model_args.patch_size, + bias=False, + ) + + self.ln_post = Fp32LayerNorm(model_args.encoder_embed_dim) + self.ln_pre = Fp32LayerNorm(model_args.encoder_embed_dim) + self.transformer_layers = nn.ModuleList( + [ + VitTransformerBlock(model_args) + for _ in range(model_args.encoder_num_layers) + ] + ) + + self.class_embedding = CLSEmbedding(model_args.encoder_embed_dim) + # pre and post tile position embedding + if model_args.max_num_tiles > 1: + self.pre_tile_pos_embed = TilePositionalEmbedding( + max_num_tiles=model_args.max_num_tiles, + emb_dim=model_args.encoder_embed_dim, + ) + self.post_tile_pos_embed = TilePositionalEmbedding( + max_num_tiles=model_args.max_num_tiles, + emb_dim=model_args.encoder_embed_dim, + ) + self.token_pos_embedding = TokenPositionalEmbedding( + emb_dim=model_args.encoder_embed_dim, + tile_size=model_args.tile_size, + patch_size=model_args.patch_size, + ) + else: + self.pre_tile_pos_embed = None + self.post_tile_pos_embed = None + self.token_pos_embedding = TiledTokenPositionalEmbedding( + max_num_tiles=model_args.max_num_tiles, + emb_dim=model_args.encoder_embed_dim, + tile_size=model_args.tile_size, + patch_size=model_args.patch_size, + ) + + def forward( + self, images: torch.Tensor, aspect_ratio: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Processes images and returns the tokens and hidden states. + + Multiple images per sample: we add a dimension num_imgs to the input. This is useful when a single + sample constains multiple images, for example: + + - sample 1: " what animal is this?" + - sample 2: "I like more than " + + In this case, sample 1 has one image, and sample 2 has two images. max_n_imgs = max(2,1) = 2. + So your input should have shape (bsz=2, num_imgs=2, num_tiles, num_channels, tile_size_w, tile_size_h). + + Notice that to batch it, you will have to pad num_imgs to max_num_imgs and max_num_tiles. + + Args: + images (torch.Tensor): torch.Tensor with shape (bsz, num_imgs, num_tiles, num_channels, tile_size_w, tile_size_h). + aspect_ratio (Optional[torch.Tensor]): torch.Tensor with shape (bsz, n_imgs, 2). If all + images have a single tile, i.e. they were not tile-cropped, it should be None. + Used to calculate the positional embeddings for the tiles. + + Returns: + Tuple[torch.Tensor, List[torch.Tensor]]: A tuple: (x, hidden_states), + where x is a torch.tensor of shape (bsz, num_imgs, num_tiles, num_tokens, emb_dim) and + hidden_states has shape is a list of len(out_indices) torch.tensor with shape + (bsz, num_imgs, num_tiles, num_tokens, emb_dim). + + Raises: + ValueError: If aspect_ratio is None, but num_tiles > 1 in the batch. + """ + + bsz, num_imgs, num_tiles, num_channels, width, height = images.shape + + if aspect_ratio is None: + aspect_ratio = torch.ones((bsz * num_imgs, 2), dtype=torch.int).to( + device=images.device + ) + if num_tiles > 1: + raise ValueError( + f"aspect_ratio was not provided, but found num_tiles > 1 " + f"for {images.shape=}. Please provide aspect_ratio." + ) + + aspect_ratio = aspect_ratio.reshape(bsz * num_imgs, 2) + + # patch embedding + images = images.view(bsz * num_imgs * num_tiles, num_channels, width, height) + # The op is not behaving completely same as conv2d it contains a permute inside. + x = self.conv(images) # shape = [*, emb_dim, grid ** 2] + _, num_tokens, emb_dim = x.shape # num_tokens = grid ** 2 + x = x.reshape(bsz * num_imgs, num_tiles, num_tokens, emb_dim) + + # tile embeddings + if self.pre_tile_pos_embed: + x = self.pre_tile_pos_embed(x, aspect_ratio) + + # apply cls token + x = self.class_embedding(x) + num_tokens += 1 + + # apply position embeddings + x = self.token_pos_embedding(x, aspect_ratio) + + x = self.ln_pre(x) + x = x.view(bsz * num_imgs, -1, emb_dim) + + int_x = [] # intermediate outputs + for layer_idx, transformer_layer in enumerate(self.transformer_layers): + if layer_idx in self.return_intermediates: + h = x.view(bsz, num_imgs, num_tiles, num_tokens, emb_dim) + int_x.append(h) + x = transformer_layer(x) + + x = self.ln_post(x) + x = x.view(bsz * num_imgs, num_tiles, num_tokens, emb_dim) + + if self.post_tile_pos_embed: + x = self.post_tile_pos_embed(x, aspect_ratio) + + x = x.view(bsz, num_imgs, num_tiles, num_tokens, emb_dim) + return x, int_x + + +class Projection(nn.Module): + """Projection transformer to adapt the output of a + encoder (CLIP) to the decoder model. + """ + + def __init__( + self, + model_args: ModelArgs, + ) -> None: + super().__init__() + self.transformer_layers = nn.ModuleList( + [ + VitTransformerBlock( + model_args, attn_scale=TanhGate(), mlp_scale=TanhGate() + ) + for _ in range(model_args.num_layers_projection) + ] + ) + + self.num_hidden = len(model_args.return_intermediates or []) + self.output = nn.Linear( + model_args.encoder_embed_dim * (self.num_hidden + 1), + model_args.decoder_embed_dim, + ) + + def forward( + self, + x: torch.Tensor, + hidden_states: Optional[List[torch.Tensor]] = None, + ) -> torch.Tensor: + bsz, num_imgs, num_tiles, num_tokens, emb_dim = x.shape + + # apply transformer layers + x = x.view(bsz * num_imgs, num_tiles * num_tokens, emb_dim) + for layer in self.transformer_layers: + x = layer(x) + x = x.view(bsz, num_imgs, num_tiles, num_tokens, emb_dim) + + # interleave hidden states and cat with x + if self.num_hidden > 0: + assert hidden_states is not None + hidden_states = torch.stack(hidden_states, dim=-1) + hidden_states = hidden_states.view(bsz, num_imgs, num_tiles, num_tokens, -1) + x = torch.cat([x, hidden_states], dim=-1) + + # [bsz x seq x decoder_emb_dim] + return self.output(x).reshape(bsz, num_imgs * num_tiles * num_tokens, -1) + + +class VisionEncoder(nn.Module): + """Vision encoder model for Llama 3.2 Vision. This combines a vision + encoder with a projection. We define two different components. + + Args: + model_args (ModelArgs): configs for the vision encoder. + """ + + def __init__(self, model_args: ModelArgs) -> None: + super().__init__() + self.vit = Vit(model_args) + self.proj = Projection(model_args) + + def forward( + self, images: torch.Tensor, aspect_ratio: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Args: + images (torch.Tensor): + Image tensor with shape [bsz x num_imgs x num_tiles x num_channels x width x height]. + aspect_ratio (Optional[torch.Tensor]): Tensor with shape [bsz x num_imgs x 2]. If all + images have a single tile, i.e. they were not tile-cropped, it should be None. + Used to calculate the positional embeddings for the tiles. + Returns: + Tensor: output tensor of a sequence of embedings [bsz x seq_len x decoder_emb_dim] + where sequence length is num_imgs*num_tiles+num_embeds + """ + return self.proj(*self.vit(images, aspect_ratio)) + + +class FeedForwardForDecoder(nn.Module): + """ + FeedForward module for the decoder. It's different from the one in the encoder. + This is the component which is orignally used in llama3. + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class SelfAttention(nn.Module): + """ + Multi-head self attention module with rotary position. + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.num_heads = model_args.decoder_num_heads + self.num_kv_heads = ( + model_args.decoder_num_heads + if model_args.decoder_num_kv_heads is None + else model_args.decoder_num_kv_heads + ) + self.n_rep = self.num_heads // self.num_kv_heads + self.head_dim = model_args.decoder_embed_dim // model_args.decoder_num_heads + + self.wq = nn.Linear( + model_args.decoder_embed_dim, + model_args.decoder_num_heads * self.head_dim, + bias=False, + ) + self.wk = nn.Linear( + model_args.decoder_embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.wv = nn.Linear( + model_args.decoder_embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.wo = nn.Linear( + model_args.decoder_num_heads * self.head_dim, + model_args.decoder_embed_dim, + bias=False, + ) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `num_heads` (or `num_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if num_kv_heads < num_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + # we use casual mask for training + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class CrossAttention(nn.Module): + """ + Multi-head cross attention module. + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.num_heads = model_args.decoder_num_heads + self.num_kv_heads = ( + model_args.decoder_num_heads + if model_args.decoder_num_kv_heads is None + else model_args.decoder_num_kv_heads + ) + self.n_rep = self.num_heads // self.num_kv_heads + self.head_dim = model_args.decoder_embed_dim // model_args.decoder_num_heads + + self.wq = nn.Linear( + model_args.decoder_embed_dim, + model_args.decoder_num_heads * self.head_dim, + bias=False, + ) + self.wk = nn.Linear( + model_args.decoder_embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.wv = nn.Linear( + model_args.decoder_embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.wo = nn.Linear( + model_args.decoder_num_heads * self.head_dim, + model_args.decoder_embed_dim, + bias=False, + ) + self.q_norm = nn.RMSNorm(self.head_dim, eps=1e-05) + self.k_norm = nn.RMSNorm(self.head_dim, eps=1e-05) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + encoder_input: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ): + bs, seqlen_x, _ = x.shape + seqlen_y = encoder_input.shape[1] + xq, xk, xv = self.wq(x), self.wk(encoder_input), self.wv(encoder_input) + + # Use -1 instead of `num_heads` (or `num_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen_x, -1, self.head_dim) + xk = xk.view(bs, seqlen_y, -1, self.head_dim) + xv = xv.view(bs, seqlen_y, -1, self.head_dim) + + # repeat k/v heads if num_kv_heads < num_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen_y, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen_y, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen_x, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen_y, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen_y, head_dim) + + xq = self.q_norm(xq) + xk = self.k_norm(xk) + + # we use casual mask for training + output = F.scaled_dot_product_attention( + xq, xk, xv, attn_mask=mask, is_causal=False + ) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen_x, n_local_heads, head_dim) + output = output.view(bs, seqlen_x, -1) + return self.wo(output) + + +class DecoderTransformerSelfAttnBlock(nn.Module): + def __init__( + self, + model_args: ModelArgs, + ): + super().__init__() + self.attn = SelfAttention(model_args) + self.ln_attn = nn.RMSNorm(model_args.decoder_embed_dim, eps=1e-5) + self.mlp = FeedForwardForDecoder( + dim=model_args.decoder_embed_dim, + hidden_dim=4 * model_args.decoder_embed_dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + self.ln_mlp = nn.RMSNorm(model_args.decoder_embed_dim, eps=1e-5) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + **kwargs: Dict, + ): + bsz, seq_len, emd_dim = x.shape + x = x + self.attn(self.ln_attn(x), freqs_cis) + x = x + self.mlp(self.ln_mlp(x)) + return x + + +class DecoderTransformerCrossAttnBlock(nn.Module): + def __init__( + self, + model_args: ModelArgs, + ): + super().__init__() + self.attn = CrossAttention(model_args) + self.ln_attn = nn.RMSNorm(model_args.decoder_embed_dim) + self.mlp = FeedForward( + dim=model_args.decoder_embed_dim, + hidden_dim=4 * model_args.decoder_embed_dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + self.ln_mlp = nn.RMSNorm(model_args.decoder_embed_dim) + self.attn_scale = TanhGate() + self.mlp_scale = TanhGate() + + def _skip_mask(self, mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """Some tokens in x may not attend to any encoder inputs + due to the cross attention mask (encoder_mask). This results in + a full row of the attention matrix being masked out. + + In the example below, the word "the" is masked from every embedding. + The False value means a token can't attend to an embedding. + + .. code-block:: text + + |emb||emb||emb| + |The| F F F + |red| T F T + |car| F T T + + This results in no inputs into the softmax layer which causes a NaN. + The skip mask is used to mask the outputs of attention and + mlp resulting in the token being skipped. + + The above example would result in a skip mask of: [[True], [False], [False]] + which specifies which tokens to fully mask out. + + """ + # no skip_mask if no masking + if mask is None: + return None + # negate mask and convert to boolean mask + if mask.dtype == torch.bool: + mask = ~mask + else: + mask = torch.isneginf(mask) + # True where all elements in a row are True + mask = torch.all(mask, dim=-1, keepdim=True) + return mask + + def forward( + self, + x: torch.Tensor, + *, + encoder_input: Optional[torch.Tensor] = None, + encoder_mask: Optional[torch.Tensor] = None, + **kwargs: Dict, + ) -> torch.Tensor: + # Skip cross attention when no secondary input as it's primary purpose + # is to attend between x and encoder_input. + if encoder_input is None: + return x + + # A mask of tokens (x) with no encoder_input + skip_mask = self._skip_mask(encoder_mask) + + attn_out = self.attn( + self.ln_attn(x), + encoder_input, + mask=encoder_mask, + ) + if skip_mask is not None: + attn_out.masked_fill_(skip_mask, 0) + + h = self.attn_scale(attn_out) + x + # Norm applied before the feedforward layer + mlp_out = self.mlp(self.ln_mlp(h)) + if skip_mask is not None: + mlp_out.masked_fill_(skip_mask, 0) + + # Residual connection; shape: [batch_size, seq_length, embed_dim] + out = h + self.mlp_scale(mlp_out) + + return out + + +class FusionLayer(nn.Module): + """ + Deep Fusion model architectures combine pretrained encoder models with pretrained + language models by infusing the encoder outputs into the middle layers of the LLM. + This allows the language model to interpret the enocder outputs as text and + "understand" any modality for which you can train an decoder. To enable the language model + to adapt to the encoder outputs, the FusionLayer fuses a new learnable layer to an existing + decoder (language model) layer. This additional layer can take the encoder embeddings and + learn to combine them with the token embeddings from the decoder. + """ + + def __init__( + self, layer: nn.Module, fusion_layer: nn.Module, fusion_first: bool = True + ): + super().__init__() + self.layer = layer + self.fusion_layer = fusion_layer + + def forward(self, x: torch.Tensor, **kwargs: Dict) -> torch.Tensor: + x = self.fusion_layer(x, **kwargs) + x = self.layer(x, **kwargs) + return x + + +class FusionEmbedding(nn.Module): + """ + Fusion embedding supports training additional special tokens while keeping + the original embedding frozen. When fusing new models with a language model, + there may be some additional tokens needed to support the fused language model. For + example, adding a vision encoder might necessitate additional tokens like ``<|image|>`` + to indicate an images position in text and require learning an embedding for this token. + The FusionEmbedding keeps the original embeddings frozen while learning a much smaller + second embedding for the additional tokens. During forward this module routes + the tokens to the appropriate embedding table. + """ + + def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> None: + super().__init__() + self.embedding = nn.Embedding(vocab_size, embed_dim) + self.fusion_embedding = nn.Embedding(fusion_vocab_size, embed_dim) + self.dim = embed_dim + self.num_embeddings = vocab_size + fusion_vocab_size + + def forward(self, input: torch.Tensor) -> torch.Tensor: + bsz, seq_len = input.size() + vocab_size = self.embedding.num_embeddings + + mask = input < vocab_size + # num_tokens = (input < vocab_size).sum() + tokens = torch.masked_select(input, mask) + # num_fusion_tokens = (input >= vocab_size).sum() + fusion_tokens = torch.masked_select(input, ~mask) - vocab_size + + # [batch_size x num_tokens x embed_dim] + embeds = self.embedding(tokens) + # [batch_size x num_fusion_tokens x embed_dim] + fusion_embeds = self.fusion_embedding(fusion_tokens) + + # [batch_size x seq_length x embed_dim] + out = torch.empty( + bsz, + seq_len, + self.dim, + device=self.embedding.weight.device, + dtype=self.embedding.weight.dtype, + ) + mask = mask.unsqueeze(-1).expand(bsz, seq_len, self.dim) + out.masked_scatter_(mask, embeds) + out.masked_scatter_(~mask, fusion_embeds) + return out + + +class MultimodalDecoder(nn.Module): + """Decoder multimodal model for Llama 3.2. + + Args: + model_args (ModelArgs): configs for the vision encoder. + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + + # TODO persistent should be set to false, since this buffer can be recomputed. + # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, + # compile or pipeline-tracer will not correctly handle non-persistent buffers, + # so we need to fix that. (2) if we initialize pipeline-parallel models from + # a seed checkpoint rather than calling init_weights, we need freqs_cis to be + # initialized by the checkpoint, or we need to add a separate initializer for + # just the non-persistent buffers that is called after loading checkpoints. + self.register_buffer( + "freqs_cis", self._precompute_freqs_cis(model_args), persistent=True + ) + + self.layers = [] + for idx in range(1, model_args.decoder_num_layers + 1): + # define a llama3-like decoder layer, we don't train this part. + decoder_layer = DecoderTransformerSelfAttnBlock(model_args) + # cross attention layers, mixing text and vision, + # placed every `fusion_interval` layers + if idx % model_args.fusion_interval == 0: + cross_attn_layer = DecoderTransformerCrossAttnBlock(model_args) + fusion_layer = FusionLayer( + layer=decoder_layer, fusion_layer=cross_attn_layer + ) + self.layers.append(fusion_layer) + else: + self.layers.append(decoder_layer) + + self.tok_embeddings = FusionEmbedding( + model_args.vocab_size, + model_args.num_special_tokens, + model_args.decoder_embed_dim, + ) + self.norm = nn.RMSNorm(model_args.decoder_embed_dim, eps=1e-05) + self.output = nn.Linear( + model_args.decoder_embed_dim, model_args.vocab_size, bias=False + ) + + def _precompute_freqs_cis(self, model_args) -> torch.Tensor: + return precompute_freqs_cis( + model_args.decoder_embed_dim // model_args.decoder_num_heads, + # Need to compute until at least the max token limit for generation + # (use 2x max sequence length to be safe) + model_args.max_seq_len * 2, + model_args.rope_theta, + ) + + def forward( + self, + tokens: torch.Tensor, + *, + encoder_input: Optional[torch.Tensor] = None, + encoder_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + tokens (torch.Tensor): input tensor with shape ``[b x s]`` + encoder_input (Optional[torch.Tensor]): Optional input embeds from the encoder. Shape ``[b x s_e x d_e]`` + encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between + tokens and encoder embeddings. A True value at position ``i,j`` means token ``i`` can attend + to embedding ``j`` in the decoder. Mask has shape ``[b x s x s_e]``. Default is None, + but this is required during inference if the model has been setup with any layers + which use encoder embeddings and caches have been setup. + """ + # input tensor of shape [b, s] + bsz, seq_len = tokens.shape + + # shape: [b, s, d] + h = self.tok_embeddings(tokens) + + for layer in self.layers: + # shape: [b, s, d] + h = layer( + h, + freqs_cis=self.freqs_cis, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + ) + + # shape: [b, s, d] + h = self.norm(h) + output = self.output(h).float() + + return output diff --git a/torchtitan/experiments/multimodal/requirements.txt b/torchtitan/experiments/multimodal/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e35531e566f2a925d851b9d3b8fa99645838e6e0 --- /dev/null +++ b/torchtitan/experiments/multimodal/requirements.txt @@ -0,0 +1 @@ +torchvision diff --git a/torchtitan/experiments/multimodal/tests/__init__.py b/torchtitan/experiments/multimodal/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e41cd717f6a439a9c08d76a9d0e4a54e190fc5a --- /dev/null +++ b/torchtitan/experiments/multimodal/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/experiments/multimodal/tests/test_multimodal_model.py b/torchtitan/experiments/multimodal/tests/test_multimodal_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b5acc51bb3d186674267a4fc47d9075f04122a60 --- /dev/null +++ b/torchtitan/experiments/multimodal/tests/test_multimodal_model.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from torchtitan.experiments.llama_multimodal import ( + ModelArgs, + MultimodalDecoder, + VisionEncoder, +) + +from .test_utils import fixed_init_model, fixed_init_tensor + + +@pytest.fixture +def encoder_config(): + return ModelArgs( + encoder_embed_dim=32, + encoder_num_layers=2, + encoder_num_heads=4, + tile_size=49, + patch_size=9, + max_num_tiles=4, + in_channels=3, + return_intermediates=[0, 1], + num_layers_projection=2, + decoder_embed_dim=128, + ) + + +@pytest.fixture +def decoder_config(): + return ModelArgs( + decoder_embed_dim=512, + vocab_size=10000, + fusion_interval=2, + num_special_tokens=3, + decoder_num_layers=6, + decoder_num_heads=8, + decoder_num_kv_heads=4, + max_seq_len=512, + rope_theta=50000.0, + ) + + +class TestMultimodalModelVisionEncoder: + @pytest.fixture(autouse=True) + def setup_class(self, encoder_config): + self.model_args = encoder_config + self.batch_size = 1 + self.num_imgs = 2 + self.num_tiles = 4 + self.aspect_ratio = torch.tensor([[1, 3], [2, 2]]).reshape( + self.batch_size, self.num_imgs, 2 + ) + image = torch.rand( + ( + self.batch_size, + self.num_imgs, + self.num_tiles, + self.model_args.in_channels, + self.model_args.tile_size, + self.model_args.tile_size, + ) + ) + self.image = fixed_init_tensor(image.shape, min_val=-1, max_val=1) + + def test_llama_mm_vision_encoder(self): + model = VisionEncoder(self.model_args) + fixed_init_model(model, min_val=-1, max_val=1) + output = model(self.image, self.aspect_ratio) + expected_shape = ( + self.batch_size, + self.num_imgs * self.num_tiles * (model.vit.patches_per_tile + 1), + self.model_args.decoder_embed_dim, + ) + assert ( + output.shape == expected_shape + ), f"Expected shape {expected_shape}, but got {output.shape}" + + # TODO: Need to ensure numerical stability before doing convergence test. + # output.mean() = 3.994, we need to debug why it is not close to 5.28800, which is + # the test value from the original torch tune test + # assert torch.allclose( + # output.mean(), torch.tensor(5.28800), atol=1e-3, rtol=1e-3 + # ) + + +class TestMultimodalModelDecoder: + @pytest.fixture(autouse=True) + def setup_class(self, decoder_config): + self.model_args = decoder_config + self.batch_size = 1 + self.decoder_embed_dim = self.model_args.decoder_embed_dim + self.vocab_size = self.model_args.vocab_size + self.seq_len = 128 + self.input = { + "tokens": torch.arange(self.batch_size * self.seq_len).reshape( + self.batch_size, self.seq_len + ), + "encoder_input": fixed_init_tensor( + (self.batch_size, self.seq_len, self.decoder_embed_dim), + min_val=-1, + max_val=1, + ), + "encoder_mask": None, + } + + @torch.no_grad() + def test_llama_mm_decoder(self): + model = MultimodalDecoder(self.model_args) + fixed_init_model(model, min_val=-1, max_val=1) + output = model(**self.input) + expected_shape = (self.batch_size, self.seq_len, self.vocab_size) + assert ( + output.shape == expected_shape + ), f"Expected shape {expected_shape}, but got {output.shape}" + + # TODO: Need to ensure numerical stability before doing convergence test. + # output.mean() = -0.0134, we need to debug why it is not close to -9.47548e-5, which is + # the test value from the original torch tune test + # assert torch.allclose( + # output.mean(), torch.tensor(-9.47548e-5), atol=1e-3, rtol=1e-3 + # ) diff --git a/torchtitan/experiments/multimodal/tests/test_utils.py b/torchtitan/experiments/multimodal/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7c3817db8699966a8d848ad744ccd6b6dabb3836 --- /dev/null +++ b/torchtitan/experiments/multimodal/tests/test_utils.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from typing import Optional, Union + +import torch +from torch import nn + + +def fixed_init_tensor( + shape: torch.Size, + min_val: Union[float, int] = 0.0, + max_val: Union[float, int] = 1.0, + nonlinear: bool = False, + dtype: torch.dtype = torch.float, +): + """ + Utility for generating deterministic tensors of a given shape. In general stuff + like torch.ones, torch.eye, etc can result in trivial outputs. This utility + generates a range tensor [min_val, max_val) of a specified dtype, applies + a sine function if nonlinear=True, then reshapes to the appropriate shape. + """ + n_elements = math.prod(shape) + step_size = (max_val - min_val) / n_elements + x = torch.arange(min_val, max_val, step_size, dtype=dtype) + x = x.reshape(shape) + if nonlinear: + return torch.sin(x) + return x + + +@torch.no_grad +def fixed_init_model( + model: nn.Module, + min_val: Union[float, int] = 0.0, + max_val: Union[float, int] = 1.0, + nonlinear: bool = False, + dtype: Optional[torch.dtype] = None, +): + """ + This utility initializes all parameters of a model deterministically using the + function fixed_init_tensor above. See that docstring for details of each parameter. + """ + for _, param in model.named_parameters(): + param.copy_( + fixed_init_tensor( + param.shape, + min_val=min_val, + max_val=max_val, + nonlinear=nonlinear, + dtype=param.dtype if dtype is None else dtype, + ) + ) diff --git a/torchtitan/experiments/multimodal/tokenizer/tiktoken.py b/torchtitan/experiments/multimodal/tokenizer/tiktoken.py new file mode 100644 index 0000000000000000000000000000000000000000..9d494a06f6557c0108b107dd3a3ba36832bb913f --- /dev/null +++ b/torchtitan/experiments/multimodal/tokenizer/tiktoken.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import os +from pathlib import Path +from typing import ( + AbstractSet, + Any, + cast, + Collection, + Dict, + Iterator, + List, + Literal, + Mapping, + Optional, + Sequence, + Union, +) + +import tiktoken +import torch +from tiktoken.load import load_tiktoken_bpe + +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger + +IMAGE_TOKEN_ID = 128256 +IGNORE_INDEX = -100 + + +class TikTokenizer(Tokenizer): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950 + + def __init__(self, model_path: str): + super().__init__(model_path) + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [ + f"<|reserved_special_token_{i}|>" + for i in range(5, self.num_reserved_special_tokens - 5) + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.special_tokens["<|image|>"] = IMAGE_TOKEN_ID + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + self._n_words: int = self.model.n_vocab + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.pad_id: int = -1 + self.image_id = IMAGE_TOKEN_ID + self.stop_tokens = { + self.special_tokens["<|end_of_text|>"], + self.special_tokens["<|eot_id|>"], + } + logger.info( + f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}, IMAGE ID {self.image_id}" + ) + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, + disallowed_special: Optional[Union[Literal["all"], Collection[str]]] = None, + ) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_tokens ("all"|set[str]): allowed special tokens in string + disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + assert type(s) is str + allowed_special = allowed_special or set() + disallowed_special = disallowed_special or () + + # The tiktoken tokenizer can handle <=400k chars without + # pyo3_runtime.PanicException. + TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + + # https://github.com/openai/tiktoken/issues/195 + # Here we iterate over subsequences and split if we exceed the limit + # of max consecutive non-whitespace or whitespace characters. + MAX_NO_WHITESPACES_CHARS = 25_000 + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: List[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(List[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + + def encode_multimodal(self, sample: Mapping[str, Any]) -> List[int]: + """ + Tokenizes a `str` of text and creates `labels` masking BOS, EOS and `image_id` tokens. + """ + # TODO(tj.solergibert) Should we keep `input_ids` OR `tokens` across this class, VisionCrossAttentionMask & the collator? + # For me it makes more sense to split `tokens` between `input_ids` & `labels` as in train.py BUT the `MultimodalDecoder` + # & everything else expects `tokens` + text = sample["text"] + tokens = self.encode( + text, bos=True, eos=True, allowed_special=set(["<|image|>"]) + ) + input_ids = torch.LongTensor(tokens[:-1]) + labels = torch.LongTensor(tokens[1:]) + labels = torch.where( + torch.isin( + labels, torch.LongTensor([self.bos_id, self.eos_id, self.image_id]) + ), + IGNORE_INDEX, + labels, + ) + + assert len(input_ids) == len(labels) # TODO(tj.solergibert) Delete + + sample.update({"tokens": input_ids, "labels": labels}) + + return sample + + +def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer: + return TikTokenizer(job_config.model.tokenizer_path) diff --git a/torchtitan/experiments/multimodal/transform.py b/torchtitan/experiments/multimodal/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..ecb0f989acd0b818f20116a60813c26e68438cec --- /dev/null +++ b/torchtitan/experiments/multimodal/transform.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Mapping, Optional, Tuple + +import torch + +import torchvision +from torchvision.transforms.v2 import functional as F + +from utils import ( + find_supported_resolutions, + get_canvas_best_fit, + resize_with_pad, + tile_crop, +) + +from torchtitan.tools.logging import logger + + +class CLIPTransform: + """ + This class accepts images of any size and dynamically resizes, pads, normalizes and tiles it + based on the image aspect ratio and the number of image tiles we allow. + + The algorithm will NOT distort the image to fit a certain aspect ratio, because + that leads to a significant degradation in image quality. + + The user can choose if they want to allow upscaling by using the flag ``resize_to_max_canvas``. + + For example, if an input image is of size 300x800, and we want to allow + a maximum of 16 image tiles, with side 224px, then: + + If ``resize_to_max_canvas=False``, then: + best_resolution = (448, 896) -> smallest canvas, up to 16 tiles, that doesn't require downscaling + image is NOT resized + image is padded (300, 800) -> 448,896 + Image is tiled 2x4, for a final output shape of (8, 3, 224, 224) + + If ``resize_to_max_canvas=True``, then: + best_resolution = (448, 1344) # canvas that allows maximum upscaling, with minimum padding, up to 16 tiles + image is resized without distortion (300,800) -> (448, 1194) #448 is the limiting side for the resize + image is padded (448, 1194) -> (448, 1344) + Image is tiled 2x6, for a final output shape of (10, 3, 224, 224) + + Args: + image_mean (Optional[List[float]]): Mean values of each channel, used for normalization. + Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. + image_std (Optional[List[float]]): Standard deviation values of each channel, used for normalization. + Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. + possible_resolutions (Optional[List[Tuple[int, int]]]): List of possible resolutions as tuples (height, width). + where each tuple represents a possible canvas to fit the image into when calling ``get_canvas_best_fit``. + If None, this will be calculated using max_num_tiles and tile_size. Default None. + tile_size (int): Size of the tiles to divide the image into. Default 224. + max_num_tiles (Optional[int]): Only used if possible_resolutions is NOT given. + Maximum number of tiles to break an image into. + This will be used to generate possible_resolutions, + e.g. [(224, 224), (224, 448), (448, 224)] if max_num_tiles = 2 and tile_size = 224. + Default 4. + dtype (torch.dtype): Data type of the output image. Default torch.bfloat16. + resample (str): Resampling method used when resizing images. Supports any enum of + ``torchvision.transforms.InterpolationMode``, e.g. "nearest", "nearest_exact", "bilinear", "bicubic". + Default 'bilinear'. + resize_to_max_canvas (bool): "If True, the image will be upscaled without distortion to fit the largest possible + resolution from possible_resolutions. + If False, it will pick the resolution that minimizes downscaling, including no downscaling at all. + In this case, the image will only be upscaled if it's size < tile_size. Default False. + + Examples: + >>> image_transform = CLIPImageTransform( + ... image_mean=None, + ... image_std=None, + ... tile_size=224, + ... possible_resolutions=None, + ... max_num_tiles=4, + ... resample="bilinear", + ... resize_to_max_canvas=True, + ...) + >>> # create random image + >>> image = (np.random.rand(100,200,3) * 255).astype(np.uint8) + >>> image = PIL.Image.fromarray(image) + >>> output = image_transform(image) + >>> output['image'].shape # [num_tiles, num_channels, tile_size, tile_size] + torch.Size([2, 3, 224, 224]) + >>> output['ar'] # image best fits the canvas 224x448 + torch.tensor([1,2]) + """ + + def __init__( + self, + *, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, + possible_resolutions: Optional[List[Tuple[int, int]]] = None, + tile_size: int = 224, + max_num_tiles: Optional[int] = 4, + dtype: torch.dtype = torch.bfloat16, + resample: str = "bilinear", + resize_to_max_canvas: bool = False, + ) -> None: + + # get_canvas_best_fit + assert ( + possible_resolutions is not None or max_num_tiles is not None + ), f"Either possible_resolutions or max_num_tiles must be given. Got {possible_resolutions} and {max_num_tiles}" + + # If possible_resolutions are not given, then calculate possible ones based on max_num_tiles + if not possible_resolutions and max_num_tiles: + possible_resolutions = find_supported_resolutions( + max_num_tiles=max_num_tiles, tile_size=tile_size + ) + else: + possible_resolutions = possible_resolutions + + self.possible_resolutions = torch.tensor(possible_resolutions).reshape(-1, 2) + logger.debug( + f"Found possible_resolutions: {self.possible_resolutions}. Will fit the images into the canvas with best fit." + ) + + self.resize_to_max_canvas = resize_to_max_canvas + + # normalize + assert (image_mean is None) == ( + image_std is None + ), f"Need to provide both or none of image_mean and image_std. Got {image_mean=} and {image_std=}" + self.mean = image_mean + self.std = image_std + + # resize_with_pad + self.max_size = None if resize_to_max_canvas else tile_size + self.dtype = dtype + self.resample = torchvision.transforms.InterpolationMode[resample.upper()] + + # tile_crop + self.tile_size = tile_size + + def __call__(self, image: torch.Tensor) -> Mapping[str, Any]: + """ + Apply image decoding and transformations to the "image" field in the sample. + + Args: + sample (Mapping[str, Any]): A sample with an "image" field containing + a List[Message] to tokenize + + Returns: + Mapping[str, Any]: The sample with an updated "image" filed and added + "aspect_ratio" field. + """ + assert isinstance(image, torch.Tensor), "Input image must be a torch.Tensor." + + image = F.to_image(image) + image = F.grayscale_to_rgb_image(image) + image = F.to_dtype(image, dtype=self.dtype, scale=True) + + # Find the best canvas to fit the image without distortion + best_resolution = get_canvas_best_fit( + image=image, + possible_resolutions=self.possible_resolutions, + resize_to_max_canvas=self.resize_to_max_canvas, + ) + + # resize without distortion + pad to fit best_resolution + image = resize_with_pad( + image=image, + target_size=best_resolution, + resample=self.resample, + max_size=self.max_size, + ) + + # Normalize + if self.mean: + image = F.normalize(image, mean=self.mean, std=self.std) + + # Divide the image into equally sized tiles + image = tile_crop(image=image, tile_size=self.tile_size) + + aspect_ratio = torch.tensor(best_resolution).reshape(-1) // self.tile_size + + return { + "image": image, + "aspect_ratio": aspect_ratio, + } diff --git a/torchtitan/experiments/multimodal/utils.py b/torchtitan/experiments/multimodal/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c927772a5ef95ba65123c9387de4ead1e732490f --- /dev/null +++ b/torchtitan/experiments/multimodal/utils.py @@ -0,0 +1,437 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from collections import defaultdict + +from pathlib import Path +from typing import List, Optional, Set, Tuple, Union +from urllib import request + +import torch +import torchvision +from torchvision.transforms.v2 import functional as F + +# NOTE Copied from torchtune.modules.transforms.vision_utils.tile_crop.py +def tile_crop(image: torch.Tensor, tile_size: int) -> torch.Tensor: + """ + Divides a tensor into equally sized tiles. The tensor should be divisible by tile_size. + + Args: + image (torch.Tensor): Input image to crop into tiles. + tile_size (int): Size of each tile. + + Returns: + torch.Tensor: torch.Tensor of shape [num_tiles, channel_size, tile_size, tile_size] + + Examples: + >>> image = torch.rand(3, 200, 300) + >>> tiles = tile_crop(image, tile_size=50) + >>> tiles.shape # 4x6 = 24 tiles + torch.Size([24, 3, 50, 50]) + + >>> image = torch.rand(3, 400, 600) + >>> tiles = tile_crop(image, tile_size=200) + >>> tiles.shape # 2x3 = 6 tiles + torch.Size([6, 3, 200, 200]) + """ + + channel_size, height, width = image.shape + + # assert sizes are divisible + assert ( + height % tile_size == 0 and width % tile_size == 0 + ), f"Image size {height}x{width} is not divisible by tile size {tile_size}" + + # Reshape to split height and width into tile_size blocks + tiles_height = height // tile_size + tiles_width = width // tile_size + + reshaped = image.view(channel_size, tiles_height, tile_size, tiles_width, tile_size) + + # Transpose to bring tiles together + # We want [tiles_height, tiles_width, channel_size, tile_size, tile_size] + transposed = reshaped.permute(1, 3, 0, 2, 4) + + # Flatten the tiles + tiles = transposed.contiguous().view( + tiles_height * tiles_width, channel_size, tile_size, tile_size + ) + + return tiles + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py +def resize_with_pad( + image: torch.Tensor, + target_size: Tuple[int, int], + resample: torchvision.transforms.InterpolationMode, + max_size: Optional[int] = None, +) -> torch.Tensor: + """ + Resizes and pads an image to target_size without causing distortion. + The user can set max_size to limit upscaling when target_size exceeds image_size. + + Args: + image (torch.Tensor): The input image tensor in the format [..., H, W]. + target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width]. + resample (torchvision.transforms.InterpolationMode): Resampling method used when resizing images. + Supports torchvision.transforms.InterpolationMode.NEAREST, InterpolationMode.NEAREST_EXACT, + InterpolationMode.BILINEAR and InterpolationMode.BICUBIC. + max_size (Optional[int]): The maximum size to upscale the image to. + If None, will upscale up to target_size. + + Returns: + torch.Tensor: The resized and padded image tensor in the format [..., H, W]. + + Examples: + + Example 1: The image will be upscaled from (300, 800) to (448, 1194), since 448 is the limiting side, + and then padded from (448, 1194) to (448, 1344). + + >>> max_size = None + >>> image = torch.rand([3, 300, 800]) + >>> target_size = (448, 1344) + >>> resample = torchvision.transforms.InterpolationMode.BILINEAR + >>> output = resize_with_pad(image, target_size, resample, max_size) + + Example 2: The image will stay as is, since 800 > 600, and then padded from (300, 800) to (448, 1344). + + >>> max_size = 600 + >>> image = torch.rand([3, 300, 800]) + >>> target_size = (448, 1344) + >>> resample = torchvision.transforms.InterpolationMode.BILINEAR + >>> output = resize_with_pad(image, target_size, resample, max_size) + + Example 3: The image will be downscaled from (500, 1000) to (224, 448), + and padded from (224, 448) to (448, 448). + + >>> max_size = 600 + >>> image = torch.rand([3, 500, 1000]) + >>> target_size = (448, 488) + >>> resample = torchvision.transforms.InterpolationMode.BILINEAR + >>> output = resize_with_pad(image, target_size, resample, max_size) + + """ + + image_height, image_width = image.shape[-2:] + image_size = (image_height, image_width) + + # If target_size requires upscaling, we might want to limit the upscaling to max_size + if max_size is not None: + new_target_height = min(max(image_height, max_size), target_size[0]) + new_target_width = min(max(image_width, max_size), target_size[1]) + target_size_resize = (new_target_height, new_target_width) + else: + target_size_resize = target_size + + # resize to target_size while preserving aspect ratio + new_size_preserving_aspect_ratio = _get_max_res_without_distortion( + image_size=image_size, + target_size=target_size_resize, + ) + + image = F.resize( + inpt=image, + size=list(new_size_preserving_aspect_ratio), + interpolation=resample, + antialias=True, + ) + + image = _pad_image_top_left(image=image, target_size=target_size) + + return image + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py +def _pad_image_top_left( + image: torch.Tensor, + target_size: Tuple[int, int], +) -> torch.Tensor: + """ + Places the image at the top left of the canvas and pads with 0 the right and bottom + to fit to the target resolution. If target_size < image_size, it will crop the image. + + Args: + image (torch.Tensor): The input image tensor in the format [..., H, W]. + target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width]. + + Returns: + torch.Tensor: The padded image tensor in the format [..., H, W]. + """ + + image_size = image.shape[-2:] + + height, width = image_size + target_height, target_width = target_size + + pad_x = target_width - width + pad_y = target_height - height + + padding = [0, 0, pad_x, pad_y] + return F.pad(inpt=image, padding=padding) + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py +def _get_max_res_without_distortion( + image_size: Tuple[int, int], + target_size: Tuple[int, int], +) -> Tuple[int, int]: + """ + Determines the maximum resolution to which an image can be resized to without distorting its + aspect ratio, based on the target resolution. + + For example, if image_size = (200,400) and target_size = (600,800), + scale_h = 600/200 = 3 + scale_w = 800/400 = 2 + So the maximum that we can upscale without distortion is min(scale_h, scale_w) = 2 + + Since scale_w is the limiting side, then new_w = target_w, and new_h = old_h*scale_w + + Args: + image_size (Tuple[int, int]): The original resolution of the image. + target_size (Tuple[int, int]): The desired resolution to fit the image into. + Returns: + Tuple[int, int]: The optimal dimensions to which the image should be resized. + Examples: + >>> _get_max_res_without_distortion([200, 300], target_size = (450, 200)) + (133, 200) + >>> _get_max_res_without_distortion([800, 600], target_size = (450, 1300)) + (450, 337) + """ + + original_height, original_width = image_size + target_height, target_width = target_size + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.floor(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.floor(original_width * scale_h), target_width) + + return new_height, new_width + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py +def _get_factors(n: int) -> Set[int]: + """ + Calculate all factors of a given number, i.e. a divisor that leaves no remainder. + + Args: + n (int): The number to find factors for. + + Returns: + set: A set containing all factors of the number. + + Examples: + >>> _get_factors(n=12) + {1, 2, 3, 4, 6, 12} + """ + factors_set = set() + + for i in range(1, int(n**0.5) + 1): + if n % i == 0: + factors_set.add(i) + factors_set.add(n // i) + return factors_set + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py +def get_canvas_best_fit( + image: torch.Tensor, possible_resolutions: torch.Tensor, resize_to_max_canvas: bool +) -> Tuple[int, int]: + """ + Determines the best canvas possible from a list of possible resolutions to + resize an image to, without distortion. + + For each possible resolution, calculates the scaling factors for + width and height, and selects the smallest one, which is the limiting side. + E.g. if to match a canvas shape you have to upscale an image's height by 2x, and width by 1.5x, + then the maximum upscaling without distortion is min(2, 1.5) = 1.5. + + If there are multiple canvases that satisfy the conditions, + we pick the one with the lowest area to minimize padding. + + Args: + image (torch.Tensor): The image we want to fit into a canvas. + possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each + row represents a possible canvas. + resize_to_max_canvas (bool): If True, pick the canvas that allows maximum scaling. + If False, pick the canvas that minimizes downscaling, including no downscaling at all. + + Returns: + Tuple[int, int]: The best resolution to fit the image into. + + Examples: + >>> image = torch.rand(3, 200, 300) + >>> possible_resolutions = torch.tensor([ + ... [224, 672], + ... [672, 224], + ... [224, 448], + ... [448, 224], + ... [224, 224] + ... ]) + >>> get_canvas_best_fit(image, possible_resolutions, resize_to_max_canvas=False) + (224, 448) + + In the example above, we calculate the scaling factors for each possible resolution + + >>> scale_height = torch.tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200]) + >>> scale_width = torch.tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467]) + >>> scales = torch.tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467]) + + Two options have scaling_factor > 1, since resize_to_max_canvas is False, we pick the smallest + + >>> upscaling_options = torch.tensor([1.1200, 1.1200]) + >>> selected_scale = torch.tensor(1.1200) + + There are two possible options, so we pick the one with the smallest area + + >>> areas = torch.tensor([150528, 100352]) # for resolutions [672, 224] and [224, 448], respectively + >>> optimal_canvas = torch.tensor([224, 448]) # resolution with the smallest area + """ + + original_height, original_width = image.shape[-2:] + + # possible resolutions heights/widths + target_heights, target_widths = ( + possible_resolutions[:, 0], + possible_resolutions[:, 1], + ) + + # scaling factors to resize the image without distortion + scale_w = target_widths / original_width + scale_h = target_heights / original_height + + # get limiting side scaling -> no distortion + scales = torch.where(scale_w > scale_h, scale_h, scale_w) + + # filter only scales that allow upscaling + upscaling_options = scales[scales >= 1] + if len(upscaling_options) > 0: + if resize_to_max_canvas: + selected_scale = torch.max(upscaling_options) + else: + selected_scale = torch.min(upscaling_options) + else: + # no upscaling possible, + # get the minimum downscaling (max scale for scales<1) + downscaling_options = scales[scales < 1] + selected_scale = torch.max(downscaling_options) + + # get all resolutions that support this scaling factor, + # e.g. you can upscale to 224x224, 224x448, 224x672 without distortion + chosen_canvas = possible_resolutions[scales == selected_scale] + + # if there are multiple resolutions, + # get the one with minimum area to reduce padding + if len(chosen_canvas) > 1: + areas = chosen_canvas[:, 0] * chosen_canvas[:, 1] + optimal_idx = torch.argmin(areas) + optimal_canvas = chosen_canvas[optimal_idx] + else: + optimal_canvas = chosen_canvas[0] + + return tuple(optimal_canvas.tolist()) + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py +def find_supported_resolutions( + max_num_tiles: int, tile_size: int +) -> List[Tuple[int, int]]: + """ + Computes all combinations of resolutions, multiple of tile_size, + that contain up to max_num_tiles. Useful for when dividing an image into tiles. + + For example, if we want at most 2 tiles per image, then we can support the + following resolutions: (1x1, 1x2, 2x1) * tile_size + + Args: + max_num_tiles (int): Maximum number of tiles. + tile_size (int): Size of the side of the tile. + + Returns: + List[Tuple[int, int]]: List of possible resolutions as tuples (height, width). + + Examples: + + >>> max_num_tiles = 4 + >>> tile_size = 224 + >>> find_supported_resolutions(max_num_tiles, tile_size) + [(224, 896), (448, 448), (224, 224), (896, 224), (224, 672), (672, 224), (224, 448), (448, 224)] + """ + + # create dictionary {aspect_ratio: [resolution1, ..., resolution n]} + # example {0.25: [(1,4)], 1.0: [(2,2), (1,1)], 4.0: [(4,1)]} + asp_dict = defaultdict(list) + for _tile_size in range(max_num_tiles, 0, -1): + factors = sorted(_get_factors(_tile_size)) + asp_ratios = [(factor, _tile_size // factor) for factor in factors] + for height, width in asp_ratios: + ratio_float = height / width + asp_dict[ratio_float].append((height, width)) + + # get the resolutions multiplied by the tile_size + possible_resolutions = [] + for ar, resolution in asp_dict.items(): + for height, width in resolution: + possible_resolutions.append((height * tile_size, width * tile_size)) + + return possible_resolutions + + +# NOTE Copied from torchtune.data._utils.py +def load_image(image_loc: Union[Path, str]) -> torch.Tensor: + """ + Convenience method to load an image in torch.Tensor format from a local file path or remote source. + + Args: + image_loc (Union[Path, str]): Local file path or remote source pointing to the image + which will be loaded in PIL format. + + Note: + If loading an image from a remote source, the function expects the URL provided in ``image_loc`` + to start with "http" or "https" e.g. "https://www.wikipedia.org/en/bird.jpg". + + Raises: + ValueError: If the image cannot be loaded from remote source, **or** + if the image cannot be opened as a :class:`~torch.Tensor`. + + Examples: + >>> # Load from remote source + >>> image = load_image("https://www.wikipedia.org/en/bird.jpg") + + >>> # Load from local file path + >>> image = load_image(Path("/home/user/bird.jpg")) + + Returns: + torch.Tensor: The loaded image. + """ + + # If pointing to remote source, try to load to local + if isinstance(image_loc, str) and image_loc.startswith("http"): + try: + image_loc = request.urlopen(image_loc).read() + image = torchvision.io.decode_image( + torch.frombuffer(image_loc, dtype=torch.uint8), + mode="RGB", + ) + except Exception as e: + raise ValueError("Failed to load remote image as torch.Tensor") from e + + # Open the local image as a Tensor image + else: + try: + image = torchvision.io.decode_image(image_loc, mode="RGB") + except Exception as e: + raise ValueError("Failed to load local image as torch.Tensor") from e + + return image diff --git a/torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc b/torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..184a1069b52468b13e425de5a64b34259be3cb21 Binary files /dev/null and b/torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc differ diff --git a/torchtitan/experiments/simple_fsdp/__pycache__/model.cpython-312.pyc b/torchtitan/experiments/simple_fsdp/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5992c210733647d400cc1dfe4283ed260e64e90 Binary files /dev/null and b/torchtitan/experiments/simple_fsdp/__pycache__/model.cpython-312.pyc differ diff --git a/torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-312.pyc b/torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..def78316cdd08bd4823b55c2262fcf87a5febd5d Binary files /dev/null and b/torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-312.pyc differ diff --git a/torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc b/torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c70ccd551692eac829bf189c00f060215db155f7 Binary files /dev/null and b/torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc differ diff --git a/torchtitan/experiments/simple_fsdp/model.py b/torchtitan/experiments/simple_fsdp/model.py new file mode 100644 index 0000000000000000000000000000000000000000..63104169b8fa14ed7032182c1ad08b782cd715fe --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/model.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.models.llama3 import Transformer, TransformerModelArgs +from .simple_fsdp import disable_data_parallel + + +class SimpleFSDPTransformer(Transformer): + def __init__(self, model_args: TransformerModelArgs): + super().__init__(model_args) + self.init_weights() + + def init_weights(self, *args, **kwargs): + with disable_data_parallel(): + super().init_weights(*args, **kwargs) diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py new file mode 100644 index 0000000000000000000000000000000000000000..34d255e199f83d41e19b7000d7d0b79fcb61455d --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn + +from torch.distributed._tensor import ( + distribute_tensor, + DTensor, + Partial, + Replicate, + Shard, +) +from torch.utils.checkpoint import ( + checkpoint, + CheckpointPolicy, + create_selective_checkpoint_contexts, +) + + +_active_parametrization = True + + +@contextmanager +def disable_data_parallel(): + global _active_parametrization + try: + _active_parametrization = False + yield + finally: + _active_parametrization = True + + +@dataclass(frozen=True) +class MixedPrecisionPolicy: + param_dtype: Optional[torch.dtype] = None + reduce_dtype: Optional[torch.dtype] = None + + +def fsdp_policy(): + def _fsdp_recomp_policy(): + def _custom_policy(ctx, func, *args, **kwargs): + to_recompute = func in { + torch.ops._c10d_functional.all_gather_into_tensor.default, + torch.ops._c10d_functional.wait_tensor.default, + torch.ops.aten._to_copy.default, # for dtype cast in FSDP + } + return ( + CheckpointPolicy.MUST_RECOMPUTE + if to_recompute + else CheckpointPolicy.MUST_SAVE + ) + + return _custom_policy + + return create_selective_checkpoint_contexts(_fsdp_recomp_policy()) + + +class ReplicateComputation(torch.nn.Module): + def __init__(self, device_mesh, param_sharding, mode, regional_ac, mp_policy): + super().__init__() + self.device_mesh = device_mesh + self.param_sharding = param_sharding + self.mode = mode + self.compute_placements = [Replicate()] * self.device_mesh.ndim + self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim + self.regional_ac = regional_ac + mp_policy = mp_policy or MixedPrecisionPolicy() + self.param_dtype = mp_policy.param_dtype + self.reduce_dtype = mp_policy.reduce_dtype + + def replicate_compute(self, x): + # data parallel runtime replicate parameters and do local compute + # the gradients are partial tensors that needs to perform reduction + # (i.e. DDP: allreduce, FSDP: reduce_scatter, HSDP: mix of both) + + # NOTE: specifying mixed precision is only available in pytorch_intern24 + # https://github.com/tianyu-l/pytorch_intern24/pull/20 + # support for FSDP + TP (assuming TP shards the inner-most dim) + if self.mode == "fully_shard" and x._spec.mesh.ndim == 2: + dp_placement, tp_placement = x._spec.placements + dp_mesh, tp_mesh = self.device_mesh, x._spec.mesh["tp"] + + # re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather + # TODO: we should consider merging this logic into DTensor redistribute API + sharded_local_tensor = x.to_local() + sharded_dtensor = DTensor.from_local( + sharded_local_tensor, dp_mesh, self.param_sharding + ) + + # the actuall FSDP all-gather on dp_mesh + # TODO(ruisizhang123): enable mixed-precision training here + # add the forward_dtype and backward_dtype back after landing changes in PyTorch DTensor + replicated_dtensor = sharded_dtensor.redistribute( + placements=self.compute_placements, + # forward_dtype=self.param_dtype, + # backward_dtype=self.reduce_dtype, + ) + + # re-wrap 1D all-gathered DTensor on dp_mesh to 1D DTensor on tp_mesh + # TODO: DTensor should support this mesh collasping operation + replicated_local_tensor = replicated_dtensor.to_local( + grad_placements=self.grad_placements + ) + output = DTensor.from_local( + replicated_local_tensor, tp_mesh, (tp_placement,) + ) + else: + output = x.redistribute( + placements=self.compute_placements, + # forward_dtype=self.param_dtype, + # backward_dtype=self.reduce_dtype, + ).to_local(grad_placements=self.grad_placements) + + return output + + def forward(self, x): + global _active_parametrization + # This should never be set to true during forward, only outside for model + # inspection / debugging / initialization + # model initialization can be done now through + # with disable_data_parallel(): + # model.init_weights() + if not _active_parametrization: + return x + + if self.regional_ac and self.mode in ("fully_shard", "hybrid_shard"): + # apply checkpointing to implement reshard_after_forward + output = checkpoint( + self.replicate_compute, x, use_reentrant=False, context_fn=fsdp_policy + ) + else: + output = self.replicate_compute(x) + + return output + + +def data_parallel( + model, + device_mesh, + mode="replicate", + ac_mode: str = "none", + mp_policy: Optional[MixedPrecisionPolicy] = None, +): + if mode == "replicate": + param_sharding = (Replicate(),) + elif mode == "fully_shard": + param_sharding = (Shard(0),) + elif mode == "hybrid_shard": + # replicate inter-host, fully shard intra-host + param_sharding = (Replicate(), Shard(0)) + assert ( + device_mesh.ndim == 2 + ), "hybrid sharded data parallel requires 2D DeviceMesh" + else: + raise ValueError(f"Unsupported mode {mode}") + + modules = list(model.modules()) + + # apply regional ac (with fsdp_policy) if no global ac is to be applied + regional_ac = ac_mode == "none" + + for mod in modules: + params_dict = dict(mod.named_parameters(recurse=False)) + for p_name, p in params_dict.items(): + if p is not None and p.numel() > 0: + mod.register_parameter( + p_name, + # NOTE: for 2D we need to distribute_tensor a DTensor + # which requires latest change in pytorch_intern24 + # https://github.com/tianyu-l/pytorch_intern24/pull/25 + nn.Parameter(distribute_tensor(p, device_mesh, param_sharding)), + ) + nn.utils.parametrize.register_parametrization( + mod, + p_name, + ReplicateComputation( + device_mesh, + param_sharding, + mode, + regional_ac, + mp_policy=mp_policy, + ), + unsafe=True, + ) + + return model diff --git a/torchtitan/experiments/simple_fsdp/tests/__init__.py b/torchtitan/experiments/simple_fsdp/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e41cd717f6a439a9c08d76a9d0e4a54e190fc5a --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/experiments/simple_fsdp/tests/test_numerics.py b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py new file mode 100644 index 0000000000000000000000000000000000000000..3c15ce573b9c65f9f26cefcbdbcd0f5b2f5c9713 --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import copy + +import torch +from torch.distributed._composable.fsdp import fully_shard + +from torch.testing._internal.common_fsdp import FSDPTest + +from torchtitan.components.loss import cross_entropy_loss +from torchtitan.distributed import ParallelDims +from torchtitan.experiments.simple_fsdp.simple_fsdp import data_parallel + + +class TestSimpleFSDP(FSDPTest): + def init_test(self): + self.optimizer = torch.optim.Adam + self.loss_fn = cross_entropy_loss + data_parallel_shard_degree = -1 + if self.mode == "replicate": + self.dp_mesh_dim_names = ("dp_replicate",) + data_parallel_replicate_degree = self.world_size + elif self.mode == "fully_shard": + self.dp_mesh_dim_names = ("dp_shard_cp",) + data_parallel_replicate_degree = 1 + elif self.mode == "hybrid_shard": + self.dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + data_parallel_replicate_degree = self.world_size // 2 + else: + raise ValueError(f"Unsupported mode {mode}") + + self.parallel_dims = ParallelDims( + dp_shard=data_parallel_shard_degree, + dp_replicate=data_parallel_replicate_degree, + cp=1, + tp=1, + pp=1, + world_size=self.world_size, + enable_loss_parallel=True, + ) + self.device_mesh = self.parallel_dims.build_mesh(device_type="cuda") + + def get_input(self): + inputs = torch.randn(8, 8).cuda() + labels = torch.randn(8, 8).cuda() + model = torch.nn.Linear(8, 8) + return model, inputs, labels + + def run_fsdp2(self, model, inputs, labels, epoch=20): + fully_shard(model, mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)]) + optim = self.optimizer(model.parameters(), lr=1e-4) + losses = [] + for _ in range(epoch): + optim.zero_grad() + out = model(inputs) + loss = self.loss_fn(out, labels) + loss.backward() + optim.step() + losses.append(loss) + return losses + + def run_simple_fsdp(self, model, inputs, labels, epoch=20): + model = data_parallel( + model, + device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)], + mode=self.mode, + ) + optim = self.optimizer(model.parameters(), lr=1e-4) + losses = [] + for _ in range(epoch): + optim.zero_grad() + out = model(inputs) + loss = self.loss_fn(out, labels) + loss.backward() + optim.step() + losses.append(loss) + return losses + + def test_replicate_convergence(self): + # unit test for replicate mode + self.mode = "replicate" + self.init_test() + model, inputs, labels = self.get_input() + + fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels) + simple_fsdp_replicate_losses = self.run_simple_fsdp( + copy.deepcopy(model), inputs, labels + ) + + for fsdp2_loss, simple_fsdp_replicate_loss in zip( + fsdp2_losses, simple_fsdp_replicate_losses + ): + assert torch.allclose(fsdp2_loss, simple_fsdp_replicate_loss) + + def test_fullyshard_convergence(self): + # unit test for fully_shard mode + self.mode = "fully_shard" + self.init_test() + model, inputs, labels = self.get_input() + + fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels) + simple_fsdp_fullyshard_losses = self.run_simple_fsdp( + copy.deepcopy(model), inputs, labels + ) + + for fsdp2_loss, simple_fsdp_fullyshard_loss in zip( + fsdp2_losses, simple_fsdp_fullyshard_losses + ): + assert torch.allclose(fsdp2_loss, simple_fsdp_fullyshard_loss) + + def test_hybridshard_convergence(self): + # unit test for hybrid_shard mode + self.mode = "hybrid_shard" + self.init_test() + model, inputs, labels = self.get_input() + + fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels) + simple_fsdp_hybridshard_losses = self.run_simple_fsdp( + copy.deepcopy(model), inputs, labels + ) + + for fsdp2_loss, simple_fsdp_hybridshard_loss in zip( + fsdp2_losses, simple_fsdp_hybridshard_losses + ): + assert torch.allclose(fsdp2_loss, simple_fsdp_hybridshard_loss) diff --git a/torchtitan/models/llama3/__pycache__/__init__.cpython-312.pyc b/torchtitan/models/llama3/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5881158d597f9aa5a8bb481644f2e55d4f99eedb Binary files /dev/null and b/torchtitan/models/llama3/__pycache__/__init__.cpython-312.pyc differ diff --git a/torchtitan/models/llama3/__pycache__/model.cpython-312.pyc b/torchtitan/models/llama3/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26f17f077825f403816b76067fa8abbc833d6094 Binary files /dev/null and b/torchtitan/models/llama3/__pycache__/model.cpython-312.pyc differ