diff --git a/fla/models/abc/modeling_abc.py b/fla/models/abc/modeling_abc.py new file mode 100644 index 0000000000000000000000000000000000000000..455d5b7f10358d1af37698e44544ca53784896a7 --- /dev/null +++ b/fla/models/abc/modeling_abc.py @@ -0,0 +1,418 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.abc import ABCAttention +from fla.layers.attn import Attention +from fla.models.abc.configuration_abc import ABCConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as ABCMLP +from fla.modules import RMSNorm + +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +class ABCBlock(nn.Module): + def __init__(self, config: ABCConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = ABCAttention( + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_slots=config.num_slots, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + gate_fn=config.hidden_act, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + use_rope=config.use_rope, + clamp_min=config.clamp_min, + clamp_max=config.clamp_max, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = ABCMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class ABCPreTrainedModel(PreTrainedModel): + + config_class = ABCConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['ABCBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class ABCModel(ABCPreTrainedModel): + + def __init__(self, config: ABCConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([ABCBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`ABCModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class ABCForCausalLM(ABCPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = ABCModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla/models/gated_deltanet/__init__.py b/fla/models/gated_deltanet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f66b6488665fe05d2391ff248c96f2186d072f98 --- /dev/null +++ b/fla/models/gated_deltanet/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig +from fla.models.gated_deltanet.modeling_gated_deltanet import GatedDeltaNetForCausalLM, GatedDeltaNetModel + +AutoConfig.register(GatedDeltaNetConfig.model_type, GatedDeltaNetConfig) +AutoModel.register(GatedDeltaNetConfig, GatedDeltaNetModel) +AutoModelForCausalLM.register(GatedDeltaNetConfig, GatedDeltaNetForCausalLM) + +__all__ = ['GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel'] diff --git a/fla/models/mamba/__init__.py b/fla/models/mamba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b67cf0a75012a2f71a0f12c53584071bdc456a6b --- /dev/null +++ b/fla/models/mamba/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.mamba.configuration_mamba import MambaConfig +from fla.models.mamba.modeling_mamba import MambaBlock, MambaForCausalLM, MambaModel + +AutoConfig.register(MambaConfig.model_type, MambaConfig, True) +AutoModel.register(MambaConfig, MambaModel, True) +AutoModelForCausalLM.register(MambaConfig, MambaForCausalLM, True) + + +__all__ = ['MambaConfig', 'MambaForCausalLM', 'MambaModel', 'MambaBlock'] diff --git a/fla/models/mamba2/__init__.py b/fla/models/mamba2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0b8ac62a700590e06d1e524979b2f21353aa5188 --- /dev/null +++ b/fla/models/mamba2/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.mamba2.configuration_mamba2 import Mamba2Config +from fla.models.mamba2.modeling_mamba2 import Mamba2ForCausalLM, Mamba2Model + +AutoConfig.register(Mamba2Config.model_type, Mamba2Config, True) +AutoModel.register(Mamba2Config, Mamba2Model, True) +AutoModelForCausalLM.register(Mamba2Config, Mamba2ForCausalLM, True) + + +__all__ = ['Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model'] diff --git a/fla/models/rwkv7/__init__.py b/fla/models/rwkv7/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6f132f3fc8de7108242e1accc51e55f4a4e6ed5 --- /dev/null +++ b/fla/models/rwkv7/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config +from fla.models.rwkv7.modeling_rwkv7 import RWKV7ForCausalLM, RWKV7Model + +AutoConfig.register(RWKV7Config.model_type, RWKV7Config, True) +AutoModel.register(RWKV7Config, RWKV7Model, True) +AutoModelForCausalLM.register(RWKV7Config, RWKV7ForCausalLM, True) + + +__all__ = ['RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model'] diff --git a/fla/models/rwkv7/modeling_rwkv7.py b/fla/models/rwkv7/modeling_rwkv7.py new file mode 100644 index 0000000000000000000000000000000000000000..038e58d254883865f2f5d8a612ec0d0060c130c1 --- /dev/null +++ b/fla/models/rwkv7/modeling_rwkv7.py @@ -0,0 +1,505 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.layers.rwkv7 import RWKV7Attention +from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, LayerNorm +from fla.modules.activations import ACT2FN + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + +logger = logging.get_logger(__name__) + + +class RWKV7FeedForward(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'sqrelu', + layer_idx: int = None + ) -> RWKV7FeedForward: + super().__init__() + + self.hidden_size = hidden_size + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio) + intermediate_size = 32 * ((intermediate_size + 32 - 1) // 32) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + self.x_k = nn.Parameter(torch.zeros(hidden_size)) + + self.key = nn.Linear(hidden_size, intermediate_size, bias=False) + self.value = nn.Linear(intermediate_size, hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + self.layer_idx = layer_idx + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + state: Optional[Cache] = None + ) -> torch.Tensor: + if attention_mask is not None: + x = x.mul(attention_mask[:, -x.shape[-2]:, None]) + if x.shape[1] == 1 and state is not None and state[self.layer_idx]['ffn_state'] is not None: + shifted = state[self.layer_idx]['ffn_state'].unsqueeze(1) + else: + shifted = self.time_shift(x) + if state is not None and state[self.layer_idx]['ffn_state'] is not None: + shifted[:, 0] = state[self.layer_idx]['ffn_state'][-1] + if state is not None: + # no need to update the offset twice + state.update(ffn_state=x[:, -1], layer_idx=self.layer_idx, offset=0) + return self.value(self.act_fn(self.key(x.addcmul(shifted - x, self.x_k)))), state + + +class RWKV7Block(nn.Module): + + def __init__( + self, + config: RWKV7Config, + layer_idx: int + ) -> RWKV7Block: + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + if config.norm_first and layer_idx == 0: + self.pre_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)( + config.hidden_size, + bias=config.norm_bias, + eps=config.norm_eps + ) + self.attn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)( + config.hidden_size, + bias=config.norm_bias, + eps=config.norm_eps + ) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = RWKV7Attention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + head_dim=config.head_dim, + num_heads=config.num_heads, + decay_low_rank_dim=config.decay_low_rank_dim, + gate_low_rank_dim=config.gate_low_rank_dim, + a_low_rank_dim=config.a_low_rank_dim, + v_low_rank_dim=config.v_low_rank_dim, + norm_eps=config.norm_eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx, + value_dim=config.value_dim[layer_idx] + ) + self.ffn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)( + config.hidden_size, + bias=config.norm_bias, + eps=config.norm_eps + ) + self.ffn = RWKV7FeedForward( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + layer_idx=layer_idx + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + v_first: torch.Tensor = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = self.pre_norm(hidden_states) if hasattr(self, 'pre_norm') else hidden_states + hidden_states = self.attn_norm(residual) + hidden_states, attentions, past_key_values, v_first = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + v_first=v_first, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.ffn_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states, past_key_values = self.ffn(hidden_states, attention_mask, past_key_values) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values, v_first) + + return outputs + + +class RWKV7PreTrainedModel(PreTrainedModel): + + config_class = RWKV7Config + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['RWKV7Block'] + _supports_cache_class = True + _skip_keys_device_placement = ["past_key_values"] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + warnings.warn( + "RWKV-7 employs a carefully designed initialization strategy tailored to its architecture. " + "The detailed initialization scheme is currently not implemented here but can be found in the " + "official code repository. We emphasize that using the recommended initialization is essential " + "for replicating the results in RWKV-7 paper. Deviations from the prescribed initialization " + "may lead to performance degradation.\n" + "Alternatively, please generate initial weights from the official RWKV code repository, and " + "convert the PyTorch checkpoint into FLA supported format." + ) + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Parameter): + nn.init.normal_(module, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class RWKV7Model(RWKV7PreTrainedModel): + + def __init__(self, config: RWKV7Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([RWKV7Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)( + config.hidden_size, + bias=config.norm_bias, + eps=config.norm_eps + ) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`RWKV7Model` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + + v_first = torch.zeros_like(hidden_states) + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values, v_first = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + v_first, + **kwargs + ) + else: + hidden_states, attentions, past_key_values, v_first = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + v_first=v_first, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class RWKV7ForCausalLM(RWKV7PreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = RWKV7Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + labels: Optional[torch.LongTensor] = None, + shift_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + has_labels = (labels is not None) or (shift_labels is not None) + if not (fuse_linear_and_cross_entropy and has_labels): + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if has_labels: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + + # shift_labels: See https://github.com/huggingface/transformers/pull/36607/files. + if shift_labels is None: + shift_labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + shift_labels = shift_labels.to(hidden_states.device) + + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, shift_labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(shift_labels.numel(), -1), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla/ops/abc/__pycache__/chunk.cpython-312.pyc b/fla/ops/abc/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d87bc4f177090954cddfab849ce239ccfb3f8459 Binary files /dev/null and b/fla/ops/abc/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla/ops/attn/__pycache__/parallel.cpython-312.pyc b/fla/ops/attn/__pycache__/parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e3c6336f1f0453c6815e053f3e7f4927e9501f5 Binary files /dev/null and b/fla/ops/attn/__pycache__/parallel.cpython-312.pyc differ diff --git a/fla/ops/based/__pycache__/parallel.cpython-312.pyc b/fla/ops/based/__pycache__/parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fe66427faea81a03c48ccab145adf9987a0fca6 Binary files /dev/null and b/fla/ops/based/__pycache__/parallel.cpython-312.pyc differ diff --git a/fla/ops/based/fused_chunk.py b/fla/ops/based/fused_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..ff5db4fb73022c677662a4f7d29d6b2ec3015194 --- /dev/null +++ b/fla/ops/based/fused_chunk.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.jit(do_not_specialize=['T']) +def fused_chunk_based_fwd_kernel( + q, + k, + v, + o, + z, + scale, # K ** -0.5 + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + + # [BV], zero-order taylor expansion + b_h_0o = tl.zeros([BV], dtype=tl.float32) + # [BK, BV], first-order taylor expansion + b_h_1o = tl.zeros([BK, BV], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + + p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT) + k_2o = tl.zeros([1, BK * BK], dtype=tl.float32) + k_1o = tl.zeros([1, BK], dtype=tl.float32) + k_0o = 0 + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK*BK, BT] + b_k_2o = b_k[:, None, :] * b_k[None, :, :] + b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype) + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_z = tl.zeros([BT], dtype=tl.float32) + + # interchunk + # zero-order + b_o += b_h_0o + b_z += k_0o + # first-order + b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False) + b_z += tl.sum(b_q * k_1o, axis=1) + # second-order + b_q_2o = b_q[:, :, None] * b_q[:, None, :] + b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype) + b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5 + b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5 + + # update running statistics + k_1o += tl.sum(b_k, axis=1)[None, :] + k_2o += tl.sum(b_k_2o, axis=1)[None, :] + k_0o += BT + + # intrachunk + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + # [TB, BV] + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=(i * BT + tl.arange(0, BT)) < T) + + # update hidden state + # [BK, BV] + b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False) + b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False) + b_h_0o = b_h_0o + tl.sum(b_v, axis=0) + + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_z += BT + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_chunk_based_bwd_kernel( + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, + k, + v, + do, + dz, + dq, + dk, + dv, + scale, # K ** -0.5 + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + # [BV], zero-order taylor expansion + # b_h_0o = tl.zeros([BV], dtype=tl.float32) + # [BK, BV], first-order taylor expansion + b_h_1o = tl.zeros([BV, BK], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32) + + k_1o = tl.zeros([1, BK], dtype=tl.float32) + k_2o = tl.zeros([1, BK * BK], dtype=tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + + # load tensors + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T) + # [BV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # inter-chunk + b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False) + if i_v == 0: + b_dq += b_dz[:, None] * k_1o + b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5 + if i_v == 0: + b_dq_2o += (b_dz[:, None] * k_2o) * 0.5 + b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK]) + b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1) + b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2) + b_dq *= scale + + # intra-chunk + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False) + + # store + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # update hidden state + # [BT, BK*BK] + b_k_2o = b_k[:, :, None] * b_k[:, None, :] + b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype) + # [BV, BK*BK] + b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False) + # [BV, BK] + b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False) + + if i_v == 0: + # update running statistics + k_1o += tl.sum(b_k, axis=0)[None, :] + k_2o += tl.sum(b_k_2o, axis=0)[None, :] + + tl.debug_barrier() + b_h_1o = None + b_h_2o = None + + # [BK, BV], first-order taylor expansion + b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32) + b_dh_0o = tl.zeros([BV], dtype=tl.float32) + m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :] + + dq_1o = tl.zeros([1, BK], dtype=tl.float32) + dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32) + + for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * T*K, (T, K), (K, 1), (i, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (i, i_v*BV), (BT, BV), (1, 0)) + p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i + + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T) + b_q = (b_q * scale).to(b_k.dtype) + + # intra chunk + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + b_ds = tl.where(m_s, b_ds, 0) + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + b_ds *= (1+b_s) + + b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False) + b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False) + + # inter chunk + b_k_2o = b_k[:, :, None] * b_k[:, None, :] + b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype) + + b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False) + b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False) + b_dv += b_dh_0o + + b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False) + + if i_v == 0: + b_dk += dq_1o + + b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype), tl.trans(b_v), allow_tf32=False) + if i_v == 0: + b_dk_2o += dq_2o + b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT]) + b_k_fp32 = tl.trans(b_k.to(tl.float32)) + b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0) + b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1) + b_dk += tl.trans(b_dk2) + + # hidden state update + b_dh_0o += tl.sum(b_do, axis=0) + b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False) + b_q_2o = b_q[None, :, :] * b_q[:, None, :] + b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype) + b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5 + + if i_v == 0: + dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :] + dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None] + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkBasedFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, scale=1): + B, H, T, K, V = *k.shape, v.shape[-1] + + scale = scale + BT = 16 + BK, BV = min(K, 16), min(V, 32) + BK, BV = max(BK, 16), max(BV, 16) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + num_warps = 4 + + # the norm of o might explode, so we need to use float32 here + o = q.new_empty(NK, B, H, T, V, dtype=torch.float32) + z = q.new_empty(NK, B, H, T, dtype=torch.float32) + + grid = (NV, NK, B * H) + fused_chunk_based_fwd_kernel[grid]( + q, k, v, o, z, + scale, + T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + ) + o = o.sum(0) + z = z.sum(0) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.to(q.dtype), z.to(z.dtype) + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + B, H, T, K, V = *k.shape, v.shape[-1] + scale = ctx.scale + + BT = 16 + BK, BV = min(K, 16), min(V, 32) + BK, BV = max(BK, 16), max(BV, 16) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + grid = (NV, NK, B * H) + + fused_chunk_based_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + scale, + T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None + + +def fused_chunk_based( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + use_norm: bool = True, + head_first: bool = True +): + assert q.shape[-1] <= 16, 'only support feature dimension up to 16.' + if scale is None: + scale = q.shape[-1] ** -0.5 + if not head_first: + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + o, z = FusedChunkBasedFunction.apply(q, k, v, scale) + if use_norm: + o = o / (z[..., None] + 1e-6) + if not head_first: + o = o.transpose(1, 2) + return o.to(q.dtype) diff --git a/fla/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc b/fla/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abf894f9617d4608f574055fa6e73210fdfb2f17 Binary files /dev/null and b/fla/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc differ diff --git a/fla/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-312.pyc b/fla/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ae6d7bebb918a4f187260dfb61cbed986ce744f Binary files /dev/null and b/fla/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-312.pyc differ diff --git a/fla/ops/common/__pycache__/utils.cpython-312.pyc b/fla/ops/common/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe39fc7d5c9db6e7ece27a56baa06fce98ac2cbc Binary files /dev/null and b/fla/ops/common/__pycache__/utils.cpython-312.pyc differ diff --git a/fla/ops/delta_rule/__pycache__/__init__.cpython-312.pyc b/fla/ops/delta_rule/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a461a8df20ed92c3dba841be9163ac9fa8efd474 Binary files /dev/null and b/fla/ops/delta_rule/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/ops/delta_rule/wy_fast.py b/fla/ops/delta_rule/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..5a863b91556ba0c33ef47f331d24d1c352d64c79 --- /dev/null +++ b/fla/ops/delta_rule/wy_fast.py @@ -0,0 +1,340 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from fla.ops.utils.solve_tril import solve_tril +from fla.utils import check_shared_mem, is_nvidia_hopper + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'], +) +@triton.jit(do_not_specialize=['T']) +def fwd_recompute_w_u_kernel( + k, + v, + beta, + w, + u, + A, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A.to(b_vb.dtype), b_vb, allow_tf32=False) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A.to(b_kb.dtype), b_kb, allow_tf32=False) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'], +) +@triton.jit(do_not_specialize=['T']) +def bwd_prepare_wy_repr_kernel( + k, + v, + beta, + A, + dw, + du, + dk, + dv, + dbeta, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T*BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + else: + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False) + b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False) + b_dv = b_dv_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dv_beta * b_v, 1) + + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False) + b_dk = b_dk_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) + b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + + b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + if HEAD_FIRST: + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_dbeta = tl.make_block_ptr(dbeta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + +def fwd_prepare_wy_repr( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool = False, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + A = chunk_scaled_dot_kkt_fwd( + k=k, + beta=beta, + cu_seqlens=offsets, + head_first=head_first, + chunk_size=chunk_size, + output_dtype=torch.float32 + ) + A = solve_tril( + A=A, + cu_seqlens=offsets, + head_first=head_first, + output_dtype=k.dtype + ) + + w, u = fwd_recompute_w_u( + k=k, + v=v, + beta=beta, + A=A, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return w, u, A + + +def fwd_recompute_w_u( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + u = torch.empty_like(v) + w = torch.empty_like(k) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, + v, + beta, + w, + u, + A, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return w, u + + +def bwd_prepare_wy_repr( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbeta = torch.empty_like(beta) + bwd_prepare_wy_repr_kernel[(NT, B * H)]( + k, + v, + beta, + A, + dw, + du, + dk, + dv, + dbeta, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dk, dv, dbeta diff --git a/fla/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc b/fla/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5d112abbdb0abe99a43b465102b1b20aa37f77d Binary files /dev/null and b/fla/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc b/fla/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32fde82c4eaa28a90804d49ec316f24e0e6e8364 Binary files /dev/null and b/fla/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla/ops/gated_delta_rule/__pycache__/wy_fast.cpython-312.pyc b/fla/ops/gated_delta_rule/__pycache__/wy_fast.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49819cf46bd534df65a63e463a63178b7a3c6473 Binary files /dev/null and b/fla/ops/gated_delta_rule/__pycache__/wy_fast.cpython-312.pyc differ diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..f80b2251f32e60dda83735f74183546b15ef45a0 --- /dev/null +++ b/fla/ops/gated_delta_rule/wy_fast.py @@ -0,0 +1,620 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import safe_exp +from fla.utils import check_shared_mem + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'BT', 'BK', 'BC', 'HEAD_FIRST', 'USE_OFFSETS'], +) +@triton.jit(do_not_specialize=['T']) +def fwd_prepare_wy_repr_kernel_chunk32( + k, + g, + beta, + Aw, + Au, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_Aw = tl.zeros([BC, BC], dtype=tl.float32) + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + + b_beta = tl.load(p_beta, boundary_check=(0,)) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_Aw += tl.dot(b_kb, tl.trans(b_k)) + + b_Aw = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw, 0) + + if HEAD_FIRST: + p_g = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_g = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + + b_g = tl.load(p_g, boundary_check=(0,)) + b_Au = b_Aw * safe_exp(b_g[:, None] - b_g[None, :]) + + for i in range(1, BC): + mask = tl.arange(0, BC) == i + b_aw = tl.sum(tl.where(mask[:, None], b_Aw, 0), 0) + b_au = tl.sum(tl.where(mask[:, None], b_Au, 0), 0) + b_aw = b_aw + tl.sum(b_aw[:, None] * b_Aw, 0) * (tl.arange(0, BC) < i) + b_au = b_au + tl.sum(b_au[:, None] * b_Au, 0) * (tl.arange(0, BC) < i) + b_Aw = tl.where(mask[:, None], b_aw, b_Aw) + b_Au = tl.where(mask[:, None], b_au, b_Au) + + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_Aw += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_Au += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + if HEAD_FIRST: + p_Aw = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Au = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + else: + p_Aw = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Au = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + tl.store(p_Aw, b_Aw.to(p_Aw.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Au, b_Au.to(p_Au.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'BT', 'BK', 'BC', 'USE_OFFSETS', 'HEAD_FIRST'], +) +@triton.jit(do_not_specialize=['T']) +def fwd_prepare_wy_repr_kernel_chunk64( + k, + g, + beta, + Aw, + Au, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_Aw = tl.zeros([BC, BC], dtype=tl.float32) + b_Aw2 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aw3 = tl.zeros([BC, BC], dtype=tl.float32) + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT,), (BC,), (0,)) + p_beta2 = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT + BC,), (BC,), (0,)) + else: + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BC,), (0,)) + p_beta2 = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT + BC,), (BC,), (0,)) + + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_beta2 = tl.load(p_beta2, boundary_check=(0,)) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_k2 = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_k2 = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_k2 = tl.load(p_k2, boundary_check=(0, 1)) + b_kb2 = (b_k2 * b_beta2[:, None]).to(b_k2.dtype) + b_Aw += tl.dot(b_kb, tl.trans(b_k)) + b_Aw2 += tl.dot(b_kb2, tl.trans(b_k2)) + b_Aw3 += tl.dot(b_kb2, tl.trans(b_k)) + + b_Aw = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw, 0) + b_Aw2 = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw2, 0) + + if HEAD_FIRST: + p_g = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT,), (BC,), (0,)) + p_g2 = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT + BC,), (BC,), (0,)) + else: + p_g = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT,), (BC,), (0,)) + p_g2 = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT + BC,), (BC,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g2 = tl.load(p_g2, boundary_check=(0,)) + + mask_c = tl.arange(0, BC)[:, None] >= tl.arange(0, BC)[None, :] + mask_g = i_t * BT + tl.arange(0, BC) < T + mask_g2 = i_t * BT + BC + tl.arange(0, BC) < T + + b_Au = tl.where(mask_g[None, :] & mask_c, b_Aw * safe_exp(b_g[:, None] - b_g[None, :]), 0) + b_Au2 = tl.where(mask_g2[None, :] & mask_c, b_Aw2 * safe_exp(b_g2[:, None] - b_g2[None, :]), 0) + b_Au3 = tl.where(mask_g[None, :], b_Aw3 * safe_exp(b_g2[:, None] - b_g[None, :]), 0) + + for i in range(1, BC): + mask = tl.arange(0, BC) == i + b_aw = tl.sum(tl.where(mask[:, None], b_Aw, 0), 0) + b_aw2 = tl.sum(tl.where(mask[:, None], b_Aw2, 0), 0) + b_au = tl.sum(tl.where(mask[:, None], b_Au, 0), 0) + b_au2 = tl.sum(tl.where(mask[:, None], b_Au2, 0), 0) + b_aw = b_aw + tl.sum(b_aw[:, None] * b_Aw, 0) * (tl.arange(0, BC) < i) + b_aw2 = b_aw2 + tl.sum(b_aw2[:, None] * b_Aw2, 0) * (tl.arange(0, BC) < i) + b_au = b_au + tl.sum(b_au[:, None] * b_Au, 0) * (tl.arange(0, BC) < i) + b_au2 = b_au2 + tl.sum(b_au2[:, None] * b_Au2, 0) * (tl.arange(0, BC) < i) + b_Aw = tl.where(mask[:, None], b_aw, b_Aw) + b_Aw2 = tl.where(mask[:, None], b_aw2, b_Aw2) + b_Au = tl.where(mask[:, None], b_au, b_Au) + b_Au2 = tl.where(mask[:, None], b_au2, b_Au2) + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_Aw += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_Aw2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + # improve precision by disallowing tf32. + b_Aw3 = -tl.dot(tl.dot(b_Aw2, b_Aw3, allow_tf32=False), b_Aw, allow_tf32=False) + b_Au += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_Au2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_Au3 = -tl.dot(tl.dot(b_Au2, b_Au3, allow_tf32=False), b_Au, allow_tf32=False) + + if HEAD_FIRST: + p_Aw1 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Aw2 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_Aw3 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_Aw4 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + p_Au1 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Au2 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_Au3 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_Au4 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + else: + p_Aw1 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Aw2 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_Aw3 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_Aw4 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + p_Au1 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Au2 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_Au3 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_Au4 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + + tl.store(p_Aw1, b_Aw.to(p_Aw1.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aw2, b_Aw2.to(p_Aw2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aw3, b_Aw3.to(p_Aw3.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aw4, tl.zeros([BC, BC], dtype=tl.float32).to(p_Aw4.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Au1, b_Au.to(p_Au1.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Au2, b_Au2.to(p_Au2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Au3, b_Au3.to(p_Au3.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Au4, tl.zeros([BC, BC], dtype=tl.float32).to(p_Au4.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'], +) +@triton.jit(do_not_specialize=['T']) +def fwd_recompute_w_u_kernel( + k, + v, + beta, + w, + u, + Aw, + Au, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_Au = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_Au = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_Au = tl.load(p_Au, boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_Au, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + b_Au = None + if HEAD_FIRST: + p_Aw = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_Aw = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_Aw = tl.load(p_Aw, boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_w = tl.dot(b_Aw, b_kb) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def fwd_prepare_wy_repr( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K = k.shape + else: + B, T, H, K = k.shape + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BC = min(BT, 32) + BK = min(triton.next_power_of_2(K), 64) + # bf16 should be good enough. + Aw = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=k.dtype) + Au = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=k.dtype) + + fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32 + fwd_fn[(NT, B*H)]( + k=k, + g=g, + beta=beta, + Aw=Aw, + Au=Au, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + BT=BT, + BK=BK, + BC=BC, + HEAD_FIRST=head_first + ) + w, u = fwd_recompute_w_u( + k=k, + v=v, + beta=beta, + Aw=Aw, + Au=Au, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return w, u, Aw, Au + + +def fwd_recompute_w_u( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + Aw: torch.Tensor, + Au: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + + u = torch.empty_like(v) + w = torch.empty_like(k) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + Aw=Aw, + Au=Au, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return w, u + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'] +) +@triton.jit(do_not_specialize=['T']) +def bwd_prepare_wy_repr_kernel( + k, + v, + beta, + g, + Aw, + Au, + dw, + du, + dk, + dv, + dbeta, + dg, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(Aw + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + else: + p_beta = tl.make_block_ptr(beta + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False) + b_dk = b_dk_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) + b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty) + + if HEAD_FIRST: + p_A = tl.make_block_ptr(Au + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + else: + p_A = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dA2 = tl.zeros([BT, BT], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA2 += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False) + b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False) + b_dv = b_dv_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dv_beta * b_v, 1) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dA2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA2, 0) + b_dA2 = tl.dot(b_dA2.to(b_A.dtype), b_A) + b_dA2 = tl.dot(b_A, b_dA2.to(b_A.dtype)) + b_dA2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA2, 0).to(k.dtype.element_ty) + if HEAD_FIRST: + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_g = tl.make_block_ptr(g + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_dA2 *= safe_exp(b_g[:, None] - b_g[None, :]) + b_dA += b_dA2 + b_dA = b_dA.to(k.dtype.element_ty) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_A += tl.dot(b_k_beta, tl.trans(b_k)) + b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + b_dA2 *= b_A + b_dg = tl.sum(b_dA2, axis=1) - tl.sum(b_dA2, axis=0) + if HEAD_FIRST: + p_dg = tl.make_block_ptr(dg + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_dg = tl.make_block_ptr(dg + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_dbeta = tl.make_block_ptr(dbeta + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + +def bwd_prepare_wy_repr( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + Aw: torch.Tensor, + Au: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbeta = torch.empty_like(beta) + dg = torch.empty_like(g) + bwd_prepare_wy_repr_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + g=g, + Aw=Aw, + Au=Au, + dw=dw, + du=du, + dk=dk, + dv=dv, + dbeta=dbeta, + dg=dg, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dk, dv, dbeta, dg diff --git a/fla/ops/generalized_delta_rule/dplr/__init__.py b/fla/ops/generalized_delta_rule/dplr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6de2928ca88abc25dc3156c4dc4fcb13ace180d --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/__init__.py @@ -0,0 +1,7 @@ +from .chunk import chunk_dplr_delta_rule +from .fused_recurrent import fused_recurrent_dplr_delta_rule + +__all__ = [ + 'chunk_dplr_delta_rule', + 'fused_recurrent_dplr_delta_rule' +] diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f35fc5f5aedfe6d7be9b232a779c0beabec6ea6 Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60d472810a7b22d3b9d7bd300466205d32ec19f9 Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2bec275e0a1914ba9138ae035ce6b9117e5fa3f Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-312.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4d6a50ba8b3c6f3d533a212259ef6b8f195f87c Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-312.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-312.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72c27c100bf77114f72db43c5836d6bd3548dea5 Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-312.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62b367387b2f44d3c5b2cfae916ef9b193b56d64 Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53dc8089d153e014972c7263db6d064b7592cc1d Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ba9b9261786c824aae9984a527d826ea2cbb99d Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..08518c203594e0f63f1e88b849ab688922e94f34 --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py @@ -0,0 +1,324 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp, gather +from fla.utils import is_gather_supported, use_cuda_graph + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=['BC', 'K'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_fwd_A_kernel_intra_sub_inter( + q, + k, + a, + b, + gi, # cumsum + ge, # before cumsum + Aqk, + Aqb, + Aab, + Aak, + offsets, + indices, + scale: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_i, i_j = i_c // NC, i_c % NC + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + if i_i <= i_j: + return + + b_Aqk = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqb = tl.zeros([BC, BC], dtype=tl.float32) + b_Aab = tl.zeros([BC, BC], dtype=tl.float32) + b_Aak = tl.zeros([BC, BC], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gq_i = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gq_e = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(gi + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_i * BC - 1) * K + o_k, BK), BK) + else: + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_a = tl.make_block_ptr(a + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gq_i = tl.make_block_ptr(gi + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gq_e = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_b = tl.make_block_ptr(b + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(gi + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = gi + (bos + i_t * BT + i_i * BC - 1) * H*K + i_h * K + o_k + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_gq_i = tl.load(p_gq_i, boundary_check=(0, 1)) + b_gq_e = tl.load(p_gq_e, boundary_check=(0, 1)) + b_ag = b_a * exp(b_gq_e - b_gn[None, :]) + b_qg = b_q * exp(b_gq_i - b_gn[None, :]) * scale + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)).to(tl.float32) + tmp = exp(b_gn[:, None] - b_gk) + b_kg = b_k * tmp + b_bg = b_b * tmp + # [BC, BC] using tf32 to improve precision here. + b_Aab += tl.dot(b_ag, b_bg) + b_Aak += tl.dot(b_ag, b_kg) + b_Aqk += tl.dot(b_qg, b_kg) + b_Aqb += tl.dot(b_qg, b_bg) + + if HEAD_FIRST: + p_Aqk = tl.make_block_ptr(Aqk + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_Aqb = tl.make_block_ptr(Aqb + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_Aab = tl.make_block_ptr(Aab + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_Aak = tl.make_block_ptr(Aak + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + else: + p_Aqk = tl.make_block_ptr(Aqk + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_Aqb = tl.make_block_ptr(Aqb + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_Aab = tl.make_block_ptr(Aab + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_Aak = tl.make_block_ptr(Aak + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Aqb, b_Aqb.to(Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Aab, b_Aab.to(Aab.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Aak, b_Aak.to(Aak.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BK', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_fwd_A_kernel_intra_sub_intra( + q, + k, + a, + b, + gi, + ge, + qg, + kg, + ag, + bg, + Aqk, + Aqb, + Aab, + Aak, + offsets, + indices, + scale: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, + GATHER_SUPPORTED: tl.constexpr +): + i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_j = i_i + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + + o_i = tl.arange(0, BC) + o_k = tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + last_idx = min((i_t+1) * BT, T) - 1 + if HEAD_FIRST: + o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_b = tl.make_block_ptr(b + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_gi = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_ge = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_g_last = gi + i_bh * T*K + last_idx * K + tl.arange(0, BK) + b_g_last = tl.load(p_g_last, mask=m_k, other=0) + + p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_kg = tl.make_block_ptr(kg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_bg = tl.make_block_ptr(bg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + else: + o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_gi = tl.make_block_ptr(gi + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_ge = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_g_last = gi + (bos * H + i_h) * K + last_idx * H * K + tl.arange(0, BK) + b_g_last = tl.load(p_g_last, mask=m_k, other=0) + p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_bg = tl.make_block_ptr(bg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = b_q * scale + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_gi = tl.load(p_gi, boundary_check=(0, 1)).to(tl.float32) + b_ge = tl.load(p_ge, boundary_check=(0, 1)).to(tl.float32) + + # deal with decay term. + g_exp = exp(b_gi) + g_exp_inv = exp(-b_gi + b_g_last[None, :]) + b_qg = b_q * g_exp + b_kg = b_k * g_exp_inv + b_bg = b_b * g_exp_inv + b_ag = b_a * exp(b_ge) + tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_bg, b_bg.to(p_bg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_ag, b_ag.to(p_ag.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + # tl.debug_barrier() + + b_q = b_q.to(b_k.dtype) + # inner attn + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # a trick to index the j-th row of b_k, b_g, b_b + if GATHER_SUPPORTED: + row_idx = tl.full([1, BK], j, dtype=tl.int16) + # [1, BK] + b_k_j = gather(b_k, row_idx, axis=0) + b_gk_j = gather(b_gi, row_idx, axis=0) + b_b_j = gather(b_b, row_idx, axis=0) + else: + mask = tl.arange(0, BC) == j + b_k_j = tl.sum(tl.where(mask[:, None], b_k, 0), 0)[None, :] + b_gk_j = tl.sum(tl.where(mask[:, None], b_gi, 0), 0)[None, :] + b_b_j = tl.sum(tl.where(mask[:, None], b_b, 0), 0)[None, :] + mask = tl.arange(0, BC) == j + tmp = exp(b_gi - b_gk_j) + b_A_qk = tl.sum(b_q * b_k_j * tmp, 1) + b_A_qk = tl.where(o_i >= j, b_A_qk, 0.) + b_A_qb = tl.sum(b_q * b_b_j * tmp, 1) + b_A_qb = tl.where(o_i >= j, b_A_qb, 0.) + tmp2 = exp(b_ge - b_gk_j) + b_A_ak = tl.sum(b_a * b_k_j * tmp2, 1) + b_A_ak = tl.where(o_i > j, b_A_ak, 0.) + b_A_ab = tl.sum(b_a * b_b_j * tmp2, 1) + b_A_ab = tl.where(o_i > j, b_A_ab, 0.) + tl.store(Aqk + o_A + j, b_A_qk.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) + tl.store(Aqb + o_A + j, b_A_qb.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) + tl.store(Aab + o_A + j, b_A_ab.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) + tl.store(Aak + o_A + j, b_A_ak.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) + + +def chunk_fwd_intra_dplr_fn( + q: torch.Tensor, + k: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + scale: float, + chunk_size: int, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, +): + if head_first: + B, H, T, K = k.shape + else: + B, T, H, K = k.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BC = min(16, BT) + NC = triton.cdiv(BT, BC) + + Aqk = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype) + Aqb = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype) + # involving matrix inverse and it'd be better to use float here. + Aab = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float) + Aak = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float) + grid = (NT, NC * NC, B * H) + + chunk_dplr_fwd_A_kernel_intra_sub_inter[grid]( + q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak, + offsets=offsets, indices=indices, + scale=scale, + T=T, H=H, K=K, BT=BT, BC=BC, NC=NC, + HEAD_FIRST=head_first + ) + grid = (NT, NC, B * H) + BK = triton.next_power_of_2(K) + qg = torch.empty_like(q) + kg = torch.empty_like(k, dtype=q.dtype) + ag = torch.empty_like(a, dtype=q.dtype) + bg = torch.empty_like(b, dtype=q.dtype) + chunk_dplr_fwd_A_kernel_intra_sub_intra[grid]( + q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak, + qg=qg, kg=kg, ag=ag, bg=bg, + offsets=offsets, indices=indices, + scale=scale, + T=T, H=H, K=K, BT=BT, BC=BC, BK=BK, HEAD_FIRST=head_first, NC=NC, + GATHER_SUPPORTED=is_gather_supported + ) + return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..b382d5905af9547a0626585453e43b01bf1b706c --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.utils import prepare_chunk_offsets +from fla.ops.utils.op import exp +from fla.utils import check_shared_mem, use_cuda_graph + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_fwd_kernel_h( + kg, + v, + w, + bg, + u, + v_new, + gk, + h, + h0, + ht, + offsets, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + b_hc = tl.zeros([BK, BV], dtype=tl.float32) + # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden + for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)): + if HEAD_FIRST: + p_kg = tl.make_block_ptr(kg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_bg = tl.make_block_ptr(bg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_w = tl.make_block_ptr(w + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_u = tl.make_block_ptr(u + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + else: + p_kg = tl.make_block_ptr(kg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0)) + # [BK, BC] + b_kg = tl.load(p_kg, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_bg = tl.load(p_bg, boundary_check=(0, 1)) + b_v2 = tl.dot(b_w, b_h.to(b_w.dtype)) + tl.load(p_u, boundary_check=(0, 1)) + b_hc += tl.dot(b_kg, b_v) + b_hc += tl.dot(b_bg.to(b_hc.dtype), b_v2) + tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + if HEAD_FIRST: + b_g_last = tl.load(gk + i_nh * T * K + last_idx * K + tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32) + else: + b_g_last = tl.load(gk + (bos + last_idx) * H * K + i_h * K + + tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32) + b_h *= exp(b_g_last[:, None]) + b_h += b_hc + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +def chunk_dplr_fwd_h( + kg: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + bg: torch.Tensor, + gk: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *kg.shape, u.shape[-1] + else: + B, T, H, K, V = *kg.shape, u.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT) + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension larger than 256." + # H100 can have larger block size + + if check_shared_mem('hopper', kg.device.index): + BV = 64 + BC = 64 if K <= 128 else 32 + elif check_shared_mem('ampere', kg.device.index): # A100 + BV = 32 + BC = 32 + else: + BV = 16 + BC = 16 + + BC = min(BT, BC) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + if head_first: + h = kg.new_empty(B, H, NT, K, V) + else: + h = kg.new_empty(B, NT, H, K, V) + final_state = kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + v_new = torch.empty_like(u) + grid = (NK, NV, N * H) + chunk_dplr_fwd_kernel_h[grid]( + kg=kg, + v=v, + w=w, + bg=bg, + u=u, + v_new=v_new, + h=h, + gk=gk, + h0=initial_state, + ht=final_state, + offsets=offsets, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BC=BC, + BK=BK, + BV=BV, + NT=NT, + HEAD_FIRST=head_first + ) + return h, v_new, final_state diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a17bcfb2bad98fbf5df1dab70f21a86a59f111 --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py @@ -0,0 +1,464 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp +from fla.utils import check_shared_mem, use_cuda_graph + +BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BV', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_bwd_kernel_dAu( + v, + do, + v_new, + A_qb, + dA_qk, + dA_qb, + dv_new, + offsets, + indices, + scale: tl.constexpr, + T, + H: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + b_dA_qk = tl.zeros([BT, BT], dtype=tl.float32) + b_dA_qb = tl.zeros([BT, BT], dtype=tl.float32) + + if HEAD_FIRST: + p_A_qb = tl.make_block_ptr(A_qb + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_A_qb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_A_qb = tl.load(p_A_qb, boundary_check=(0, 1)) + # causal mask + b_A_qb = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_A_qb, 0.).to(b_A_qb.dtype) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + p_v_new = tl.make_block_ptr(v_new + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + p_dv_new = tl.make_block_ptr(dv_new + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + p_v_new = tl.make_block_ptr(v_new + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + p_dv_new = tl.make_block_ptr(dv_new + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_v_new = tl.load(p_v_new, boundary_check=(0, 1)) + b_dA_qk += tl.dot(b_do, b_v) + b_dA_qb += tl.dot(b_do, b_v_new) + b_dv_new = tl.dot(tl.trans(b_A_qb), b_do) + # for recurrent + tl.store(p_dv_new, b_dv_new.to(p_dv_new.dtype.element_ty), boundary_check=(0, 1)) + + if HEAD_FIRST: + p_dA_qk = tl.make_block_ptr(dA_qk + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dA_qb = tl.make_block_ptr(dA_qb + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_dA_qk = tl.make_block_ptr(dA_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dA_qb = tl.make_block_ptr(dA_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + b_dA_qk = tl.where(m_s, b_dA_qk * scale, 0.) + tl.store(p_dA_qk, b_dA_qk.to(p_dA_qk.dtype.element_ty), boundary_check=(0, 1)) + b_dA_qb = tl.where(m_s, b_dA_qb * scale, 0.) + tl.store(p_dA_qb, b_dA_qb.to(p_dA_qb.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit +def chunk_dplr_bwd_o_kernel( + v, + v_new, + h, + do, + dh, + dk, + db, + w, + dq, + dv, + dw, + gk, + dgk_last, + k, + b, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + v += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + v_new += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + h += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K * V + dh += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K * V + dk += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + db += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + b += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + dw += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + dv += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + dq += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + w += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + # CHECK HEAD_FIRST is FALSE + dgk_last += (i_bh * NT + i_t) * K if HEAD_FIRST else (i_tg * H + i_h) * K + gk += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + + stride_qk = K if HEAD_FIRST else H*K + stride_vo = V if HEAD_FIRST else H*V + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT, BK], dtype=tl.float32) + b_db = tl.zeros([BT, BK], dtype=tl.float32) + b_dgk_last = tl.zeros([BK], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_new = tl.load(p_v_new, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + b_dgk_last += tl.sum((b_h * b_dh).to(tl.float32), axis=0) + + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + b_db += tl.dot(b_v_new, b_dh.to(b_v_new.dtype)) + p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype)) + + m_k = (i_k*BK+tl.arange(0, BK)) < K + last_idx = min(i_t * BT + BT, T) - 1 + b_gk_last = tl.load(gk + last_idx * stride_qk + i_k*BK + tl.arange(0, BK), mask=m_k, other=float('-inf')) + b_dgk_last *= exp(b_gk_last) + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_dgk_last += tl.sum(b_k * b_dk, axis=0) + b_dgk_last += tl.sum(b_b * b_db, axis=0) + tl.store(dgk_last + tl.arange(0, BK) + i_k * BK, b_dgk_last, mask=m_k) + + p_dw = tl.make_block_ptr(dw, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + for BK in BK_LIST + for BV in BK_LIST + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit +def chunk_dplr_bwd_kernel_dv( + A_qk, + kg, + do, + dv, + dh, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + # offset calculation + A_qk += i_bh * T * BT if HEAD_FIRST else (bos * H + i_h) * BT + do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + dv += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + kg += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + dh += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K*V + + stride_qk = K if HEAD_FIRST else H*K + stride_vo = V if HEAD_FIRST else H*V + stride_A = BT if HEAD_FIRST else H*BT + + for i_k in range(tl.cdiv(K, BK)): + p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_kg = tl.make_block_ptr(kg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + b_kg = tl.load(p_kg, boundary_check=(0, 1)) + b_dv += tl.dot(b_kg, b_dh.to(b_kg.dtype)) + + p_Aqk = tl.make_block_ptr(A_qk, (BT, T), (1, stride_A), (0, i_t * BT), (BT, BT), (0, 1)) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], tl.load(p_Aqk, boundary_check=(0, 1)), 0) + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv += tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_dplr_bwd_dv( + A_qk: torch.Tensor, + kg: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> torch.Tensor: + if head_first: + B, H, T, K, V = *kg.shape, do.shape[-1] + else: + B, T, H, K, V = *kg.shape, do.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + dv = torch.empty_like(do) + + def grid(meta): return ( + triton.cdiv(V, meta['BV']), + NT, + B * H + ) + chunk_dplr_bwd_kernel_dv[grid]( + A_qk=A_qk, + kg=kg, + do=do, + dv=dv, + dh=dh, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + HEAD_FIRST=head_first + ) + return dv + + +def chunk_dplr_bwd_o( + k: torch.Tensor, + b: torch.Tensor, + v: torch.Tensor, + v_new: torch.Tensor, + gk: torch.Tensor, + do: torch.Tensor, + h: torch.Tensor, + dh: torch.Tensor, + dv: torch.Tensor, + w: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + scale: float = 1.0, + head_first: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + if head_first: + B, H, T, K, V = *w.shape, v.shape[-1] + else: + B, T, H, K, V = *w.shape, v.shape[-1] + + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + BK = min(triton.next_power_of_2(K), 64) if check_shared_mem() else min(triton.next_power_of_2(K), 32) + BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(K), 32) + NK = triton.cdiv(K, BK) + dq = torch.empty_like(k) + dk = torch.empty_like(k) + dw = torch.empty_like(w) + db = torch.empty_like(b) + grid = (NK, NT, B * H) + + dgk_last = torch.empty(B, H, NT, K, dtype=torch.float, device=w.device) if head_first \ + else torch.empty(B, NT, H, K, dtype=torch.float, device=w.device) + + chunk_dplr_bwd_o_kernel[grid]( + k=k, + b=b, + v=v, + v_new=v_new, + h=h, + do=do, + dh=dh, + dq=dq, + dk=dk, + db=db, + dgk_last=dgk_last, + w=w, + dv=dv, + dw=dw, + gk=gk, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first, + ) + return dq, dk, dw, db, dgk_last + + +def chunk_dplr_bwd_dAu( + v: torch.Tensor, + v_new: torch.Tensor, + do: torch.Tensor, + A_qb: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> torch.Tensor: + if head_first: + B, H, T, V = v.shape + else: + B, T, H, V = v.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + if check_shared_mem('ampere'): # A100 + BV = min(triton.next_power_of_2(V), 128) + elif check_shared_mem('ada'): # 4090 + BV = min(triton.next_power_of_2(V), 64) + else: + BV = min(triton.next_power_of_2(V), 32) + + grid = (NT, B * H) + dA_qk = torch.empty(B, H, T, BT, dtype=torch.float, device=v.device) if head_first \ + else torch.empty(B, T, H, BT, dtype=torch.float, device=v.device) + dA_qb = torch.empty(B, H, T, BT, dtype=torch.float, device=v.device) if head_first \ + else torch.empty(B, T, H, BT, dtype=torch.float, device=v.device) + dv_new = torch.empty_like(v_new) + chunk_dplr_bwd_kernel_dAu[grid]( + v=v, + do=do, + v_new=v_new, + A_qb=A_qb, + dA_qk=dA_qk, + dA_qb=dA_qb, + dv_new=dv_new, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + H=H, + V=V, + BT=BT, + BV=BV, + HEAD_FIRST=head_first + ) + return dv_new, dA_qk, dA_qb diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..981901295b1b79ad881d7bd8600582e6c421f28a --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.utils import check_shared_mem, use_cuda_graph + +BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BK_LIST + for BV in BK_LIST + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_fwd_kernel_o( + qg, + v, + v_new, + A_qk, + A_qb, + h, + o, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_qg = tl.load(p_qg, boundary_check=(0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_qg, b_h) + + if HEAD_FIRST: + p_Aqk = tl.make_block_ptr(A_qk + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_Aqb = tl.make_block_ptr(A_qb + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_Aqk = tl.make_block_ptr(A_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_Aqb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1)) + b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1)) + b_Aqk = tl.where(m_s, b_Aqk, 0) + b_Aqb = tl.where(m_s, b_Aqb, 0) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_new = tl.load(p_v_new, boundary_check=(0, 1)) + b_o = b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_dplr_fwd_o( + qg: torch.Tensor, + v: torch.Tensor, + v_new: torch.Tensor, + A_qk: torch.Tensor, + A_qb: torch.Tensor, + h: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> torch.Tensor: + if head_first: + B, H, T, K, V = *qg.shape, v.shape[-1] + else: + B, T, H, K, V = *qg.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + o = torch.empty_like(v) + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) + chunk_dplr_fwd_kernel_o[grid]( + qg=qg, + v=v, + v_new=v_new, + A_qk=A_qk, + A_qb=A_qb, + h=h, + o=o, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + HEAD_FIRST=head_first + ) + return o diff --git a/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py b/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..17e7f3483a21de3f634ae70e58f8b7858810dda3 --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, use_cuda_graph + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BV in [16, 32, 64] + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=['BK'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_dplr_delta_rule_fwd_kernel( + q, + k, + v, + a, + b, + gk, + o, + h0, + ht, + offsets, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + o_k = tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + if HEAD_FIRST: + p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k + p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k + p_a = a + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k + p_b = b + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k + p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k + p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v + p_o = o + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v + + else: + p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_b = b + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[None, :] & mask_v[:, None] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32) + b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + + tmp = tl.sum(b_h * b_a[None, :], axis=1) + b_h = exp(b_gk)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None]) + b_o = tl.sum(b_h * b_q[None, :], axis=1) + + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_a += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_b += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_dplr_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: Optional[float] = 1.0, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if offsets is None else len(offsets) - 1 + BK = triton.next_power_of_2(K) + + h0 = initial_state + if output_final_state: + ht = q.new_empty(N, H, K, V, dtype=torch.float32) + else: + ht = None + o = torch.empty_like(v) + + def grid(meta): return (triton.cdiv(V, meta['BV']), N * H) + fused_recurrent_dplr_delta_rule_fwd_kernel[grid]( + q, + k, + v, + a, + b, + gk, + o, + h0, + ht, + offsets, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + REVERSE=reverse, + HEAD_FIRST=head_first + ) + return o, ht + + +class FusedRecurrentDPLRDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: Optional[float] = 1.0, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = False + ): + o, ht = fused_recurrent_dplr_delta_rule_fwd( + q=q, + k=k, + v=v, + a=a, + b=b, + gk=gk, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + offsets=offsets, + head_first=head_first + ) + return o, ht + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht): + raise NotImplementedError( + "Backward pass for fused_recurrent_dplr_delta_rule is not implemented and will not be supported. " + "This kernel is only for inference. " + "For training, please use `chunk_dplr_delta_rule`." + ) + + +def fused_recurrent_dplr_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: Optional[float] = 1.0, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner. + + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` + k (torch.Tensor): + keys of shape `[B, H, T, K]` + v (torch.Tensor): + values of shape `[B, H, T, V]` + a (torch.Tensor): + as of shape `[B, H, T, K]` + b (torch.Tensor): + bs of shape `[B, H, T, K]` + gk (torch.Tensor): + gk of shape `[B, H, T, K]` + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If None, it will default to `1 / sqrt(K)`. Default: `1.0`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[B, H, K, V]`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. + reverse (Optional[bool]): + If `True`, process the state passing in reverse order. Default: `False`. + cu_seqlens (Optional[torch.Tensor]): + Cumulative sequence lengths of shape `[N + 1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + if scale is None: + scale = q.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + o, final_state = FusedRecurrentDPLRDeltaRuleFunction.apply( + q, + k, + v, + a, + b, + gk, + scale, + initial_state, + output_final_state, + reverse, + cu_seqlens, + head_first + ) + return o, final_state diff --git a/fla/ops/generalized_delta_rule/dplr/naive.py b/fla/ops/generalized_delta_rule/dplr/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..d6ac253673e5361a375286347253f7d4e6f7a2f3 --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/naive.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + +# S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T +# q, k, alpha, beta [B, H, L, D_K] +# v [B, H, L, D_V] + + +def dplr_recurrence(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True): + orig_dtype = q.dtype + b, h, l, d_k = q.shape + q, k, v, beta, gk = map(lambda x: x.float(), [q, k, v, beta, gk]) + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + + if initial_state is not None: + S += initial_state + + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i] + _alpha = alpha[:, :, i].clone() + _beta = beta[:, :, i].clone() + _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None] + S = S.clone() * gk[:, :, i].exp()[..., None] + _kv + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + S = None if output_final_state is False else S + return o.to(orig_dtype), S + + +def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True, chunk_size=32): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + q = q * (d_k ** -0.5) + v = v + assert l % chunk_size == 0 + + S = k.new_zeros(b, h, d_k, d_v).to(q) + if initial_state is not None: + S += initial_state + + # note that diagonal is masked. + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0) + q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', + c=chunk_size).float(), [q, k, v, alpha, beta, gk]) + + gk_cumsum = gk.cumsum(-2) + + # v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v + A_ab = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device) + A_qk = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device) + A_ak = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device) + A_qb = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device) + + for i in range(chunk_size): + alpha_i = alpha[:, :, :, i, None] + q_i = q[:, :, :, i, None] + gk_i = gk_cumsum[:, :, :, i, None] + mask = (torch.arange(chunk_size) <= i).to(q.device) + attn_i = (gk_i - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp() + A_qk[:, :, :, i, :] = (q_i * k * attn_i).sum(-1).clone() + A_qb[:, :, :, i, :] = (q_i * beta * attn_i).sum(-1).clone() + mask = (torch.arange(chunk_size) < i).to(q.device) + # shift by one. + attn_i = (gk_i - gk[:, :, :, i, None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp() + A_ab[:, :, :, i, :] = (alpha_i * beta * attn_i).sum(-1).clone() + A_ak[:, :, :, i, :] = (alpha_i * k * attn_i).sum(-1).clone() + + A_ab = A_ab + for i in range(1, chunk_size): + A_ab[..., i, :i] = A_ab[..., i, :i].clone() + (A_ab[..., i, :, None].clone() * A_ab[..., :, :i].clone()).sum(-2) + + A_ab = A_ab + torch.eye(chunk_size, dtype=torch.float, device=q.device) + u = A_ab @ (A_ak @ v) + w = A_ab @ ((gk_cumsum-gk).exp() * alpha) + + o = torch.zeros_like(v) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) + for i in range(0, l // chunk_size): + q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i] + v2_i = u_i + w_i @ S + + o_1 = A_qk[:, :, i] @ v_i + o_2 = A_qb[:, :, i] @ v2_i + o_3 = (q_i * gk_cumsum[:, :, i].exp()) @ S + o[:, :, i] = o_1 + o_2 + o_3 + decay = (gk_cumsum[:, :, i, -1, None] - gk_cumsum[:, :, i]).exp() + S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + \ + (beta_i * decay).transpose(-1, -2) @ v2_i + S = None if output_final_state is False else S + return rearrange(o, 'b h n c d -> b h (n c) d'), S diff --git a/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py b/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..d9ff775184d4f1fa4472bb172da19fdd45553ed6 --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import check_shared_mem, is_intel_alchemist, use_cuda_graph + +# https://github.com/intel/intel-xpu-backend-for-triton/issues/3449 +triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {} + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config(triton_config, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def bwd_prepare_wy_repr_kernel( + A_ab_inv, + A_ak, + ag, + v, + dw, + du, + dv, + dv0, + dag, + dAak, + dAab, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_Aak_t = tl.make_block_ptr(A_ak + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_dAak = tl.make_block_ptr(dAak + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dAab = tl.make_block_ptr(dAab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_Aak_t = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_dAak = tl.make_block_ptr(dAak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dAab = tl.make_block_ptr(dAab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1)) + b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1)) + b_A_ak_t = tl.where(tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0) + b_A_ab_inv_t = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0) + b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty) + b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv0 = tl.make_block_ptr(dv0 + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv0 = tl.make_block_ptr(dv0 + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v)) + b_dv0 = tl.load(p_dv0, boundary_check=(0, 1)) + b_dv = b_dv0 + tl.dot(b_A_tmp_t, b_du) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dA_tmp = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_tmp, 0) + b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp) + b_dA_ak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ak, 0) + tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1)) + b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dag = tl.make_block_ptr(dag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dag = tl.make_block_ptr(dag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_ag = tl.load(p_ag, boundary_check=(0, 1)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag)) + b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw) + tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1)) + + # if we know dL/dA^(-1), for dL/dA, we can use the following formula: + # dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T + # in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1. + # denote A = I - lower(A_ab), B = A^-1 + # in the backward pass. + # dL/dA = -(B)^T @ (dL/dB) @ B^T + # dL/dA_ab = lower(B^T @ dL/dB @ B^T) + b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0) + b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv) + b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t) + b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ab_inv, 0) + tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1)) + + +def chunk_dplr_bwd_wy( + A_ab_inv: torch.Tensor, + A_ak: torch.Tensor, + v: torch.Tensor, + ag: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + dv0: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + A_ab_inv, A_ak, v, ag, dw, du = map(lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du]) + if head_first: + B, H, T, K, V = *dw.shape, du.shape[-1] + else: + B, T, H, K, V = *dw.shape, du.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(V), 32) + + dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float) + dA_ak = torch.empty_like(A_ak, dtype=torch.float) + dv = torch.empty_like(v) + dag = torch.empty_like(ag) + + bwd_prepare_wy_repr_kernel[(NT, B * H)]( + A_ab_inv=A_ab_inv, + A_ak=A_ak, + ag=ag, + v=v, + dw=dw, + du=du, + dv=dv, + dv0=dv0, + dag=dag, + dAak=dA_ak, + dAab=dA_ab, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dA_ab, dA_ak, dv, dag diff --git a/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py b/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..3ef5ac298d5218a6a1c10087a2bafc547f03acff --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py @@ -0,0 +1,318 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import gather +from fla.utils import is_gather_supported, use_cuda_graph + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def fwd_prepare_wy_repr_kernel_chunk32( + A_ab, + A_ab_inv, + offsets, + indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, # placeholder, do not delete + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + if HEAD_FIRST: + p_Aab = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_Aab_inv = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_Aab = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_Aab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A_ab = tl.load(p_Aab, boundary_check=(0, 1)) + b_A_ab = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A_ab, 0) + for i in range(1, BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A_ab, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A_ab, 0) * (tl.arange(0, BT) < i) + b_A_ab = tl.where(mask[:, None], b_a, b_A_ab) + b_A_ab += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + tl.store(p_Aab_inv, b_A_ab.to(p_Aab_inv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BC'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def fwd_prepare_wy_repr_kernel_chunk64( + A_ab, + A_ab_inv, + offsets, + indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, + GATHER_SUPPORTED: tl.constexpr = is_gather_supported +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + + p_A1 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A3 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A_inv1 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A_inv2 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A_inv3 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A_inv4 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + else: + p_A1 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A3 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A_inv1 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A_inv2 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A_inv3 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A_inv4 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + + b_A = tl.load(p_A1, boundary_check=(0, 1)) + b_A2 = tl.load(p_A2, boundary_check=(0, 1)) + b_A3 = tl.load(p_A3, boundary_check=(0, 1)) + b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0) + b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0) + + for i in range(1, BC): + if GATHER_SUPPORTED: + row_idx = tl.full([1, BC], i, dtype=tl.int16) + # [1, BK] -> [BK] + b_a = tl.sum(gather(b_A, row_idx, axis=0), 0) + b_a2 = tl.sum(gather(b_A2, row_idx, axis=0), 0) + else: + mask = tl.arange(0, BC) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) + mask = tl.arange(0, BC) == i + # b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + # b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i) + b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A2 = tl.where(mask[:, None], b_a2, b_A2) + + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A3 = tl.dot(tl.dot(b_A2, b_A3), b_A) + # tl.debug_barrier() + tl.store(p_A_inv1, b_A.to(p_A_inv1.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_A_inv2, b_A2.to(p_A_inv2.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_A_inv3, b_A3.to(p_A_inv3.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + # causal mask + tl.store(p_A_inv4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A_inv4.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def fwd_wu_kernel( + u, + w, + ag, + v, + A_ab_inv, + A_ak, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_A_ab_inv = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_A_ak = tl.make_block_ptr(A_ak + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_A_ab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_A_ak = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_Aab_inv = tl.load(p_A_ab_inv, boundary_check=(0, 1)) + b_Aak = tl.load(p_A_ak, boundary_check=(0, 1)) + o_s = tl.arange(0, BT) + b_Aab_inv = tl.where(o_s[:, None] >= o_s[None, :], b_Aab_inv, 0) + b_Aak = tl.where(o_s[:, None] > o_s[None, :], b_Aak, 0) + # let's use tf32 here + b_Aak = tl.dot(b_Aab_inv, b_Aak) + # (SY 01/04) should be bf16 or tf32? To verify. + b_Aak = b_Aak.to(v.dtype.element_ty, fp_downcast_rounding="rtne") + b_Aab_inv = b_Aab_inv.to(ag.dtype.element_ty, fp_downcast_rounding="rtne") + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_ag = tl.make_block_ptr(ag + i_bh * T * K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + i_bh * T * K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_ag = tl.make_block_ptr(ag + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_ag = tl.load(p_ag, boundary_check=(0, 1)) + b_w = tl.dot(b_Aab_inv, b_ag) # both bf16 or fp16 + tl.store(p_w, b_w.to(p_w.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T * V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + i_bh * T * V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_u = tl.dot(b_Aak, b_v) # both bf16 or fp16 + tl.store(p_u, b_u.to(p_u.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +def fwd_prepare_wy_repr( + ag: torch.Tensor, + v: torch.Tensor, + A_ak: torch.Tensor, + A_ab: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K = ag.shape + else: + B, T, H, K = ag.shape + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BC = min(BT, 32) + fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32 + A_ab_inv = torch.empty_like(A_ab) + fwd_fn[(NT, B * H)]( + A_ab=A_ab, + A_ab_inv=A_ab_inv, + offsets=offsets, + indices=indices, + T=T, + H=H, + BT=BT, + BC=BC, + HEAD_FIRST=head_first + ) + w, u = fwd_wu( + ag=ag, + v=v, + A_ak=A_ak, + A_ab_inv=A_ab_inv, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + return w, u, A_ab_inv + + +def fwd_wu( + ag: torch.Tensor, + v: torch.Tensor, + A_ak: torch.Tensor, + A_ab_inv: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *ag.shape, v.shape[-1] + else: + B, T, H, K, V = *ag.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + + u = torch.empty_like(v) + w = torch.empty_like(ag) + fwd_wu_kernel[(NT, B*H)]( + ag=ag, + v=v, + A_ak=A_ak, + A_ab_inv=A_ab_inv, + w=w, + u=u, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return w, u diff --git a/fla/ops/generalized_delta_rule/iplr/__init__.py b/fla/ops/generalized_delta_rule/iplr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e44d2a773b31f43fce68c5a9d1e67a3b33f42411 --- /dev/null +++ b/fla/ops/generalized_delta_rule/iplr/__init__.py @@ -0,0 +1,7 @@ +from .chunk import chunk_iplr_delta_rule +from .fused_recurrent import fused_recurrent_iplr_delta_rule + +__all__ = [ + 'chunk_iplr_delta_rule', + 'fused_recurrent_iplr_delta_rule' +] diff --git a/fla/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc b/fla/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49fe01b1c2e15b59b9c029cd527537330bfc0ab1 Binary files /dev/null and b/fla/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc b/fla/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..721485351f47fcdd9c079b4a9c376ca327fb20b1 Binary files /dev/null and b/fla/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla/ops/generalized_delta_rule/iplr/chunk.py b/fla/ops/generalized_delta_rule/iplr/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..07f76533b10f022ba2dd1bc2af075c5a4f537760 --- /dev/null +++ b/fla/ops/generalized_delta_rule/iplr/chunk.py @@ -0,0 +1,528 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.chunk_delta_h import prepare_chunk_offsets +from fla.ops.generalized_delta_rule.iplr.wy_fast import fwd_prepare_wy_repr +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, use_cuda_graph + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [2, 4, 8, 16] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_generalized_iplr_delta_rule_fwd_kernel_h( + k, + v, + d, + b, + u, + v_new, + h, + h0, + ht, + offsets, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + b_hc = tl.zeros([BK, BV], dtype=tl.float32) + # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden + for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_b = tl.make_block_ptr(b + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_d = tl.make_block_ptr(d + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_u = tl.make_block_ptr(u + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + else: + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_b = tl.make_block_ptr(b+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0)) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_v2 = tl.dot(b_d, b_h.to(b_d.dtype)) + tl.load(p_u, boundary_check=(0, 1)) + b_hc += tl.dot(b_k, b_v) + b_hc += tl.dot(b_b, b_v2.to(b_k.dtype)) + tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + b_h += b_hc + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in [2, 4, 8] + for num_stages in [2, 3] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_generalized_iplr_delta_rule_fwd_kernel_o( + q, + k, + v, + u, + b, + h, + o, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K) + k += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K) + b += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K) + v += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V) + u += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V) + o += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V) + h += ((i_bh * NT + i_t) * K * V) if HEAD_FIRST else ((i_tg * H + i_h) * K * V) + stride_qk = K if HEAD_FIRST else H*K + stride_vo = V if HEAD_FIRST else H*V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_Aqk = tl.zeros([BT, BT], dtype=tl.float32) + b_Aqb = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_b = tl.make_block_ptr(b, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_Aqk += tl.dot(b_q, b_k) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_Aqb += tl.dot(b_q, b_b) + + o_i = tl.arange(0, BT) + m_A = o_i[:, None] >= o_i[None, :] + b_Aqk = tl.where(m_A, b_Aqk, 0) + b_Aqb = tl.where(m_A, b_Aqb, 0) + + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_u = tl.load(p_u, boundary_check=(0, 1)) + b_o = (b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_u.dtype), b_u)) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_generalized_iplr_delta_rule_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + v_new: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + scale: Optional[float] = None, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> torch.Tensor: + if head_first: + B, H, T, K, V = *q.shape, v.shape[-1] + else: + B, T, H, K, V = *q.shape, v.shape[-1] + if scale is None: + scale = k.shape[-1] ** -0.5 + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + o = torch.empty_like(v) + + def grid(meta): return ( + triton.cdiv(V, meta['BV']), + NT, + B * H + ) + chunk_generalized_iplr_delta_rule_fwd_kernel_o[grid]( + q=q, + k=k, + v=v, + u=v_new, + b=b, + h=h, + o=o, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + HEAD_FIRST=head_first + ) + return o + + +def chunk_generalized_iplr_delta_rule_fwd_h( + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + b: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, u.shape[-1] + else: + B, T, H, K, V = *k.shape, u.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT) + + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension larger than 256." + # H100 can have larger block size + + if check_shared_mem('hopper', k.device.index): + BV = 64 + BC = 64 if K <= 128 else 32 + elif check_shared_mem('ampere', k.device.index): # A100 + BV = 32 + BC = 32 + else: + BV = 16 + BC = 16 + + BC = min(BT, BC) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + if head_first: + h = k.new_empty(B, H, NT, K, V) + else: + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(u) + grid = (NK, NV, N * H) + + chunk_generalized_iplr_delta_rule_fwd_kernel_h[grid]( + k=k, + v=v, + d=w, + b=b, + u=u, + v_new=v_new, + h=h, + h0=initial_state, + ht=final_state, + offsets=offsets, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BC=BC, + BK=BK, + BV=BV, + NT=NT, + HEAD_FIRST=head_first + ) + return h, v_new, final_state + + +def chunk_generalized_iplr_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + T = q.shape[2] if head_first else q.shape[1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + w, u, _ = fwd_prepare_wy_repr( + a=a, + b=b, + k=k, + v=v, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + + h, v_new, final_state = chunk_generalized_iplr_delta_rule_fwd_h( + k=k, + v=v, + b=b, + w=w, + u=u, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + o = chunk_generalized_iplr_delta_rule_fwd_o( + q=q, + k=k, + v=v, + v_new=v_new, + b=b, + h=h, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + return o, final_state + + +class ChunkGeneralizedIPLRDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True + ): + chunk_size = 64 + + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = None + if offsets is not None: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()]) + indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) + + o, final_state = chunk_generalized_iplr_delta_rule_fwd( + q=q, + k=k, + v=v, + a=a, + b=b, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward( + ctx, + do: torch.Tensor, + dht: torch.Tensor + ): + raise NotImplementedError( + "Backward pass for ChunkGeneralizedIPLRDeltaRuleFunction is not implemented yet. " + "Stay tuned!" + ) + + +@torch.compiler.disable +def chunk_iplr_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + a (torch.Tensor): + activations of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + b (torch.Tensor): + betas of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16." + + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + scale = k.shape[-1] ** -0.5 if scale is None else scale + o, final_state = ChunkGeneralizedIPLRDeltaRuleFunction.apply( + q, + k, + v, + a, + b, + scale, + initial_state, + output_final_state, + cu_seqlens, + head_first + ) + return o, final_state diff --git a/fla/ops/generalized_delta_rule/iplr/wy_fast.py b/fla/ops/generalized_delta_rule/iplr/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..9fdfa7091500873765a36c6ef86506a203f4be19 --- /dev/null +++ b/fla/ops/generalized_delta_rule/iplr/wy_fast.py @@ -0,0 +1,338 @@ + +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import check_shared_mem, is_nvidia_hopper + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BK'] +) +@triton.jit(do_not_specialize=['T']) +def fwd_prepare_wy_repr_kernel_chunk32( + a, + b, + A, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, # dummy placeholder + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + else: + p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_A += tl.dot(b_a, b_b) + + b_A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0) + for i in range(1, BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + + if HEAD_FIRST: + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BK'] +) +@triton.jit(do_not_specialize=['T']) +def fwd_prepare_wy_repr_kernel_chunk64( + a, + b, + A, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + b_A2 = tl.zeros([BC, BC], dtype=tl.float32) + b_A3 = tl.zeros([BC, BC], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_a1 = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_a2 = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) + p_b1 = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BC), (0, 1)) + p_b2 = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1)) + else: + p_a1 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_a2 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) + p_b1 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BC), (0, 1)) + p_b2 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1)) + b_a1 = tl.load(p_a1, boundary_check=(0, 1)) + b_a2 = tl.load(p_a2, boundary_check=(0, 1)) + b_b1 = tl.load(p_b1, boundary_check=(0, 1)) + b_b2 = tl.load(p_b2, boundary_check=(0, 1)) + b_A += tl.dot(b_a1, b_b1, allow_tf32=False) + b_A2 += tl.dot(b_a2, b_b2, allow_tf32=False) + b_A3 += tl.dot(b_a2, b_b1, allow_tf32=False) + + b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0) + b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0) + + for i in range(1, BC): + mask = tl.arange(0, BC) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i) + b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A2 = tl.where(mask[:, None], b_a2, b_A2) + + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A3 = tl.dot(tl.dot(b_A2, b_A3, allow_tf32=False), b_A, allow_tf32=False) + + if HEAD_FIRST: + p_A1 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A3 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A4 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + else: + p_A1 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A3 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A4 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1)) + # causal mask + tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in NUM_WARPS + ], + key=['BT', 'BK', 'BV'] +) +@triton.jit(do_not_specialize=['T']) +def fwd_wu_kernel( + w, + u, + a, + k, + v, + A, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_Aak = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_w = tl.dot(b_A, b_a) + b_Aak += tl.dot(b_a, tl.trans(b_k)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + b_Aak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_Aak, 0) + b_Aak = b_Aak.to(k.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v = tl.dot(b_Aak, b_v).to(v.dtype.element_ty) + b_u = tl.dot(b_A, b_v) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +def fwd_prepare_wy_repr( + a: torch.Tensor, + b: torch.Tensor, + v: torch.Tensor, + k: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K = a.shape + else: + B, T, H, K = a.shape + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BC = min(BT, 32) + BK = min(triton.next_power_of_2(K), 64) + + A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=a.device, dtype=a.dtype) + fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32 + + fwd_fn[(NT, B * H)]( + a=a, + b=b, + A=A, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + BT=BT, + BK=BK, + BC=BC, + HEAD_FIRST=head_first + ) + w, u = fwd_wu( + a=a, + v=v, + k=k, + A=A, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return w, u, A + + +def fwd_wu( + a: torch.Tensor, + v: torch.Tensor, + k: torch.Tensor, + A: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *a.shape, v.shape[-1] + else: + B, T, H, K, V = *a.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + + u = torch.empty_like(v) + w = torch.empty_like(a) + fwd_wu_kernel[(NT, B*H)]( + a=a, + v=v, + w=w, + u=u, + A=A, + k=k, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return w, u diff --git a/fla/ops/gla/__pycache__/__init__.cpython-312.pyc b/fla/ops/gla/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7f991bef9a9b4db36790d8fb827c38f299a2906 Binary files /dev/null and b/fla/ops/gla/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/ops/gla/__pycache__/chunk.cpython-312.pyc b/fla/ops/gla/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb51eea2b133ee9f6bbf4ec7478a731d276a6baa Binary files /dev/null and b/fla/ops/gla/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla/ops/gla/__pycache__/fused_chunk.cpython-312.pyc b/fla/ops/gla/__pycache__/fused_chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36ddfd7d4d4d0e188570f973e17648777dd7c333 Binary files /dev/null and b/fla/ops/gla/__pycache__/fused_chunk.cpython-312.pyc differ diff --git a/fla/ops/gsa/__pycache__/fused_recurrent.cpython-312.pyc b/fla/ops/gsa/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15216abed05944f5354b60fd9ca8c94957d7c21d Binary files /dev/null and b/fla/ops/gsa/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla/ops/hgrn/__pycache__/__init__.cpython-312.pyc b/fla/ops/hgrn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c112279ad6b4f1651b0fe68cc799b492349acd76 Binary files /dev/null and b/fla/ops/hgrn/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/ops/lightning_attn/__pycache__/__init__.cpython-312.pyc b/fla/ops/lightning_attn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38981ac556cfa1e70fbc9449473693c59e75ff78 Binary files /dev/null and b/fla/ops/lightning_attn/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/ops/linear_attn/__pycache__/fused_recurrent.cpython-312.pyc b/fla/ops/linear_attn/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30d299f9c161310dad13ef746f4459c0b5de9b47 Binary files /dev/null and b/fla/ops/linear_attn/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla/ops/linear_attn/__pycache__/utils.cpython-312.pyc b/fla/ops/linear_attn/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c059a78a9cf13e1f6cc001a0c1aa2f335849499 Binary files /dev/null and b/fla/ops/linear_attn/__pycache__/utils.cpython-312.pyc differ diff --git a/fla/ops/nsa/__pycache__/__init__.cpython-312.pyc b/fla/ops/nsa/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a73e3f72711566de0b8330abb277c26b3a33f95 Binary files /dev/null and b/fla/ops/nsa/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/ops/nsa/__pycache__/utils.cpython-312.pyc b/fla/ops/nsa/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8c2154d000d4034c88ccac83a5af1c4d94441dd Binary files /dev/null and b/fla/ops/nsa/__pycache__/utils.cpython-312.pyc differ diff --git a/fla/ops/rebased/__pycache__/__init__.cpython-312.pyc b/fla/ops/rebased/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..277679bfa86e524ad1f8939206932b655a2ceb7a Binary files /dev/null and b/fla/ops/rebased/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/ops/rebased/__pycache__/parallel.cpython-312.pyc b/fla/ops/rebased/__pycache__/parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dfbea7aa018ea6f62180d719cc2b9511ece8b46 Binary files /dev/null and b/fla/ops/rebased/__pycache__/parallel.cpython-312.pyc differ diff --git a/fla/ops/retention/__pycache__/chunk.cpython-312.pyc b/fla/ops/retention/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54a6cda157e69231d0767c75fd699db9ccc87fb5 Binary files /dev/null and b/fla/ops/retention/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla/ops/rwkv4/fused_recurrent.py b/fla/ops/rwkv4/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..63a5c6577dd3ef288aa59c494e74b8d29d8580ad --- /dev/null +++ b/fla/ops/rwkv4/fused_recurrent.py @@ -0,0 +1,476 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +from typing import Any, cast + +import torch +import triton +import triton.language as tl +from torch import Tensor +from torch.autograd.function import Function, FunctionCtx, once_differentiable + +from fla.ops.utils.op import exp + + +def get_block_size_c(chans: int) -> int: + if chans < 32: + return 32 + if chans < 64: + return 64 + return 128 + + +@triton.jit +def fused_recurrent_rwkv4_forward_kernel( + # W + w_ptr, + w_s_c, + # U + u_ptr, + u_s_c, + # K + k_ptr, + k_s_b, + k_s_t, + k_s_c, + # V + v_ptr, + v_s_b, + v_s_t, + v_s_c, + # State + state_ptr, + state_s_b, + state_s_abe, + state_s_c, + # WKV + wkv_ptr, + wkv_s_b, + wkv_s_t, + wkv_s_c, + # Output state + state_out_ptr, + state_out_s_b, + state_out_s_abe, + state_out_s_t, + state_out_s_c, + # Params + chans, + tsz, + BLOCK_SIZE_C: tl.constexpr, +): + # Parallelize over the batch dimension. + b_idx = tl.program_id(0) + c_idx = tl.program_id(1) + + cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C) + cmask = cs < chans + + # Pointers to the batch (and possibly channel) for the input tensors. + k_ptr = k_ptr + b_idx * k_s_b + v_ptr = v_ptr + b_idx * v_s_b + alpha_ptr = state_ptr + b_idx * state_s_b + beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe + eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe + + # Pointers to the batch (and possibly channel) for the output tensors. + wkv_ptr = wkv_ptr + b_idx * wkv_s_b + alpha_out_ptr = state_out_ptr + b_idx * state_out_s_b + beta_out_ptr = state_out_ptr + b_idx * state_out_s_b + state_out_s_abe + eps_out_ptr = state_out_ptr + b_idx * state_out_s_b + 2 * state_out_s_abe + + # Loads parameters. + alpha = tl.load(alpha_ptr + cs * state_s_c, mask=cmask).to(tl.float32) + beta = tl.load(beta_ptr + cs * state_s_c, mask=cmask).to(tl.float32) + eps = tl.load(eps_ptr + cs * state_s_c, mask=cmask).to(tl.float32) + w = tl.load(w_ptr + cs * w_s_c, mask=cmask).to(tl.float32) + u = tl.load(u_ptr + cs * u_s_c, mask=cmask).to(tl.float32) + + for t in range(tsz): + kt = tl.load(k_ptr + t * k_s_t + cs * k_s_c, mask=cmask).to(tl.float32) + vt = tl.load(v_ptr + t * v_s_t + cs * v_s_c, mask=cmask).to(tl.float32) + + ukt = u + kt + tau = tl.maximum(ukt, eps) + e1a = exp(eps - tau) + e2a = exp(ukt - tau) + wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a) + tl.store(wkv_ptr + t * wkv_s_t + cs * wkv_s_c, wkv, mask=cmask) + + w_eps = w + eps + eps = tl.maximum(w_eps, kt) + e1b = exp(w_eps - eps) + e2b = exp(kt - eps) + alpha = e1b * alpha + e2b * vt + beta = e1b * beta + e2b + tl.store(alpha_out_ptr + t * state_out_s_t + cs * state_out_s_c, alpha, mask=cmask) + tl.store(beta_out_ptr + t * state_out_s_t + cs * state_out_s_c, beta, mask=cmask) + tl.store(eps_out_ptr + t * state_out_s_t + cs * state_out_s_c, eps, mask=cmask) + + +def fused_recurrent_rwkv4_forward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + state: Tensor, +) -> tuple[Tensor, Tensor]: + (bsz, tsz, chans) = k.shape + + # New tensors to output. + wkvs = k.new_empty(bsz, tsz, chans) + state_out = k.new_empty(bsz, 3, tsz, chans) + + # Constants. + block_size_c = get_block_size_c(chans) + + def grid(meta: dict[str, Any]) -> tuple[int, ...]: + return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"])) + + fused_recurrent_rwkv4_forward_kernel[grid]( + # W + w, + w.stride(0), + # U + u, + u.stride(0), + # K + k, + k.stride(0), + k.stride(1), + k.stride(2), + # V + v, + v.stride(0), + v.stride(1), + v.stride(2), + # State + state, + state.stride(0), + state.stride(1), + state.stride(3), + # WKV + wkvs, + wkvs.stride(0), + wkvs.stride(1), + wkvs.stride(2), + # Output state + state_out, + state_out.stride(0), + state_out.stride(1), + state_out.stride(2), + state_out.stride(3), + # Params + chans, + tsz, + BLOCK_SIZE_C=block_size_c, + ) + + state_out = torch.cat((state, state_out), dim=2) + + return wkvs, state_out + + +@triton.jit +def fused_recurrent_rwkv4_backward_kernel( + # W + w_ptr, + w_s_c, + # U + u_ptr, + u_s_c, + # K + k_ptr, + k_s_b, + k_s_t, + k_s_c, + # V + v_ptr, + v_s_b, + v_s_t, + v_s_c, + # State + state_ptr, + state_s_b, + state_s_abe, + state_s_t, + state_s_c, + # WKV grad + gwkv_ptr, + gwkv_s_b, + gwkv_s_t, + gwkv_s_c, + # Output state grad + gstate_out_ptr, + gstate_out_s_b, + gstate_out_s_abe, + gstate_out_s_c, + # W grad + gw_ptr, + gw_s_c, + # U grad + gu_ptr, + gu_s_c, + # K grad + gk_ptr, + gk_s_b, + gk_s_t, + gk_s_c, + # V grad + gv_ptr, + gv_s_b, + gv_s_t, + gv_s_c, + # State grad + gstate_ptr, + gstate_s_b, + gstate_s_abe, + gstate_s_c, + # Params + tsz, + chans, + BLOCK_SIZE_C: tl.constexpr, +): + # Parallelize over the batch dimension. + b_idx = tl.program_id(0) + c_idx = tl.program_id(1) + + cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C) + cmask = cs < chans + + # Pointers to the batch (and possibly channel) for the input tensors. + k_ptr = k_ptr + b_idx * k_s_b + v_ptr = v_ptr + b_idx * v_s_b + alpha_ptr = state_ptr + b_idx * state_s_b + beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe + eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe + + # Pointers to the batch (and possibly channel) for the output tensors. + gk_ptr = gk_ptr + b_idx * gk_s_b + gv_ptr = gv_ptr + b_idx * gv_s_b + + # Pointers to gradients which were recieved by the function. + gwkv_ptr = gwkv_ptr + b_idx * gwkv_s_b + galpha_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gbeta_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gstate_out_s_abe + geps_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + 2 * gstate_out_s_abe + + # Loads parameters. + galpha = tl.load(galpha_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) + gbeta = tl.load(gbeta_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) + geps = tl.load(geps_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) + w = tl.load(w_ptr + w_s_c * cs, mask=cmask).to(tl.float32) + u = tl.load(u_ptr + u_s_c * cs, mask=cmask).to(tl.float32) + + # Gradient accumulators. + gw = tl.zeros_like(w) + gu = tl.zeros_like(u) + + alpha_prev = tl.load(alpha_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + beta_prev = tl.load(beta_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + eps_prev = tl.load(eps_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + + for t in range(tsz): + tc = tsz - t - 1 + + kt = tl.load(k_ptr + tc * k_s_t + k_s_c * cs, mask=cmask).to(tl.float32) + vt = tl.load(v_ptr + tc * v_s_t + v_s_c * cs, mask=cmask).to(tl.float32) + + alpha_curr = alpha_prev + beta_curr = beta_prev + eps_curr = eps_prev + + alpha_prev = tl.load(alpha_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + beta_prev = tl.load(beta_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + eps_prev = tl.load(eps_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + + ukt = u + kt + tau = tl.maximum(ukt, eps_prev) + e1 = exp(eps_prev - tau) + e2 = exp(ukt - tau) + + euke = exp(ukt + eps_prev - 2 * tau) + + denom = e1 * beta_prev + e2 + denom_sq = denom * denom + + gwkvt = tl.load(gwkv_ptr + tc * gwkv_s_t + gwkv_s_c * cs, mask=cmask).to(tl.float32) + + # Backpropagates wkv gradients. + guk = gwkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq + gu += guk + gk = guk + gv = gwkvt * e2 / denom + + galpha_wkv = gwkvt * e1 / denom + gbeta_wkv = -gwkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq + geps_wkv_denom = e1 * beta_prev + e2 + geps_wkv = gwkvt * euke * (alpha_prev - vt * beta_prev) / (geps_wkv_denom * geps_wkv_denom) + + e1 = exp(w + eps_prev - eps_curr) + e2 = exp(kt - eps_curr) + + # Backpropagates alpha gradients. + galpha_we = galpha * e1 * alpha_prev + gw += galpha_we + gk += galpha * e2 * vt + gv += galpha * e2 + geps += galpha * -alpha_curr + + # Backpropagates beta gradients. + gbeta_we = gbeta * e1 * beta_prev + gw += gbeta_we + gk += gbeta * e2 + geps += gbeta * -beta_curr + + # Backpropagates epsilon gradients. + geps_mask = w + eps_prev > kt + geps_we = tl.where(geps_mask, geps, tl.zeros_like(geps)) + gw += geps_we + gk += tl.where(geps_mask, tl.zeros_like(geps), geps) + + # Stores the gradients for k and v. + tl.store(gk_ptr + tc * gk_s_t + gk_s_c * cs, gk, mask=cmask) + tl.store(gv_ptr + tc * gv_s_t + gv_s_c * cs, gv, mask=cmask) + + # Computes new gradients for alpha and beta. + galpha = galpha * e1 + galpha_wkv + gbeta = gbeta * e1 + gbeta_wkv + geps = galpha_we + gbeta_we + geps_we + geps_wkv + + # Stores final gradients for alpha and beta. + galpha_ptr = gstate_ptr + b_idx * gstate_s_b + gbeta_ptr = gstate_ptr + b_idx * gstate_s_b + gstate_s_abe + geps_ptr = gstate_ptr + b_idx * gstate_s_b + 2 * gstate_s_abe + tl.store(galpha_ptr + gstate_s_c * cs, galpha, mask=cmask) + tl.store(gbeta_ptr + gstate_s_c * cs, gbeta, mask=cmask) + tl.store(geps_ptr + gstate_s_c * cs, geps, mask=cmask) + + # Stores final gradients for w and u. + gw_temp = tl.load(gw_ptr + gw_s_c * cs, mask=cmask).to(tl.float32) + gw_temp += gw + tl.store(gw_ptr + gw_s_c * cs, gw_temp, mask=cmask) + gu_temp = tl.load(gu_ptr + gu_s_c * cs, mask=cmask).to(tl.float32) + gu_temp += gu + tl.store(gu_ptr + gu_s_c * cs, gu_temp, mask=cmask) + + +def fused_recurrent_rwkv4_backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + state: Tensor, + grad_wkv: Tensor, + grad_state: Tensor, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + bsz, tsz, chans = k.shape + + gw = torch.zeros_like(w) # New tensors to output. + gu = torch.zeros_like(u) + gk = torch.empty_like(k) + gv = torch.empty_like(v) + gstate = k.new_empty(bsz, 3, 1, chans) + + block_size_c = get_block_size_c(chans) # Constants. + + def grid(meta: dict[str, Any]) -> tuple[int, ...]: + return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"])) + + fused_recurrent_rwkv4_backward_kernel[grid]( + # W + w, + w.stride(0), + # U + u, + u.stride(0), + # K + k, + k.stride(0), + k.stride(1), + k.stride(2), + # V + v, + v.stride(0), + v.stride(1), + v.stride(2), + # State + state, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + # WKV grad + grad_wkv, + grad_wkv.stride(0), + grad_wkv.stride(1), + grad_wkv.stride(2), + # Output state grad + grad_state, + grad_state.stride(0), + grad_state.stride(1), + grad_state.stride(3), + # W grad + gw, + gw.stride(0), + # U grad + gu, + gu.stride(0), + # K grad + gk, + gk.stride(0), + gk.stride(1), + gk.stride(2), + # V grad + gv, + gv.stride(0), + gv.stride(1), + gv.stride(2), + # State grad + gstate, + gstate.stride(0), + gstate.stride(1), + gstate.stride(3), + # Params + tsz, + chans, + BLOCK_SIZE_C=block_size_c, + ) + + return gw, gu, gk, gv, gstate + + +class FusedRecurrentRWKV4Function(Function): + @staticmethod + def forward( + ctx: FunctionCtx, + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + state: Tensor, + ) -> tuple[Tensor, Tensor]: + ctx.input_dtype = k.dtype + + w = -torch.exp(w.float().contiguous()) + if k.dtype == torch.float16: + u = u.float() + k = k.float() + v = v.float() + u = u.contiguous() + k = k.contiguous() + v = v.contiguous() + wkv, state_out = fused_recurrent_rwkv4_forward(w, u, k, v, state) + ctx.save_for_backward(w, u, k, v, state_out[:, :, :-1]) + return wkv, state_out[:, :, -1:] + + @staticmethod + @once_differentiable + def backward(ctx: FunctionCtx, gwkv: Tensor, gstate: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + w, u, k, v, state = cast(tuple[Tensor, ...], ctx.saved_tensors) + gw, gu, gk, gv, gstate = fused_recurrent_rwkv4_backward(w, u, k, v, state, gwkv, gstate) + return gw, gu, gk, gv, gstate + + +def fused_recurrent_rwkv4(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]: + return FusedRecurrentRWKV4Function.apply(w, u, k, v, state) diff --git a/fla/ops/rwkv6/__pycache__/chunk.cpython-312.pyc b/fla/ops/rwkv6/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b03156c7e02419aeeb6b28bdd3f00c5f1eeecfd Binary files /dev/null and b/fla/ops/rwkv6/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla/ops/rwkv6/__pycache__/fused_recurrent.cpython-312.pyc b/fla/ops/rwkv6/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac7932b966ac18a7ba83fdd10aa625403f5543cc Binary files /dev/null and b/fla/ops/rwkv6/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla/ops/rwkv7/__pycache__/fused_recurrent.cpython-312.pyc b/fla/ops/rwkv7/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ce194dbdbe7dd1e9bfd580003c5345415697591 Binary files /dev/null and b/fla/ops/rwkv7/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla/ops/simple_gla/__pycache__/fused_recurrent.cpython-312.pyc b/fla/ops/simple_gla/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f7ade450ed7c345c81a7e48b764718244416831 Binary files /dev/null and b/fla/ops/simple_gla/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla/ops/simple_gla/__pycache__/parallel.cpython-312.pyc b/fla/ops/simple_gla/__pycache__/parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2b70f70f44a4b3d536e394585ff785178cec008 Binary files /dev/null and b/fla/ops/simple_gla/__pycache__/parallel.cpython-312.pyc differ diff --git a/fla/ops/utils/__pycache__/__init__.cpython-312.pyc b/fla/ops/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fdc7b5e5b235de5d6dd621a07c80ca216c93df8 Binary files /dev/null and b/fla/ops/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/ops/utils/__pycache__/asm.cpython-312.pyc b/fla/ops/utils/__pycache__/asm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62c2f8ddd1c2182220e3a412faaa16ca9d5fcd7a Binary files /dev/null and b/fla/ops/utils/__pycache__/asm.cpython-312.pyc differ diff --git a/fla/ops/utils/__pycache__/cumsum.cpython-312.pyc b/fla/ops/utils/__pycache__/cumsum.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85186a403e2832f0a71ffe56a087bde07ac13746 Binary files /dev/null and b/fla/ops/utils/__pycache__/cumsum.cpython-312.pyc differ diff --git a/fla/ops/utils/__pycache__/logcumsumexp.cpython-312.pyc b/fla/ops/utils/__pycache__/logcumsumexp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67e47b615114b5c57cc9360cd947e292bb739f95 Binary files /dev/null and b/fla/ops/utils/__pycache__/logcumsumexp.cpython-312.pyc differ diff --git a/fla/ops/utils/__pycache__/logsumexp.cpython-312.pyc b/fla/ops/utils/__pycache__/logsumexp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c98bd6d3cd7d1b8df0ddcdb06fca0197c5442534 Binary files /dev/null and b/fla/ops/utils/__pycache__/logsumexp.cpython-312.pyc differ diff --git a/fla/ops/utils/__pycache__/matmul.cpython-312.pyc b/fla/ops/utils/__pycache__/matmul.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5dfd37dbb51338ee2f10f39c785d7171db0721c Binary files /dev/null and b/fla/ops/utils/__pycache__/matmul.cpython-312.pyc differ diff --git a/fla/ops/utils/pooling.py b/fla/ops/utils/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..0dd9059b4abd0a87fb65e25c01fd5897452f77e0 --- /dev/null +++ b/fla/ops/utils/pooling.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.utils import prepare_chunk_indices +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=num_warps) + for BD in [16, 32, 64, 128] + for num_warps in [1, 2, 4, 8] + ], + key=['BT'] +) +@triton.jit(do_not_specialize=['T']) +def mean_pooling_fwd_kernel( + x, + o, + offsets, + indices, + T: tl.constexpr, + H: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + NT: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + p_x = tl.make_block_ptr(x + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_o = tl.make_block_ptr(o + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,)) + # [BT, BD] + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + # [BD] + b_o = tl.sum(b_x, axis=0) / min(BT, T - i_t * BT) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=num_warps) + for BD in [16, 32, 64, 128] + for num_warps in [1, 2, 4, 8] + ], + key=['BT'] +) +@triton.jit(do_not_specialize=['T']) +def mean_pooling_bwd_kernel( + do, + dx, + offsets, + indices, + T: tl.constexpr, + H: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + NT: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + p_dx = tl.make_block_ptr(dx + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_do = tl.make_block_ptr(do + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,)) + # [BD] + b_do = tl.load(p_do, boundary_check=(0,)).to(tl.float32) + # [BT, BD] + b_dx = b_do / tl.full((BT,), min(BT, T - i_t * BT), dtype=tl.float32)[:, None] + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) + + +def mean_pooling_fwd( + x: torch.Tensor, + chunk_size: int, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None +) -> torch.Tensor: + B, T, H, D = x.shape + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + o = x.new_empty(B, NT, H, D) + def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H) + mean_pooling_fwd_kernel[grid]( + x, + o, + offsets, + indices, + T=T, + H=H, + D=D, + BT=BT, + NT=NT, + ) + return o + + +def mean_pooling_bwd( + do: torch.Tensor, + batch_size: int, + seq_len: int, + chunk_size: int, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None +) -> torch.Tensor: + B, T, H, D = batch_size, seq_len, *do.shape[-2:] + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + dx = do.new_empty(B, T, H, D) + def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H) + mean_pooling_bwd_kernel[grid]( + do, + dx, + offsets, + indices, + T=T, + H=H, + D=D, + BT=BT, + NT=NT + ) + return dx + + +class MeanPoolingFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + x: torch.Tensor, + chunk_size: int, + offsets: Optional[torch.LongTensor] = None + ) -> torch.Tensor: + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None + o = mean_pooling_fwd(x, chunk_size, offsets, indices) + ctx.batch_size = x.shape[0] + ctx.seq_len = x.shape[1] + ctx.chunk_size = chunk_size + ctx.offsets = offsets + ctx.indices = indices + return o + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward( + ctx, do + ) -> Tuple[torch.Tensor, None, None]: + batch_size = ctx.batch_size + seq_len = ctx.seq_len + chunk_size = ctx.chunk_size + offsets = ctx.offsets + indices = ctx.indices + dx = mean_pooling_bwd(do, batch_size, seq_len, chunk_size, offsets, indices) + return dx, None, None + + +def mean_pooling( + x: torch.Tensor, + chunk_size: int, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + if head_first: + x = x.transpose(1, 2) + if cu_seqlens is not None: + if x.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {x.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + o = MeanPoolingFunction.apply(x, chunk_size, cu_seqlens) + if head_first: + o = o.transpose(1, 2) + return o diff --git a/torchtitan/components/__pycache__/checkpoint.cpython-312.pyc b/torchtitan/components/__pycache__/checkpoint.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..877122ed71e92fa56d23b733470bb1f596608d71 Binary files /dev/null and b/torchtitan/components/__pycache__/checkpoint.cpython-312.pyc differ diff --git a/torchtitan/components/__pycache__/dataloader.cpython-312.pyc b/torchtitan/components/__pycache__/dataloader.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..820c739e768a36a05aa8895f38c1ffb6bd298150 Binary files /dev/null and b/torchtitan/components/__pycache__/dataloader.cpython-312.pyc differ diff --git a/torchtitan/components/__pycache__/ft.cpython-312.pyc b/torchtitan/components/__pycache__/ft.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af846758e12e377c92a33e0ff81718a2a5a117b0 Binary files /dev/null and b/torchtitan/components/__pycache__/ft.cpython-312.pyc differ diff --git a/torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc b/torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94ad9a61da3189a70ec2ef930a01555ccc92babd Binary files /dev/null and b/torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc differ diff --git a/torchtitan/components/__pycache__/metrics.cpython-312.pyc b/torchtitan/components/__pycache__/metrics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87f50cfb3547c68d7fba1abfedbd020fe0526ffd Binary files /dev/null and b/torchtitan/components/__pycache__/metrics.cpython-312.pyc differ diff --git a/torchtitan/components/__pycache__/optimizer.cpython-312.pyc b/torchtitan/components/__pycache__/optimizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d8db91f1775f99114b507abbebada15b642df8c Binary files /dev/null and b/torchtitan/components/__pycache__/optimizer.cpython-312.pyc differ diff --git a/torchtitan/components/ft.py b/torchtitan/components/ft.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d270d26dd9ceb604f1189c2caee0b6733dd73e --- /dev/null +++ b/torchtitan/components/ft.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import importlib +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.distributed._functional_collectives as funcol +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims + +if importlib.util.find_spec("torchft") is not None: + import torchft as ft + + has_torchft = True +else: + has_torchft = False + + +class FTManager: + def __init__( + self, + manager: Optional["ft.Manager"], + group_size: int = 1, + replica_id: int = 0, + ) -> None: + self._manager = manager + self.group_size = group_size + self.replica_id = replica_id + + @property + def enabled(self) -> bool: + return self._manager is not None + + @property + def manager(self) -> "ft.Manager": + assert self._manager is not None + return self._manager + + def get_dp_info(self, dp_degree: int, dp_rank: int) -> tuple[int, int]: + return dp_degree * self.group_size, dp_degree * self.replica_id + dp_rank + + +def init_ft_manager(job: JobConfig) -> FTManager: + """Initialize the FT manager if TorchFT is enabled. + + Args: + job (JobConfig): The job configuration. + + Returns: + Optional[ft.Manager]: The FT manager if TorchFT is enabled, otherwise None. + """ + if not job.fault_tolerance.enable: + return FTManager(None) + + if not has_torchft: + raise ImportError("torchft is not installed. Please install it.") + + if job.fault_tolerance.min_replica_size < 1: + raise ValueError("At least one FT replica is required.") + + pg = ft.ProcessGroupBabyNCCL() + + return FTManager( + ft.Manager( + pg=pg, + min_replica_size=job.fault_tolerance.min_replica_size, + load_state_dict=None, + state_dict=None, + use_async_quorum=True, + replica_id=f"torchtitan_ft_{job.fault_tolerance.replica_id}", + ), + group_size=job.fault_tolerance.group_size, + replica_id=job.fault_tolerance.replica_id, + ) + + +@dataclass +class FTParallelDims(ParallelDims): + ft_manager: FTManager + + def build_mesh(self, device_type: str) -> DeviceMesh: + def func( + device_type: str, mesh_shape: list[int], mesh_dim_names: list[str] + ) -> DeviceMesh: + from torchft.process_group import ft_init_device_mesh + + return ft_init_device_mesh( + device_type=device_type, + mesh_shape=mesh_shape, + mesh_dim_names=mesh_dim_names, + replicate_dim=mesh_dim_names.index("dp_replicate"), + manager=self.ft_manager.manager, + ) + + dims = [] + names = [] + for d, name in zip( + [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], + ["pp", "dp_replicate", "dp_shard", "cp", "tp"], + ): + if d > 1 or name == "dp_replicate": + dims.append(d) + names.append(name) + + return self._build_mesh(device_type, dims, names, func) + + @property + def dp_replicate_enabled(self): + return True + + +def ft_dist_reduce( + x: torch.Tensor, reduceOp: str, mesh: DeviceMesh +) -> tuple[torch.Tensor, str, DeviceMesh]: + if has_torchft and isinstance(mesh, ft.process_group._FlattenDeviceMesh): + x = funcol.all_reduce( + x, reduceOp=reduceOp, group=mesh.managed_mesh.replicate_pg + ) + return x, reduceOp, mesh.managed_mesh.mesh + return x, reduceOp, mesh + + +def ft_clip_grad_norm_util(total_norm: DTensor) -> torch.Tensor: + if has_torchft: + mesh = total_norm._spec.mesh + if isinstance(mesh, ft.process_group.ManagedDeviceMesh): + # The gradients along the replicated dim has already been reduced. + # So we don't need another reducution beforing removing the + # replicate dimension + local_tensor = total_norm.to_local() + placements = list(copy.copy(total_norm._spec.placements)) + placements.pop(mesh.replicate_dim) + return DTensor.from_local(local_tensor, mesh.mesh, placements) + + return total_norm diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..5935a5b09f2756933b46870db18e435478dc4282 --- /dev/null +++ b/torchtitan/datasets/hf_datasets.py @@ -0,0 +1,173 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Any, Callable + +import torch + +from datasets import Dataset, load_dataset +from datasets.distributed import split_dataset_by_node +from torch.distributed.checkpoint.stateful import Stateful +from torch.utils.data import IterableDataset + +from torchtitan.components.dataloader import ParallelAwareDataloader +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger + + +def _load_c4_dataset(dataset_path: str): + """Load C4 dataset with default configuration.""" + return load_dataset(dataset_path, name="en", split="train", streaming=True) + + +def _process_c4_text(sample: dict[str, Any]) -> str: + """Process C4 dataset sample text.""" + return sample["text"] + + +@dataclass +class DatasetConfig: + path: str + loader: Callable + text_processor: Callable + + +# Add your dataset here here - more information at docs/datasets.md +DATASETS = { + "c4": DatasetConfig( + path="allenai/c4", + loader=_load_c4_dataset, + text_processor=_process_c4_text, + ), + "c4_test": DatasetConfig( + path="tests/assets/c4_test", + loader=lambda path: load_dataset(path, split="train"), + text_processor=_process_c4_text, + ), +} + + +def _validate_dataset( + dataset_name: str, dataset_path: str | None = None +) -> tuple[str, Callable, Callable]: + """Validate dataset name and path.""" + if dataset_name not in DATASETS: + raise ValueError( + f"Dataset {dataset_name} is not supported. " + f"Supported datasets are: {list(DATASETS.keys())}" + ) + + config = DATASETS[dataset_name] + path = dataset_path or config.path + logger.info(f"Preparing {dataset_name} dataset from {path}") + return path, config.loader, config.text_processor + + +class HuggingFaceDataset(IterableDataset, Stateful): + def __init__( + self, + dataset_name: str, + dataset_path: str | None, + tokenizer: Tokenizer, + seq_len: int = 2048, + dp_rank: int = 0, + dp_world_size: int = 1, + infinite: bool = False, + ) -> None: + # Force lowercase for consistent comparison + dataset_name = dataset_name.lower() + + path, dataset_loader, text_processor = _validate_dataset( + dataset_name, dataset_path + ) + ds = dataset_loader(path) + + self.dataset_name = dataset_name + self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) + self._tokenizer = tokenizer + self.seq_len = seq_len + self.infinite = infinite + self._text_processor = text_processor + + # Variables for checkpointing + self._sample_idx = 0 + self._all_tokens: list[int] = [] + + def _get_data_iter(self): + if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): + return iter([]) + + it = iter(self._data) + for _ in range(self._sample_idx): + next(it) + return it + + def __iter__(self): + max_buffer_token_len = 1 + self.seq_len + + while True: + for sample in self._get_data_iter(): + # Use the dataset-specific text processor + sample_text = self._text_processor(sample) + sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) + self._all_tokens.extend(sample_tokens) + self._sample_idx += 1 + + while len(self._all_tokens) >= max_buffer_token_len: + x = torch.LongTensor(self._all_tokens[:max_buffer_token_len]) + # update tokens to the remaining tokens + self._all_tokens = self._all_tokens[max_buffer_token_len:] + input = x[:-1] + label = x[1:] + yield {"input": input}, label + + if not self.infinite: + logger.warning(f"Dataset {self.dataset_name} has run out of data") + break + else: + # Reset offset for the next iteration + self._sample_idx = 0 + logger.warning(f"Dataset {self.dataset_name} is being re-looped") + + def load_state_dict(self, state_dict): + self._sample_idx = state_dict["sample_idx"] + self._all_tokens = state_dict["token_buffer"] + + def state_dict(self): + return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx} + + +def build_hf_dataloader( + dp_world_size: int, + dp_rank: int, + tokenizer: Tokenizer, + job_config: JobConfig, + infinite: bool = True, +) -> ParallelAwareDataloader: + """Build a data loader for HuggingFace datasets.""" + dataset_name = job_config.training.dataset + dataset_path = job_config.training.dataset_path + batch_size = job_config.training.batch_size + seq_len = job_config.training.seq_len + + hf_ds = HuggingFaceDataset( + dataset_name=dataset_name, + dataset_path=dataset_path, + tokenizer=tokenizer, + seq_len=seq_len, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + infinite=infinite, + ) + + return ParallelAwareDataloader( + dataset=hf_ds, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + batch_size=batch_size, + ) diff --git a/torchtitan/datasets/tokenizer/tiktoken.py b/torchtitan/datasets/tokenizer/tiktoken.py new file mode 100644 index 0000000000000000000000000000000000000000..401757a93e6b598a6a3a60c4ca934ea0427f25a4 --- /dev/null +++ b/torchtitan/datasets/tokenizer/tiktoken.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import os +from collections.abc import Collection, Iterator, Sequence, Set as AbstractSet +from pathlib import Path +from typing import cast, Literal + +import tiktoken +from tiktoken.load import load_tiktoken_bpe + +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger + + +class TikTokenizer(Tokenizer): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + + special_tokens: dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950 + + def __init__(self, model_path: str): + super().__init__() + assert os.path.exists( + model_path + ), f"The tokenizer path does not exist: {model_path}" + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [ + f"<|reserved_special_token_{i}|>" + for i in range(5, self.num_reserved_special_tokens - 5) + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + self._n_words: int = self.model.n_vocab + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.pad_id: int = -1 + self.stop_tokens = { + self.special_tokens["<|end_of_text|>"], + self.special_tokens["<|eot_id|>"], + } + logger.info( + f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}" + ) + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Literal["all"] | AbstractSet[str] | None = None, + disallowed_special: Literal["all"] | Collection[str] | None = None, + ) -> list[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_tokens ("all"|set[str]): allowed special tokens in string + disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + assert type(s) is str + allowed_special = allowed_special or set() + disallowed_special = disallowed_special or () + + # The tiktoken tokenizer can handle <=400k chars without + # pyo3_runtime.PanicException. + TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + + # https://github.com/openai/tiktoken/issues/195 + # Here we iterate over subsequences and split if we exceed the limit + # of max consecutive non-whitespace or whitespace characters. + MAX_NO_WHITESPACES_CHARS = 25_000 + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: list[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(list[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + + +def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer: + return TikTokenizer(job_config.model.tokenizer_path) diff --git a/torchtitan/experiments/deepseek_v3/model.py b/torchtitan/experiments/deepseek_v3/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0669df9528b3db0de3325db36f010312b5b3eac7 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/model.py @@ -0,0 +1,1325 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This code is based on model definition of `deepseek-ai/DeepSeek-V3-Base` on +# Hugging Face Model Hub. Url: +# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py +# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/resolve/main/configuration_deepseek.py +# +# It has been modified from its original forms to accommodate naming convention +# and usage patterns of the TorchTitan project. + +# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeepSeek model.""" +import math +from typing import Optional, Tuple + +import torch +import torch.distributed as dist + +import torch.distributed._symmetric_memory as symm_mem +import torch.nn.functional as F +import torch.utils.checkpoint + +from attn_mask_utils import _prepare_4d_causal_attention_mask +from indices import generate_permute_indices +from model_config import ModelArgs +from symm_mem_recipes import OnDeviceAllToAllV +from torch import nn +from torch.distributed._functional_collectives import all_to_all_single_autograd + +from torchtitan.experiments.kernels.triton_mg_group_gemm.torchao_pr import ( + ALIGN_SIZE_M, + grouped_gemm_forward, +) + +# Get model parallel subgroup by name: +# e.g. "pp", "ep", None +def get_group(dim_name: Optional[str] = None) -> dist.ProcessGroup: + glob = torch.distributed.device_mesh._mesh_resources.get_current_mesh() + return glob.get_group(dim_name) + + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Deepseek +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class YarnRotaryEmbedding(RotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MLP(nn.Module): + act_fn = nn.SiLU() + + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim)) + ) + if self.topk_method == "noaux_tc": + self.e_score_correction_bias = nn.Parameter( + # Changed from torch.empty to torch.rand to avoid non-even + # distribution for runs without actual weigths + torch.rand((self.n_routed_experts)) + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + # compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear( + hidden_states.type(torch.float32), self.weight.type(torch.float32), None + ) + if self.scoring_func == "sigmoid": + scores = logits.sigmoid() + elif self.scoring_func == "softmax": + scores = logits.softmax(dim=-1, dtype=torch.float32) + else: + raise NotImplementedError( + f"insupportable scoring function for MoE gating: {self.scoring_func}" + ) + + # select top-k experts + if self.topk_method == "noaux_tc": + scores_for_choice = scores.view( + bsz * seq_len, -1 + ) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(bsz * seq_len, self.n_group, -1) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group + ) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores_for_choice.masked_fill( + ~score_mask.bool(), 0.0 + ) # [n, e] + _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) + topk_weight = scores.gather(1, topk_idx) + elif self.topk_method == "greedy": + topk_weight, topk_idx = torch.topk( + scores, k=self.top_k, dim=-1, sorted=False + ) + else: + raise NotImplementedError( + f"insupportable TopK function for MoE gating: {self.topk_method}" + ) + + # norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = ( + topk_weight * self.routed_scaling_factor + ) # must multiply the scaling factor + + return topk_idx, topk_weight + + +class MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + # Class attributes: + # Two shuffle method supported: + # 1. "torch_all_to_all" + # 2. "symm_mem" (see `setup_symm_mem` below) + shuffle_method = "torch_all_to_all" + + # Symmetric memory buffers shared by all MoE instances across layers + token_send_buf: Optional[torch.Tensor] = None + token_gather_buf: Optional[torch.Tensor] = None + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + # ep_size is the number of ranks in expert dimension + if config.ep_size <= 1: + raise ValueError( + "For code simplicity, this model only supports distributed experts, " + "thus EP size must be > 1, please modify your model config" + ) + self.ep_group = get_group("ep") + assert config.ep_size == self.ep_group.size() + self.ep_size = config.ep_size + self.ep_rank = self.ep_group.rank() + self.experts_per_rank = config.n_routed_experts // config.ep_size + # Use ModuleDict instead of ModuleList to preserve absoulte expert + # IDs while avoiding `None` experts. The absolute expert IDs match + # with checkpoint FQNs. + self.experts = nn.ModuleDict() + for i in range(self.experts_per_rank): + abs_expert_id = self.ep_rank * self.experts_per_rank + i + self.experts[str(abs_expert_id)] = MLP( + config, intermediate_size=config.moe_intermediate_size + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = MLP( + config=config, intermediate_size=intermediate_size + ) + + def combine_experts(self, submod_name): + all_weights = [] + for expert in self.experts.values(): + lin = expert.get_submodule(submod_name) + all_weights.append(lin.weight) + lin.weight = None + + concat_weight = torch.cat(all_weights) + self.register_parameter(f"{submod_name}_weight", nn.Parameter(concat_weight)) + + # This function is used to create a symm mem buffer for MoE's. It is for + # shuffling tokens fully "on-device", as compared to traditional torch + # all_to_all APIs which requrie a GPU-to-CPU sync of the splits. If a user + # calls this function, the `shuffle_method` would switch from + # `torch_all_to_all` to `symm_mem`. + def setup_symm_mem(self, dtype: torch.dtype, device: torch.device): + # Switch shuffle method + self.shuffle_method = "symm_mem" + + # Combine expert weights + print("Combining expert weights for Group GEMM") + self.combine_experts("gate_proj") + self.combine_experts("up_proj") + self.combine_experts("down_proj") + + # Assuming worst case, 2x tokens are routed to one EP rank + overflow = 2 + OnDeviceAllToAllV.max_output_len = ( + self.config.max_seq_len * self.num_experts_per_tok * overflow + ) + + # Symmetric memory buffers are shared by all MoE instances across + # layers, we only need to initialize them once + if MoE.token_send_buf is not None: + return + + # Input buffer for DP-to-EP shuffle + MoE.token_send_buf = symm_mem.empty( + self.config.max_seq_len + * self.num_experts_per_tok, # seq len * top k (flattened) + self.config.hidden_size, # hidden dim + dtype=dtype, + device=device, + ) + # Input buffer for EP-to-DP shuffle + MoE.token_gather_buf = symm_mem.empty( + self.config.max_seq_len + * self.num_experts_per_tok # seq len * top k (flattened) + * overflow, + self.config.hidden_size, # hidden dim + dtype=dtype, + device=device, + ) + print(f"EP rank [{self.ep_rank}]: Created Symmetric Memory for MoE") + + def get_send_buf(self): + # [Why detach?] During a first forward-backward step, the buffer would + # be included in a computational graph. In a second step, autograd will + # return an error saying "Trying to backward through the graph a second + # time (or directly access saved tensors more than once)". This is + # because the buffer is still in the graph, and autograd is trying to + # backward through the graph a second time. To avoid this, we detach the + # buffer from the graph. `detach()` returns a new tensor, which shares + # the same storage with the original one. + self.token_send_buf.grad = None + return self.token_send_buf.detach() + + def get_gather_buf(self): + # See [Why detach?] in `get_send_buf` + self.token_gather_buf.grad = None + return self.token_gather_buf.detach() + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + # for each token, select top-k experts, and compute the weight for each expert + topk_idx, topk_weight = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + if self.shuffle_method == "symm_mem": + y = self.moe_on_device(hidden_states, topk_idx, topk_weight) + else: # "torch_all_to_all" + y = self.moe_forward(hidden_states, topk_idx, topk_weight) + + y = y.view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y + + def moe_forward(self, x, topk_ids, topk_weight): + # This part sorts the token indices so that tokens routed to the same expert reside consecutively. + # An implication is that tokens to the same "expert group" (i.e., device) are also consecutive. + # Since this is an "aritificial" index creation (final outcome being + # `idxs`), we don't need gradients here. + with torch.no_grad(): + # [seq_len, n_routed_experts] + cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts)) + # Fill 1 to the selected experts + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + # Token indices for each expert + idxs = topk_ids.view(-1).argsort() + sorted_tokens_shape = idxs.shape + x.shape[1:] + + sorted_tokens = x[idxs // topk_ids.shape[1]] + assert sorted_tokens.shape == sorted_tokens_shape + + # This part exchange the information about the number of tokens send and + # received by each expert. We can understand this information as "side + # band", which is not part of the actual data. Thus no gradient is + # needed. + with torch.no_grad(): + # Sum the tokens over local experts, then we get tokens per EP rank, + # which is the input splits + tokens_per_expert_group = tokens_per_expert.new_empty( + tokens_per_expert.shape[0] + ) + dist.all_to_all_single( + tokens_per_expert_group, tokens_per_expert, group=self.ep_group + ) + input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + + # DP to EP token shuffle. This part needs gradient. + if self.shuffle_method == "symm_mem": + # Move input to the `token_send_buf` symm mem + token_send_buf = self.get_send_buf() + token_send_buf[: idxs.shape[0]].copy_(sorted_tokens) + # Note: `out=` avoids copy, but it is not differentiable + # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]]) + token_gather_buf, output_splits = OnDeviceAllToAllV.apply( + token_send_buf, + input_splits, + self.ep_group, + ) + with torch.no_grad(): + # Received tokens from all other ranks. TODO: use mask instead + received = output_splits.sum() + # TODO: don't use `received` + gathered_tokens = token_gather_buf[:received] + else: # "torch_all_to_all" + # Prepare input ans output splits + with torch.no_grad(): + output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum( + dim=1 + ) + gathered_tokens = all_to_all_single_autograd( + sorted_tokens, + output_splits.tolist(), + input_splits.tolist(), + self.ep_group, + ) + + # This part prepares a 1D tensor with the same length as + # `gathered_tokens`. The 1D tensor is filled with local expert IDs which + # the tokens in `gathered_tokens` are headed for. This part doesn't need + # gradient. + with torch.no_grad(): + gatherd_idxs = ( + torch.arange( + tokens_per_expert_group.numel(), + device=tokens_per_expert_group.device, + ) + % self.experts_per_rank + ) + gatherd_idxs = gatherd_idxs.repeat_interleave(tokens_per_expert_group) + + # Prepare buffer for tokens processed by experts + if self.shuffle_method == "symm_mem": + # Take necessary space from `token_gather_buf` symm mem because we are + # going to send them out after expert processing + processed_tokens = self.get_gather_buf()[: gathered_tokens.shape[0]] + else: # "torch_all_to_all" + processed_tokens = torch.empty_like(gathered_tokens) + + # This part processes the tokens routed to the local experts. + # TODO: can we use group GEMM here? + for i, expert in enumerate(self.experts.values()): + processed_tokens[gatherd_idxs == i] = expert( + gathered_tokens[gatherd_idxs == i] + ) + + # Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle. + # The input/output splits are just a reverse of the previous shuffle. + if self.shuffle_method == "symm_mem": + token_return_buf, _ = OnDeviceAllToAllV.apply( + processed_tokens, + output_splits, + self.ep_group, + ) + returned_tokens = token_return_buf[: sorted_tokens_shape[0]] + else: # "torch_all_to_all" + returned_tokens = all_to_all_single_autograd( + processed_tokens, + input_splits.tolist(), + output_splits.tolist(), + self.ep_group, + ) + + output_tokens = torch.empty_like(returned_tokens) + output_tokens[idxs] = returned_tokens + final_out = ( + output_tokens.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(returned_tokens.dtype) + ) + return final_out + + def moe_on_device(self, x, topk_ids, topk_weight): + # This part sorts the token indices so that tokens routed to the same expert reside consecutively. + # An implication is that tokens to the same "expert group" (i.e., device) are also consecutive. + # Since this is an "aritificial" index creation (final outcome being + # `idxs`), we don't need gradients here. + with torch.no_grad(): + # [seq_len, n_routed_experts] + cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts)) + # Fill 1 to the selected experts + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + # Token indices for each expert + idxs = topk_ids.view(-1).argsort() + sorted_tokens_shape = idxs.shape + x.shape[1:] + + sorted_tokens = x[idxs // topk_ids.shape[1]] + assert sorted_tokens.shape == sorted_tokens_shape + + # This part exchange the information about the number of tokens send and + # received by each expert. We can understand this information as "side + # band", which is not part of the actual data. Thus no gradient is + # needed. + with torch.no_grad(): + # Sum the tokens over local experts, then we get tokens per EP rank, + # which is the input splits + tokens_per_expert_group = tokens_per_expert.new_empty( + tokens_per_expert.shape[0] + ) + dist.all_to_all_single( + tokens_per_expert_group, tokens_per_expert, group=self.ep_group + ) + input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + + # Move input to the `token_send_buf` symm mem + token_send_buf = self.get_send_buf() + token_send_buf[: idxs.shape[0]].copy_(sorted_tokens) + # Note: `out=` avoids copy, but it is not differentiable + # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]]) + token_gather_buf, output_splits = OnDeviceAllToAllV.apply( + token_send_buf, + input_splits, + self.ep_group, + ) + + # We need to permute the received tokens so that tokens for the same expert are contiguous. + # This part prepares a 1D tensor `permuted_indices` for such permutation. + # This part doesn't need gradient. + with torch.no_grad(): + permuted_indices, m_sizes = generate_permute_indices( + tokens_per_expert_group, + self.experts_per_rank, + self.ep_size, + token_gather_buf.shape[0], + ALIGN_SIZE_M, + ) + + # Permute the received tokens so that tokens for the same expert are contiguous. + contig_tokens = token_gather_buf[permuted_indices] + + # Run the first grouped GEMM + w1 = self.get_parameter("gate_proj_weight") + gate_proj = grouped_gemm_forward(contig_tokens, w1, m_sizes) + + # Run the second grouped GEMM + w3 = self.get_parameter("up_proj_weight") + up_proj = grouped_gemm_forward(contig_tokens, w3, m_sizes) + + # Apply activation + hidden_outputs = MLP.act_fn(gate_proj) * up_proj + + # Run the third grouped GEMM + w2 = self.get_parameter("down_proj_weight") + hidden_outputs = grouped_gemm_forward(hidden_outputs, w2, m_sizes) + + # Prepare buffer for tokens processed by experts + # Take necessary space from `token_gather_buf` symm mem because we are + # going to send them out after expert processing + processed_tokens = self.get_gather_buf() + + # Move into Symmetric Memory for the return shuffle + processed_tokens[permuted_indices] = hidden_outputs + + # Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle. + # The input/output splits are just a reverse of the previous shuffle. + token_return_buf, _ = OnDeviceAllToAllV.apply( + processed_tokens, + output_splits, + self.ep_group, + ) + returned_tokens = token_return_buf[: sorted_tokens_shape[0]] + + output_tokens = torch.empty_like(returned_tokens) + output_tokens[idxs] = returned_tokens + final_out = ( + output_tokens.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(returned_tokens.dtype) + ) + return final_out + + +class Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: ModelArgs, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + self._init_rope() + + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if attention_mask is not None: + # Attention mask was made 4D because the `attn_weights` above is 4D. + # We probably can make this mask smarter if we want to pack sequences + # together, instead of using padding. This optimization can be used in + # inference. For training, if we want to pack sequences, data loader + # will pass in a mask containing such info. + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, # None, or user provided mask in 2D + (bsz, q_len), + hidden_states, + 0, # past_key_values_length, 0 when training + ) + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query=query_states, + key=key_states, + value=value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout, + is_causal=attention_mask is None, + scale=self.softmax_scale, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output + + +class DecoderLayer(nn.Module): + def __init__(self, config: ModelArgs, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Attention(config=config, layer_idx=layer_idx) + + self.mlp = ( + MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else MLP(config) + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +Deepseek_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class DeepseekModel(torch.nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DecoderLayer`] + + Args: + config: ModelArgs + """ + + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Creating model parts related to my stage + assert ( + config.stage_idx < config.num_stages + ), f"Stage {config.stage_idx} is not in the model" + print(f"Creating model stage {config.stage_idx} of {config.num_stages}") + + self.embed_tokens = ( + nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + if config.stage_idx == 0 + else None + ) + + self.layers = torch.nn.ModuleDict() + division = config.num_hidden_layers // config.num_stages + residual = config.num_hidden_layers % config.num_stages + # Some earlier stages may have 1 more layer than latter stages because + # the division may have residual; this is more even than giving the + # entire residual to the last stage. + layers_per_stage = [ + division + 1 if stage < residual else division + for stage in range(config.num_stages) + ] + assert sum(layers_per_stage) == config.num_hidden_layers + layer_id_start = sum(layers_per_stage[: config.stage_idx]) + layer_id_end = layer_id_start + layers_per_stage[config.stage_idx] + for layer_id in range(layer_id_start, layer_id_end): + self.layers[str(layer_id)] = DecoderLayer(config, layer_id) + + self.norm = ( + RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if config.stage_idx == config.num_stages - 1 + else None + ) + + # Initialize weights and apply final processing + self.apply(self._init_weights) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def forward( + self, + tokens: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + # Embedding + hidden_states = ( + self.embed_tokens(tokens) if self.embed_tokens is not None else tokens + ) + + # decoder layers + for decoder_layer in self.layers.values(): + hidden_states = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = ( + self.norm(hidden_states) if self.norm is not None else hidden_states + ) + return hidden_states + + +class DeepseekForCausalLM(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.model = DeepseekModel(config) + self.lm_head = ( + nn.Linear(config.hidden_size, config.vocab_size, bias=False) + if config.stage_idx == config.num_stages - 1 + else None + ) + + # Initialize weights and apply final processing + # self.post_init() + + def forward( + self, + tokens: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Tuple: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekForCausalLM + + >>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + hidden_states = self.model( + tokens, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + logits = ( + self.lm_head(hidden_states) if self.lm_head is not None else hidden_states + ) + return logits + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + **kwargs, + ): + if past_key_values is not None: + # Assuming isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + # Setup Symmetric Memory for MoE token shuffle. + # Supports inference currently. + def setup_symm_mem(self, dtype: torch.dtype, device: torch.device): + for layer in self.model.layers.values(): + if not isinstance(layer.mlp, MoE): + continue + layer.mlp.setup_symm_mem(dtype, device) diff --git a/torchtitan/experiments/deepseek_v3/model_config.py b/torchtitan/experiments/deepseek_v3/model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d559d4ee94ecf7fccc933cf1a243161d1796a123 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/model_config.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + + +@dataclass +class ModelArgs: + r""" + This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V3. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within + `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + """ + + vocab_size: int = 129280 + hidden_size: int = 7168 + intermediate_size: int = 18432 + moe_intermediate_size: int = 2048 + num_hidden_layers: int = 61 + num_nextn_predict_layers: int = 1 + num_attention_heads: int = 128 + num_key_value_heads: int = 128 + n_shared_experts: int = 1 + n_routed_experts: int = 256 + ep_size: int = 1 + routed_scaling_factor: float = 2.5 + kv_lora_rank: int = 512 + q_lora_rank: int = 1536 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + qk_nope_head_dim: int = 128 + topk_method: str = "noaux_tc" + n_group: int = 8 + topk_group: int = 4 + num_experts_per_tok: int = 8 + moe_layer_freq: int = 1 + first_k_dense_replace: int = 3 + norm_topk_prob: bool = True + scoring_func: str = "sigmoid" + aux_loss_alpha: float = 0.001 + seq_aux: bool = True + hidden_act: str = "silu" + max_position_embeddings: int = 163840 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + rope_theta: float = 10000.0 + rope_scaling: dict = field( + default_factory=lambda: { + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn", + } + ) + attention_bias: bool = False + attention_dropout: float = 0.0 + pad_token_id = None + # Added for symmetric memory + max_seq_len: int = 4096 + dtype: str = "bfloat16" + # Added for pipeline parallel + num_stages: int = 1 + stage_idx: int = 0 + + +# This is the configuration for deepseek-ai/DeepSeek-V2-Lite. +deepseek_v2_lite_config = ModelArgs( + vocab_size=102400, + hidden_size=2048, + intermediate_size=10944, + moe_intermediate_size=1408, + num_hidden_layers=27, + num_attention_heads=16, + num_key_value_heads=16, + n_shared_experts=2, + n_routed_experts=64, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=None, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="greedy", + n_group=1, + topk_group=1, + num_experts_per_tok=6, + first_k_dense_replace=1, + norm_topk_prob=False, + scoring_func="softmax", + max_position_embeddings=4096, + rope_scaling={ + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 0.707, + "mscale_all_dim": 0.707, + "original_max_position_embeddings": 4096, + "type": "yarn", + }, +) + + +# Model configuration registry +# Key is the model distribution ID on HuggingFace Hub +deepseek_config_registry = { + "deepseek-ai/DeepSeek-V2-Lite": deepseek_v2_lite_config, + "deepseek-ai/DeepSeek-V2-Lite-Chat": deepseek_v2_lite_config, + "deepseek-ai/deepseek-v3": ModelArgs(), +} diff --git a/torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..335bc2d966efbe486418525cb784078a6ec879d5 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .triton_on_device_all_to_all_v import OnDeviceAllToAllV + +__all__ = [ + "OnDeviceAllToAllV", +] diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py new file mode 100644 index 0000000000000000000000000000000000000000..76e0b12d882fa46ed1f11139352141f06d899f59 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py @@ -0,0 +1,299 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging + +import numpy as np +import torch + +from reference_utils import ( + analyze_tensor_differences, + compute_reference_backward, + compute_reference_forward, +) + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +# Import grouped GEMM implementations +try: + from mg_grouped_gemm import grouped_gemm_backward, grouped_gemm_forward + +except ImportError: + logging.error( + "Error importing grouped GEMM modules. Make sure the implementation files are in the correct path." + ) + raise + + +def test_forward_pass(): + """ + A simple test for the M*G grouped GEMM forward pass with detailed error handling. + + In M*G grouping: + - M dimension is partitioned into G groups (M_total = sum(M_sizes)) + - N dimension is the same for all groups + """ + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Test parameters for DeepSeek-like models + G = 1 # Number of groups + M_sizes = [ + 2048, + ] # 2048, 2048, 2048] # Group sizes (will be adjusted) + M_total = sum(M_sizes) # Total M dimension + N = 4096 # Output dimension (same for all groups) + K = 7168 # Hidden dimension + + # Create group sizes tensor + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create input and weight tensors - using float16 for higher precision + x = torch.randn(M_total, K, dtype=torch.float16, device=device) + w = torch.randn(N, K, dtype=torch.float16, device=device) + + # Log the setup + logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}") + logging.info(f"Group sizes: {m_sizes}") + logging.info(f"Input x shape: {x.shape}") + logging.info(f"Weight w shape: {w.shape}") + + # Run forward pass + logging.info("Running forward pass with grouped GEMM") + result = grouped_gemm_forward(x, w, m_sizes) + logging.info(f"Forward result shape: {result.shape}") + + # Compute reference result + logging.info("Computing reference result with PyTorch") + reference_result = compute_reference_forward(x, w, m_sizes) + + # Compare results + logging.info("Comparing with PyTorch reference") + forward_close = analyze_tensor_differences( + result, reference_result, "Forward output" + ) + + return forward_close + + except Exception as e: + logging.error(f"Test failed with error: {e}") + import traceback + + logging.error(traceback.format_exc()) + return False + + +def test_backward_pass(): + """ + A simple test for the M*G grouped GEMM backward pass with detailed error handling. + + In M*G grouping: + - M dimension is partitioned into G groups (M_total = sum(M_sizes)) + - N dimension is the same for all groups + """ + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Test parameters for DeepSeek-like models + G = 4 # Number of groups + M_sizes = [2048, 2048, 2048, 2048] # Group sizes (will be adjusted) + M_total = sum(M_sizes) # Total M dimension + N = 4096 # Output dimension (same for all groups) + K = 7168 # Hidden dimension + + # Create group sizes tensor + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create input and weight tensors - using float16 for higher precision + x = torch.randn( + M_total, K, dtype=torch.float16, device=device, requires_grad=True + ) + w = torch.randn(N, K, dtype=torch.float16, device=device, requires_grad=True) + + # Log the setup + logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}") + logging.info(f"Group sizes: {m_sizes}") + logging.info(f"Input x shape: {x.shape}") + logging.info(f"Weight w shape: {w.shape}") + + # Step 1: Run forward pass + logging.info("Running forward pass") + result = grouped_gemm_forward(x, w, m_sizes) + logging.info(f"Forward result shape: {result.shape}") + + # Create a gradient for backpropagation + grad_output = torch.randn_like(result) + logging.info(f"Created gradient with shape: {grad_output.shape}") + + # Step 2: Run backward pass directly + logging.info("Running backward pass directly") + grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes) + + # Verify gradient shapes + logging.info( + f"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}" + ) + + # Step 3: Verify gradient computation using PyTorch's autograd + logging.info("Running PyTorch reference implementation") + + # Compute reference gradients + x_ref_grad, w_ref_grad = compute_reference_backward(x, w, m_sizes, grad_output) + + # Compare gradients + logging.info("Comparing gradients with PyTorch reference") + grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x") + grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w") + + # Log overall result + if grad_x_close and grad_w_close: + logging.info("✓ SUCCESS: Gradients match the PyTorch reference") + else: + logging.error("✗ FAILURE: Gradient mismatch detected") + + return grad_x_close and grad_w_close + + except Exception as e: + logging.error(f"Test failed with error: {e}") + import traceback + + logging.error(traceback.format_exc()) + return False + + +def test_multiple_deepseek_configs(): + """ + Test multiple DeepSeek model configurations with both forward and backward pass verification. + """ + # DeepSeek configurations: (G, M, K, N) + configs = [ + (4, 8192, 7168, 4096), # Config 1 + (4, 8192, 2048, 7168), # Config 2 + (8, 4096, 7168, 4096), # Config 3 + (8, 4096, 2048, 7168), # Config 4 + ] + + results = [] + + for config_idx, (G, M, K, N) in enumerate(configs): + logging.info(f"\n\n===== Testing DeepSeek Config {config_idx+1} =====") + logging.info(f"G={G}, M={M}, K={K}, N={N}") + + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Create even group sizes + base_size = M // G + remainder = M % G + M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)] + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create input and weight tensors using float16 for higher precision + x = torch.randn( + M, K, dtype=torch.float16, device=device, requires_grad=True + ) + w = torch.randn( + N, K, dtype=torch.float16, device=device, requires_grad=True + ) + + logging.info(f"Input x shape: {x.shape}, Weight w shape: {w.shape}") + + # Run forward pass + result = grouped_gemm_forward(x, w, m_sizes) + logging.info(f"Forward result shape: {result.shape}") + + # ===== FORWARD PASS VERIFICATION ===== + # Compute reference forward result + reference_result = compute_reference_forward(x, w, m_sizes) + + # Compare forward results + forward_close = analyze_tensor_differences( + result, reference_result, "Forward output" + ) + + # ===== BACKWARD PASS VERIFICATION ===== + # Create gradient for backpropagation + grad_output = torch.randn_like(result) + + # Run backward pass + grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes) + + # Compute reference gradients + x_ref_grad, w_ref_grad = compute_reference_backward( + x, w, m_sizes, grad_output + ) + + # Compare backward results + grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x") + grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w") + + # Overall config result + backward_close = grad_x_close and grad_w_close + config_success = forward_close and backward_close + results.append( + (config_idx + 1, config_success, forward_close, backward_close) + ) + + # Log overall config result + if config_success: + logging.info(f"✓ SUCCESS: Config {config_idx+1} passed all tests!") + else: + logging.error( + f"✗ FAILURE: Config {config_idx+1} failed one or more tests" + ) + + except Exception as e: + logging.error(f"Config {config_idx+1} test failed with error: {e}") + import traceback + + logging.error(traceback.format_exc()) + results.append((config_idx + 1, False, False, False)) + + # Summary + logging.info("\n===== Test Results Summary =====") + for config_idx, overall_success, forward_success, backward_success in results: + overall_status = "✓ PASSED" if overall_success else "✗ FAILED" + forward_status = "✓ PASSED" if forward_success else "✗ FAILED" + backward_status = "✓ PASSED" if backward_success else "✗ FAILED" + + logging.info(f"Config {config_idx}: {overall_status}") + logging.info(f" - Forward pass: {forward_status}") + logging.info(f" - Backward pass: {backward_status}") + + return all(overall_success for _, overall_success, _, _ in results) + + +if __name__ == "__main__": + logging.info( + "Running verification for both forward and backward pass of M*G grouped GEMM" + ) + + # Run basic forward pass test + logging.info("\n===== Running basic forward pass test =====") + success_forward = test_forward_pass() + logging.info(f"Basic forward test {'succeeded' if success_forward else 'failed'}") + + # Run basic backward pass test + logging.info("\n===== Running basic backward pass test =====") + success_backward = test_backward_pass() + logging.info(f"Basic backward test {'succeeded' if success_backward else 'failed'}") + + # Run multiple DeepSeek configs with forward and backward verification + logging.info("\n===== Running tests for all DeepSeek configs =====") + success_configs = test_multiple_deepseek_configs() + logging.info( + f"DeepSeek configs tests {'all succeeded' if success_configs else 'had failures'}" + ) + + # Overall result + overall_success = success_forward and success_backward and success_configs + logging.info( + f"\nOverall test result: {'SUCCESS' if overall_success else 'FAILURE'}" + ) diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..37bf59f29e89b0bd3abb69d3e5d75bc14721b97b --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py @@ -0,0 +1,1304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# credit - flat index forward kernel is derived from FBGemm: +# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm + +# pyre-unsafe +import functools +import logging + +import os +import sys +from typing import Any, Dict, Optional, Tuple + +import torch + +import triton +import triton.language as tl +from triton import Config as TConfig + +from triton.runtime import driver # @manual + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from tma_autotuning import ( + ALIGN_SIZE_M, + _NV_CONFIGS, + CudaUtils, + early_config_prune, + TmaDescriptorHelper, +) + + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +# ============== Start Triton Kernels =============== + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_forward_hopper( + a_desc_ptr, + b_desc_ptr, + c_ptr, + workspace, + m_sizes, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + TMA_SIZE: tl.constexpr, + USE_EPILOGUE_SUBTILING: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + Flat index style forward kernel for Hopper. + For simplicity, we always use TMA Load and TMA Store + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = c_ptr.dtype.element_ty # output dtype + + c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store + + M_end = 0 + M_start = 0 + processed_tiles = 0 + # Size of individual weight matrix + n_size = N // G + n_start = 0 + + for g in range(G): + # Move down along groups + # reset to new M offset + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + n_start = n_size * g + + if m_size > 0: + # Process this group + + # Acquire hold on c_desc_ptr for TMA Store + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr + M_start * n_size, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[m_size, n_size], + element_ty=c_dtype, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + # tiles for this group + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + # columnwise + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + global_n_offset = (n_start + n_offset).to(tl.int32) + + for k_offset in range(0, K, BLOCK_SIZE_K): + # input block [M,K] + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + c_dtype, + ) + # weight block [N, K] + b = tl._experimental_descriptor_load( + b_desc_ptr, + [global_n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + c_dtype, + ) + + accumulator += tl.dot(a, b.T) + + # Store using TMA + + m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + + if USE_EPILOGUE_SUBTILING: + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(c_dtype) + tl._experimental_descriptor_store( + c_desc_ptr, c0, [m_offset, n_offset] + ) + c1 = acc1.to(c_dtype) + tl._experimental_descriptor_store( + c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2] + ) + else: + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_dtype), + [m_offset, n_offset], + ) + # move to next tile in group + tbidx += NUM_SMS + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_forward_tma( + a_desc_ptr, + b_desc_ptr, + c_ptr, + workspace, + m_sizes, + a_scale_ptr, + b_scale_ptr, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + USE_FP8: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + Flat index style forward kernel. + For simplicity, we always use TMA Load and TMA Store + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = c_ptr.dtype.element_ty + + c_desc_ptr = workspace + (tbidx * TMA_SIZE) + + M_end = 0 + processed_tiles = 0 + + for g in range(G): + # Move down along groups + # reset to new M offset + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + if m_size > 0: + # Process this group + n_size = N + + # TMA Store prep + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr + M_start * N, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[m_size, n_size], + element_ty=c_dtype, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + # tiles for this group + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + + for k_offset in range(0, K, BLOCK_SIZE_K): + # input block [M,K] + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + c_dtype, + ) + # weight block [N, K] + b = tl._experimental_descriptor_load( + b_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + c_dtype, + ) + + accumulator += tl.dot(a, b.T) + + # Store using TMA + + m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + # n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_dtype), + [m_offset, n_offset], + ) + + # Move to the next tile + tbidx += NUM_SMS + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_forward_no_tma( + a_ptr, + b_ptr, + c_ptr, + workspace, + m_sizes, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + Flat index style forward kernel. + For bc and Ampere, we never use TMA Load and TMA Store + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = c_ptr.dtype.element_ty + c_desc_ptr = None + + M_end = 0 + processed_tiles = 0 + + for g in range(G): + # Move down along groups + # reset to new M offset + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + if m_size > 0: + # Process this group + n_size = N + + # tiles for this group + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + + offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :] + b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :] + + for k_offset in range(0, K, BLOCK_SIZE_K): + # Load with bounds checking + a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) + b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) + + # Main matmul + accumulator += tl.dot(a, b.T) + + # Update pointers for next block + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + + # Store without TMA + offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + c = accumulator.to(c_dtype) + + tl.store( + c_ptr + + (M_start + offs_am[:, None]) * N # Row stride is N + + offs_bn[None, :], # Column offset + c, + mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size, + ) + # Move to the next tile + tbidx += NUM_SMS + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +""" +Backward pass for grouped GEMM with Triton, where grouping is M*G +We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`). +""" + + +# ---- dx flat linear indexed ---- +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_dx_tma( + grad_output_desc_ptr, # [MG, N] + w_desc_ptr, # [N, K] + grad_input_ptr, # output grad_x [MG, K] + workspace, # for TMA store + m_sizes, # group sizes [G] + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + TMA-optimized kernel for computing gradients with respect to input (dx). + For the forward pass Y = X @ W.T, the backward for input is: + grad_X = grad_Y @ W + + This maps to [MG, N] @ [N, K] -> [MG, K] + + Key differences from forward: + 1. W is used directly and not transposed + 2. The reduction dimension is now N (not K) + 3. Output is [M, K] instead of [M, N] + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = grad_input_ptr.dtype.element_ty + c_desc_ptr = workspace + (tbidx * TMA_SIZE) + + M_end = 0 + processed_tiles = 0 + + for g in range(G): + # Move down along groups - same as forward + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + if m_size > 0: + # Process this group + # tiles for this group - now producing [M, K] output + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + group_num_tiles = num_m_tiles * num_k_tiles + + # TMA Store prep for [M, K] output + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=grad_input_ptr + M_start * K, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], + global_size=[m_size, K], + element_ty=c_dtype, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + # Different tiling scheme for [M, K] output + tile_m_index = group_index % num_m_tiles + tile_k_index = group_index // num_m_tiles + + # for grad_input block [M, K] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + + # Position in full matrix + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32) + + # reduce along N dimension (instead of K in forward) + for n_offset in range(0, N, BLOCK_SIZE_N): + # grad_output block [M, N] + grad_output = tl._experimental_descriptor_load( + grad_output_desc_ptr, + [m_offset, n_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + c_dtype, + ) + + # weight block [N, K] - no transpose needed + w = tl._experimental_descriptor_load( + w_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + c_dtype, + ) + + # grad_x = grad_output @ w + # reducing along N dimension + accumulator += tl.dot(grad_output, w) + + # Store using TMA + m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + # k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32) + + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_dtype), + [m_offset, k_offset], + ) + + # Move to the next tile + tbidx += NUM_SMS + + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +# ---- dw flat linear indexed ---- + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_dw_tma( + x_desc_ptr, # input descriptor [M_total, K] + grad_output_desc_ptr, # grad_output descriptor [M_total, N] + grad_weight_ptr, # output grad_w [N, K] + workspace, # workspace for TMA store + m_sizes, # group sizes [G] + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + # tiles + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension +) -> None: + """ + Improved TMA-optimized kernel for computing gradients with respect to weights (dw). + Uses flat index structure similar to forward. + + For the forward pass Y = X @ W.T, + the backward for weights is: + grad_W = grad_Y.T @ X + + Where: + - grad_Y is [MG, N] + - X is [MG, K] + - grad_W is [N, K] + - we return [N,K] + """ + # Get thread block index l + tbidx = tl.program_id(0) + + # Get output data type + c_dtype = grad_weight_ptr.dtype.element_ty + + # Calculate number of output tiles + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + total_output_tiles = num_n_tiles * num_k_tiles + + # Process tiles in strided manner across SMs + for tile_idx in range(tbidx, total_output_tiles, NUM_SMS): + # Calculate tile indices + tile_n_idx = tile_idx % num_n_tiles + tile_k_idx = tile_idx // num_n_tiles + + # Calculate global offsets + n_offset = tile_n_idx * BLOCK_SIZE_N + k_offset = tile_k_idx * BLOCK_SIZE_K + + # Initialize accumulator for this output tile [N, K] + accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) + + # Process each group + M_end = 0 + for g in range(G): + # Get group boundaries + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + # Only process if group is non-empty + if m_size > 0: + # Process this group in chunks along the M dimension + for m_offset in range(0, m_size, BLOCK_SIZE_M): + # Calculate actual block size (handling boundary) + m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset) + + # Only process if we have actual work to do + if m_block_size > 0: + # Global offset for this chunk + m_global_offset = M_start + m_offset + + if USE_TMA_LOAD: + # Load input chunk [M_chunk, K] using TMA + x_block = tl._experimental_descriptor_load( + x_desc_ptr, + [m_global_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + c_dtype, + ) + + # Load grad_output chunk [M_chunk, N] using TMA + grad_output_block = tl._experimental_descriptor_load( + grad_output_desc_ptr, + [m_global_offset, n_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + c_dtype, + ) + + # Apply masks for valid regions + offs_m = tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < m_block_size + + # Zero out invalid elements + x_block = tl.where(m_mask[:, None], x_block, 0.0) + grad_output_block = tl.where( + m_mask[:, None], grad_output_block, 0.0 + ) + else: + # Manual load with bounds checking + offs_m = tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Create masks + m_mask = offs_m < m_block_size + n_mask = offs_n < N - n_offset + k_mask = offs_k < K - k_offset + + # Combined masks + mk_mask = m_mask[:, None] & k_mask[None, :] + mn_mask = m_mask[:, None] & n_mask[None, :] + + # Global offsets for loading + m_global_offs = m_global_offset + offs_m + + # Load x block [M_chunk, K] + x_block = tl.load( + x_desc_ptr + + m_global_offs[:, None] * K + + (k_offset + offs_k)[None, :], + mask=mk_mask, + other=0.0, + ) + + # Load grad_output block [M_chunk, N] + grad_output_block = tl.load( + grad_output_desc_ptr + + m_global_offs[:, None] * N + + (n_offset + offs_n)[None, :], + mask=mn_mask, + other=0.0, + ) + + # Compute partial contribution: grad_W += grad_Y.T @ X + # transpose grad_output for the matmul + contribution = tl.dot( + grad_output_block.to(tl.float32).T, # [N, M_chunk] + x_block.to(tl.float32), # [M_chunk, K] + ) + + # Accumulate + accumulator += contribution + + # Store the result + if USE_TMA_STORE: + # Store using TMA + tl._experimental_descriptor_store( + workspace, # TMA store descriptor + accumulator.to(c_dtype), + [n_offset, k_offset], + ) + else: + # Manual store with bounds checking + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Create masks for bounds checking + n_mask = offs_n < N - n_offset + k_mask = offs_k < K - k_offset + output_mask = n_mask[:, None] & k_mask[None, :] + + # Store the result + tl.store( + grad_weight_ptr + + (n_offset + offs_n)[:, None] * K + + (k_offset + offs_k)[None, :], + accumulator.to(c_dtype), + mask=output_mask, + ) + + +# ======== End Triton kernels ======== + +# ======== Triton wrapper functions ======== + +# ----- main forward pass wrapper ----- + + +def grouped_gemm_forward( + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + tma_size: int = 128, +) -> torch.Tensor: + """ + M*G style grouped GEMM with TMA and Float8 support. + # Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors. + + """ + if not CudaUtils.verify_tma(): + raise NotImplementedError("Grouped GEMM without TMA is not supported yet") + + G = m_sizes.shape[0] + + assert x.is_contiguous() + assert w.is_contiguous() + assert m_sizes.is_contiguous() + + # Total input size is now [M_total, K] where M_total is the sum of all group sizes + M_total, K = x.shape + N = w.shape[0] # N is now the same for all groups + + assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})" + + # Verify that all group sizes are multiples of ALIGN_SIZE_M + # This check is commented out because it will involve a GPU-CPU sync + # assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M" + + # Create output tensor with correct shape [M_total, N] + y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype) + + if M_total == 0: + return y + + NUM_SMS = CudaUtils.get_num_sms() + USE_TMA_LOAD = True + USE_TMA_STORE = True + USE_EPILOGUE_SUBTILING = False + + # TMA descriptor helper + desc_helper = None + desc_x = x + desc_w = w + workspace = None + + if USE_TMA_LOAD: + desc_helper = TmaDescriptorHelper(tma_size=tma_size) + desc_helper.init_tma_descriptor("x") + desc_helper.init_tma_descriptor("w") + desc_x = desc_helper.get_tma_descriptor_kernel_param("x") + desc_w = desc_helper.get_tma_descriptor_kernel_param("w") + + if USE_TMA_STORE: + workspace = torch.empty( + NUM_SMS * desc_helper.tma_size, + device=x.device, + dtype=torch.uint8, + ) + + def grid(META): + if USE_TMA_LOAD: + nonlocal desc_helper + desc_helper.fill_2d_tma_descriptor( + "x", + x.data_ptr(), + M_total, + K, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_K"], + x.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "w", + w.data_ptr(), + N, + K, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + w.element_size(), + ) + return (NUM_SMS,) + + M_BUCKET = triton.next_power_of_2(M_total) + + _kernel_mg_forward_hopper[grid]( + desc_x, + desc_w, + y, + workspace, + m_sizes, + G, + M_BUCKET, + N, + K, + NUM_SMS, + TMA_SIZE=tma_size, + USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING, + ) + + return y + + +# ======== Improved Backward ============= +def grouped_gemm_backward( + grad_output: torch.Tensor, + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + use_tma: bool = True, + tma_size: int = 128, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Unified backward pass for grouped GeMM with M*G grouping. + Uses optimized TMA-based implementations for both dx and dw when available. + + Args: + grad_output: Gradient of output, shape [M_total, N] + x: Input tensor from forward pass, shape [M_total, K] + w: Weight tensor from forward pass, shape [N, K] + m_sizes: Group sizes tensor, shape [G] + use_tma: Whether to try using TMA acceleration (if available) + tma_size: Size of TMA descriptor in bytes + + + Returns: + Tuple of gradients with respect to x and w: (grad_x, grad_w) + """ + logging.info("Starting unified grouped_gemm_backward") + + # do this once, seems expensive + NUM_SMS = CudaUtils.get_num_sms() + + # Basic validation + G = m_sizes.shape[0] + M_total, K_x = x.shape + M_grad, N = grad_output.shape + N_w, K_w = w.shape + + # Check dimensions + if K_x != K_w: + raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}") + if M_total != M_grad: + raise ValueError( + f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}" + ) + + # Check total M matches sum of group sizes + sum_m_sizes = m_sizes.sum().item() + if M_total != sum_m_sizes: + raise ValueError( + f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" + ) + + # Make sure inputs are contiguous + grad_output = grad_output.contiguous() + x = x.contiguous() + w = w.contiguous() + m_sizes = m_sizes.contiguous() + + # Check TMA support + can_use_tma = use_tma and CudaUtils.verify_tma() + if use_tma and not can_use_tma: + logging.info("TMA requested but not supported on this device") + use_tma = False + + # Compute grad_x using flat linear implementation + try: + logging.info(f"Computing grad_x with flat linear kernel") + + # Use TMA-optimized implementation + grad_x = grouped_gemm_dx_tma( + grad_output=grad_output, + w=w, + m_sizes=m_sizes, + num_sms=NUM_SMS, + tma_size=tma_size, + ) + + except Exception as e: + logging.error(f"Error in grad_x computation: {e}") + raise + + # Compute grad_w using flat linear style implementation + try: + logging.info(f"Computing grad_w with flat linear kernel") + + grad_w = grouped_gemm_dw_tma( + x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size + ) + except Exception as e: + logging.error(f"Error in grad_w computation: {e}") + raise + + return grad_x, grad_w + + +# ----- dx backward pass wrapper ----- + + +def grouped_gemm_dx_tma( + grad_output: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + num_sms: int = 132, + tma_size: int = 128, +) -> torch.Tensor: + """ + Optimized backward pass wrapper for computing gradient with respect to input (dx) + using TMA patterns similar to the forward pass. + + Args: + grad_output: Gradient of output, shape [M_total, N] + w: Weight tensor, shape [N, K] + m_sizes: Group sizes tensor, shape [G] + tma_size: Size of TMA descriptor + # using_fp8: Whether to use FP8 quantization + # grad_output_scale: Scale for grad_output in FP8 mode + # w_scale: Scale for w in FP8 mode + + Returns: + grad_x: Gradient with respect to x, shape [M_total, K] + """ + """ + Optimized backward pass for computing gradient with respect to input (dx) + using TMA patterns similar to the forward pass. + + Args: + grad_output: Gradient of output, shape [M_total, N] + w: Weight tensor, shape [N, K] + m_sizes: Group sizes tensor, shape [G] + tma_size: Size of TMA descriptor + using_fp8: Whether to use FP8 quantization + # grad_output_scale: Scale for grad_output in FP8 mode + # w_scale: Scale for w in FP8 mode + + Returns: + grad_x: Gradient with respect to x, shape [M_total, K] + """ + if not CudaUtils.verify_tma(): + raise NotImplementedError("Optimized dx computation requires TMA support") + + G = m_sizes.shape[0] + + assert grad_output.is_contiguous() + assert w.is_contiguous() + assert m_sizes.is_contiguous() + + M_total, N_grad = grad_output.shape + N_w, K = w.shape + + # Check dimensions + assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})" + + # Verify that the sum of m_sizes matches M_total + sum_m_sizes = m_sizes.sum().item() + assert ( + M_total == sum_m_sizes + ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" + + # Create output tensor (grad_x) with shape [M_total, K] + grad_x = torch.empty( + (M_total, K), device=grad_output.device, dtype=grad_output.dtype + ) + + NUM_SMS = num_sms # CudaUtils.get_num_sms() + USE_TMA_LOAD = True + USE_TMA_STORE = True + + # Set up TMA descriptors + desc_helper = TmaDescriptorHelper(tma_size=tma_size) + desc_helper.init_tma_descriptor("grad_output") + desc_helper.init_tma_descriptor("w") + desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output") + desc_w = desc_helper.get_tma_descriptor_kernel_param("w") + + # Allocate workspace for TMA store + workspace = torch.empty( + NUM_SMS * desc_helper.tma_size, + device=grad_output.device, + dtype=torch.uint8, + ) + + def grid(META): + # Fill TMA descriptors with appropriate dimensions + desc_helper.fill_2d_tma_descriptor( + "grad_output", + grad_output.data_ptr(), + M_total, + N_grad, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_N"], + grad_output.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "w", + w.data_ptr(), + N_w, + K, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + w.element_size(), + ) + return (NUM_SMS,) + + M_BUCKET = triton.next_power_of_2(M_total) + + # Launch the flat linear kernel for computing grad_x + _kernel_mg_dx_tma[grid]( + desc_grad_output, + desc_w, + grad_x, + workspace, + m_sizes, + G, + M_BUCKET, + N_grad, # N dimension is now the reduction dimension + K, + NUM_SMS, + USE_TMA_LOAD, + USE_TMA_STORE, + TMA_SIZE=tma_size, + ) + + return grad_x + + +# ======== dw wrapper function ========== + + +def grouped_gemm_dw_tma( + x: torch.Tensor, + grad_output: torch.Tensor, + m_sizes: torch.Tensor, + num_sms: int = 132, + tma_size: int = 128, +) -> torch.Tensor: + """ + Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA. + For the forward pass Y = X @ W.T, the backward for weights is: + grad_W = grad_Y.T @ X + + Args: + x: Input tensor, shape [M_total, K] + grad_output: Gradient of output, shape [M_total, N] + m_sizes: Group sizes tensor, shape [G] + tma_size: Size of TMA descriptor in bytes + + + Returns: + grad_w: Gradient with respect to weights, shape [N, K] + """ + # Check TMA support + has_tma_support = CudaUtils.verify_tma() + + # Get group count + G = m_sizes.shape[0] + + # Ensure contiguous tensors + x = x.contiguous() + grad_output = grad_output.contiguous() + m_sizes = m_sizes.contiguous() + + # Get dimensions + M_total, K_x = x.shape + M_grad, N = grad_output.shape + + # Check dimensions + assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})" + + # Verify that the sum of m_sizes matches M_total + sum_m_sizes = m_sizes.sum().item() + assert ( + sum_m_sizes == M_total + ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" + + # Create output tensor (grad_w) with shape [N, K] + grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype) + + NUM_SMS = num_sms + + # TODO - hardcoded for now...but should set TMA flags based on hardware support + USE_TMA_LOAD = True # has_tma_support + USE_TMA_STORE = True # has_tma_support + + # Set up TMA descriptors or direct pointers + if USE_TMA_LOAD or USE_TMA_STORE: + desc_helper = TmaDescriptorHelper(tma_size=tma_size) + + if USE_TMA_LOAD: + desc_helper.init_tma_descriptor("x") + desc_helper.init_tma_descriptor("grad_output") + x_desc = desc_helper.get_tma_descriptor_kernel_param("x") + grad_output_desc = desc_helper.get_tma_descriptor_kernel_param( + "grad_output" + ) + else: + x_desc = x + grad_output_desc = grad_output + + if USE_TMA_STORE: + desc_helper.init_tma_descriptor("grad_w") + workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w") + else: + workspace = torch.empty(1, device=x.device, dtype=torch.uint8) + else: + # If not using TMA, just use the tensors directly + x_desc = x + grad_output_desc = grad_output + workspace = torch.empty(1, device=x.device, dtype=torch.uint8) + + # M_BUCKET for grid size + M_BUCKET = triton.next_power_of_2(M_total) + + # Define grid for kernel launch + def grid(META): + if USE_TMA_LOAD or USE_TMA_STORE: + + if USE_TMA_LOAD: + desc_helper.fill_2d_tma_descriptor( + "x", + x.data_ptr(), + M_total, + K_x, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_K"], + x.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "grad_output", + grad_output.data_ptr(), + M_total, + N, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_N"], + grad_output.element_size(), + ) + + if USE_TMA_STORE: + desc_helper.fill_2d_tma_descriptor( + "grad_w", + grad_w.data_ptr(), + N, + K_x, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + grad_w.element_size(), + ) + + # Return grid size - one block per SM for balanced work distribution + return (NUM_SMS,) + + # Launch the optimized kernel + _kernel_mg_dw_tma[grid]( + x_desc, + grad_output_desc, + grad_w, + workspace, + m_sizes, + G, + M_BUCKET, + N, + K_x, + NUM_SMS, + USE_TMA_LOAD, + USE_TMA_STORE, + TMA_SIZE=tma_size, + ) + + return grad_w + + +# ======== End Backwards Wrapper Functions ============= + +# ======== PyTorch wrapper functions ======== + + +class GroupedGEMM_mg(torch.autograd.Function): + """ + Autograd function for GroupedGEMM with M*G grouping. + Supports both standard and FP8 quantized operations. + """ + + @staticmethod + def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128): + """ + Forward pass of GroupedGEMM. + + Args: + x: Input tensor, shape [M_total, K] + w: Weight tensor, shape [N, K] + m_sizes: Tensor of shape [G] containing the size of each group + use_tma: Whether to try using TMA acceleration (if available) + tma_size: Size of TMA descriptor in bytes + using_fp8: Whether to use FP8 quantization + + Returns: + Output tensor, shape [M_total, N] + """ + + # Use regular forward without quantization + output = grouped_gemm_forward( + x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False + ) + + # Save inputs and parameters for backward pass + ctx.save_for_backward(x, w, m_sizes) + ctx.use_tma = use_tma + ctx.tma_size = tma_size + + ctx.save_for_backward(x, w, m_sizes) + + return output + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass of M*G GroupedGEMM. + + Args: + grad_output: Gradient of output, shape [M_total, N] + + Returns: + Tuple of gradients: + - grad_x: Gradient with respect to x, shape [M_total, K] + - grad_w: Gradient with respect to w, shape [N, K] + - None: Gradient with respect to m_sizes (not differentiable) + - None: Gradient with respect to use_tma (not differentiable) + - None: Gradient with respect to tma_size (not differentiable) + + """ + # Retrieve saved tensors and parameters + + x, w, m_sizes = ctx.saved_tensors + + use_tma = ctx.use_tma + tma_size = ctx.tma_size + + # Compute gradients using the unified implementation + grad_x, grad_w = grouped_gemm_backward( + grad_output=grad_output, + x=x, + w=w, + m_sizes=m_sizes, + use_tma=use_tma, + tma_size=tma_size, + ) + + # Return gradients for all inputs (None for non-differentiable parameters) + return grad_x, grad_w, None, None + + +def mg_grouped_gemm( + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + use_tma: bool = True, + tma_size: int = 128, + using_fp8: bool = False, +) -> torch.Tensor: + """ + Unified differentiable grouped GEMM operation for M*G grouped GEMM. + Supports both standard precision and FP8 quantized operations. + + Args: + x: Input tensor, shape [M_total, K] + w: Weight tensor, shape [N, K] + m_sizes: Tensor of shape [G] containing the size of each group + use_tma: Whether to try using TMA acceleration (if available) + tma_size: Size of TMA descriptor in bytes + using_fp8: Whether to use FP8 quantization + + Returns: + Output tensor, shape [M_total, N] + """ + return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8) diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py new file mode 100644 index 0000000000000000000000000000000000000000..2429432d756ae4d5bb6f91a6108c7ba8a4b9c627 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging +import unittest +from typing import Tuple + +import torch +import torch.nn as nn + +from mg_grouped_gemm import grouped_gemm_forward + + +class TestMG_GroupedGEMM(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(2020) + + def _run_grouped_gemm_test( + self, + shape: Tuple[int, int, int, int], + device: torch.device, + dtype: torch.dtype = torch.bfloat16, + atol: float = 1e-5, + rtol: float = 1.6e-2, + ) -> None: + G, M, N, K = shape + # In M*G grouping, input is [M*G, K] and weights are [N*G, K] + a = torch.randn(M * G, K, dtype=dtype, device=device) + b = torch.randn(N * G, K, dtype=dtype, device=device) + + # Create equal-sized groups for simplicity + m_size = M + m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32) + + result = grouped_gemm_forward(a, b, m_sizes) + self.assertTrue(result.shape == (M * G, N)) + + expected_result = torch.zeros(M * G, N, dtype=dtype, device=device) + m_start = 0 + for g in range(G): + m_end = m_start + m_sizes[g] + b_slice = b[N * g : N * (g+1), :] + expected_result[m_start:m_end, :] = a[m_start:m_end, :] @ b_slice.T + m_start = m_end + + # Convert result to match input dtype if needed + result = result.to(dtype) + torch.testing.assert_close(result, expected_result, atol=atol, rtol=rtol) + + def test_MG_grouped_gemm_bf16(self) -> None: + for G in (1, 4, 16): + for M in (128, 512, 1024): + print(f"Testing BF16 M*G GroupGeMM with G={G}, M={M}") + self._run_grouped_gemm_test( + (G, M, 1024, 1024), + torch.device("cuda"), + dtype=torch.bfloat16, + atol=1e-5, + rtol=1.6e-2, + ) + + def test_MG_grouped_gemm_deepseek_shapes(self) -> None: + """Test with shapes from Deepseek model.""" + deepseek_shapes = [ + (4, 2048, 4096, 7168), # G, M, N, K + (4, 2048, 7168, 2048), + (8, 512, 4096, 7168), + (8, 512, 7168, 2048), + ] + + device = torch.device("cuda") + + for shape in deepseek_shapes: + G, M, N, K = shape + print(f"Testing BF16 M*G Deepseek shape: G={G}, M={M}, N={N}, K={K}") + self._run_grouped_gemm_test( + shape, device, dtype=torch.bfloat16, atol=1e-5, rtol=1.6e-2 + ) diff --git a/torchtitan/experiments/llama4/__init__.py b/torchtitan/experiments/llama4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0907e1892fa3840be81e7eefe12047d2e1cf1661 --- /dev/null +++ b/torchtitan/experiments/llama4/__init__.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer +from torchtitan.models.llama3 import pipeline_llama +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .infra.parallelize_llama import parallelize_llama +from .model.args import TransformerModelArgs +from .model.model import Transformer + +__all__ = [ + "TransformerModelArgs", + "Transformer", + "llama4_configs", +] + + +llama4_configs = { + "debugmodel": TransformerModelArgs( + dim=256, + n_layers=8, + n_heads=16, + rope_theta=500000, + ), + "17bx16e": TransformerModelArgs( + dim=5120, + n_layers=48, + n_heads=40, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=2048, + rope_theta=500000, + num_experts=16, + interleave_moe_layer_step=1, + ), + "17bx128e": TransformerModelArgs( + dim=5120, + n_layers=48, + n_heads=40, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=2048, + rope_theta=500000, + num_experts=128, + ), +} + + +register_train_spec( + TrainSpec( + name="llama4", + cls=Transformer, + config=llama4_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_tiktoken_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/experiments/llama4/infra/expert_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..63945e8cd6a3f9509ca34c779b09a2f2f7581c2f --- /dev/null +++ b/torchtitan/experiments/llama4/infra/expert_parallel.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from functools import partial +from typing import Optional, Tuple + +import torch.nn as nn +from torch.distributed.tensor import ( + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + Partial, + Replicate, + Shard, +) +from torch.distributed.tensor.parallel import ParallelStyle +from torch.distributed.tensor.placement_types import Placement + + +# implementation of Tensor Parallel on the non-shared experts in MoE +class TensorParallel(ParallelStyle): + def __init__( + self, + *, + input_layouts: Optional[Tuple[Optional[Placement]]] = None, + output_layout: Optional[Placement] = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layouts = input_layouts or (Replicate(), None) + self.output_layout = output_layout or Partial() + self.desired_input_layouts = (Replicate(), None) + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): + # TODO: figure out dynamo support for instance method and switch this to instance method + + # annotate module input placements/sharding with input_layouts + input_tensor, input_layout, desired_input_layout = ( + inputs[0], + input_layouts[0], + desired_input_layouts[0], + ) + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, (input_layout,), run_check=False + ) + + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute( + placements=(desired_input_layout,), async_op=True + ) + return (input_tensor, *inputs[1:]) + + def _partition_fn(self, name, module, device_mesh): + module.register_parameter( + "w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)])) + ) # Column-wise sharding + module.register_parameter( + "w2", + nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(1)])), + ) # Row-wise sharding + module.register_parameter( + "w3", + nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(2)])), + ) # Column-wise sharding + + @staticmethod + def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): + if outputs.placements != (output_layout,): + outputs = outputs.redistribute(placements=(output_layout,), async_op=True) + # back to local tensor + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + self._partition_fn, + partial( + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts + ), + partial(self._prepare_output_fn, self.output_layout, self.use_local_output), + ) + + +# NOTE: This is to achieve replicate computation on the gate module in the MoE router. +# It does nothing other than (1) setting the module parameters as DTensors on the given mesh +# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back. +# TODO: The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh, +# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation. +class NoParallel(ParallelStyle): + def __init__( + self, + *, + input_layout: Optional[Placement] = None, + output_layout: Optional[Placement] = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layout = input_layout or Replicate() + self.output_layout = output_layout or Replicate() + self.desired_input_layout = Replicate() + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh): + # annotate module input placements/sharding with input_layouts + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, (input_layout,), run_check=False + ) + + if input_layout != desired_input_layout: + input_tensor = input_tensor.redistribute( + placements=(desired_input_layout,), async_op=True + ) + return (input_tensor, *inputs[1:]) + + @staticmethod + def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): + if outputs.placements != (output_layout,): + outputs = outputs.redistribute(placements=(output_layout,), async_op=True) + # back to local tensor + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + None, + partial( + self._prepare_input_fn, self.input_layout, self.desired_input_layout + ), + partial(self._prepare_output_fn, self.output_layout, self.use_local_output), + ) diff --git a/torchtitan/experiments/llama4/model/__pycache__/args.cpython-312.pyc b/torchtitan/experiments/llama4/model/__pycache__/args.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e41de97e6bf649bb4e4bc8775658a0518e7db4d9 Binary files /dev/null and b/torchtitan/experiments/llama4/model/__pycache__/args.cpython-312.pyc differ diff --git a/torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc b/torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01040b2a10d5cf39ccc18a7dfd722d5e0cdc487e Binary files /dev/null and b/torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc differ diff --git a/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..99eb36ac6ffa8e546d8895358978e937088f7ee1 --- /dev/null +++ b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py @@ -0,0 +1,545 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import json +import math +import os +import pprint +import time +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Optional + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor, Shard +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset +from torchtitan.components.checkpoint import MODEL +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import init_logger, logger +from torchtitan.train import Trainer + + +def extract_layer_number(s): + import re + + match = re.search(r"layers\.(\d+)", s) + if match: + return int(match.group(1)) + else: + return None + + +def convert_to_titan_fqns(fqn: str) -> list[str]: + # From the stored checkpoint keys to TorchTitan keys. + if "language_model." not in fqn: + # TODO: Not support video model yet + return [fqn] + + layer = extract_layer_number(fqn) + + if layer is None: + if "embed_tokens.weight" in fqn: + return ["tok_embeddings.weight"] + elif "norm.weight" in fqn: + return ["norm.weight"] + elif "lm_head.weight" in fqn: + return ["output.weight"] + else: + raise ValueError(f"Unknown fqn {fqn}") + + if "feed_forward.experts.down_proj" in fqn: + return [f"layers.{layer}.moe.experts.w2"] + elif "feed_forward.experts.gate_up_proj" in fqn: + return [f"layers.{layer}.moe.experts.w1", f"layers.{layer}.moe.experts.w3"] + elif "feed_forward.router.weight" in fqn: + return [f"layers.{layer}.moe.router.gate.weight"] + elif "feed_forward.shared_expert.down_proj.weight" in fqn: + return [f"layers.{layer}.moe.shared_expert.w2"] + elif "feed_forward.shared_expert.gate_proj.weight" in fqn: + return [f"layers.{layer}.moe.shared_expert.w3"] + elif "feed_forward.shared_expert.up_proj.weight" in fqn: + return [f"layers.{layer}.moe.shared_expert.w1"] + elif "input_layernorm.weight" in fqn: + return [f"layers.{layer}.ffn_norm.weight"] + elif "self_attn.k_proj" in fqn: + return [f"layers.{layer}.attention.wk.weight"] + elif "self_attn.o_proj" in fqn: + return [f"layers.{layer}.attention.wo.weight"] + elif "self_attn.q_proj" in fqn: + return [f"layers.{layer}.attention.wq.weight"] + elif "self_attn.v_proj" in fqn: + return [f"layers.{layer}.attention.wv.weight"] + elif "post_attention_layernorm.weight" in fqn: + return [f"layers.{layer}.attention_norm.weight"] + else: + raise ValueError(f"Unknown fqn {fqn}") + + +def convert_to_hf_shape(fqn: str, titan_fqns: list[str], dtensor: DTensor) -> list[str]: + if "feed_forward.experts.gate_up_proj" in fqn: + assert len(titan_fqns) == 2 + shape = dtensor.shape + return torch.Size(list(shape[:-1]) + [shape[-1] * 2]) + elif "shared_expert" in fqn: + s = dtensor.shape + # TODO: this is not right but I have to do this to load the checkpoint. + return torch.Size((s[2], s[1])) + return dtensor.shape + + +def convert_to_titan_tensors(fqn: str, full_tensor: torch.Tensor) -> torch.Tensor: + if "feed_forward.experts.gate_up_proj" in fqn: + full_tensors = full_tensor.chunk(2, dim=-1) + elif "shared_expert" in fqn: + # TODO: this is not right but I have to do this to load the checkpoint. + full_tensor = full_tensor.transpose(1, 0) + full_tensors = [full_tensor.unsqueeze(0)] + else: + full_tensors = [full_tensor] + return full_tensors + + +@dataclass +class _Assignment: + loader_id: int + filename: str + fqns: list[str] + shapes: list[torch.Size] + dtypes: list[torch.dtype] + + +@dataclass +class _AssignmentRound: + loader_assignments: dict[int, _Assignment] # List of assignments for each loader + + +@dataclass +class TensorMetadata: + fqn: str + shape: torch.Size + dtype: torch.dtype + + +class CheckpointConverter: + def __init__( + self, + process_group: dist.ProcessGroup, + path: str, + token: Optional[str] = None, + loader_every_n_ranks: int = 8, + ) -> None: + self.path = path + self.token = token + self.pg = process_group + self.my_rank = dist.get_rank(self.pg) + + self.loader_every_n_ranks = loader_every_n_ranks + self.loader_id = self.my_rank // loader_every_n_ranks + self.should_load = self.my_rank % loader_every_n_ranks == 0 + self.total_loader = dist.get_world_size(self.pg) // loader_every_n_ranks + + self.titan_fqn_to_stored_fqn: dict[str, str] = {} + self.stored_fqn_to_titan_fqn: dict[str, list[str]] = {} + self.total_send_bytes = 0 + self.total_recv_bytes = 0 + + def convert(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + begin = time.time() + self._load_metadata() + self._create_fqn_mappings(state_dict) + rounds = self._get_load_assignments(state_dict) + + logger.info(f"Got {len(rounds)} rounds of assignments.") + for idx, assignments in enumerate(rounds): + loader_assignments = assignments.loader_assignments + loaded_state_dict = None + # Let each loader to load its own data and move to its GPU. + logger.info(f"Loading round {idx}") + for i in range(self.total_loader): + # This loader doesn't have any loading assignment for this round. + if i not in loader_assignments: + continue + # This rank is not the loader + if i != self.loader_id or not self.should_load: + continue + loaded_state_dict = self._load_round(loader_assignments[i]) + + torch.cuda.synchronize() + logger.info(f"Loading round {idx} finished") + for i in range(self.total_loader): + if i not in loader_assignments: + continue + + logger.info(f"Resharding round {idx} loader {i} data. ") + if i == self.loader_id and self.should_load: + # This rank is the loader. It needs to send the loaded data to + # the other ranks. + assert loaded_state_dict is not None + results = self._reshard_send( + loader_assignments[i], loaded_state_dict + ) + else: + results = self._reshard_receive(loader_assignments[i], state_dict) + torch.cuda.synchronize() + + logger.info(f"Communication round {idx} loader {i} is done.") + self._reshard(results, state_dict) + logger.info(f"Resharding round {idx} loader {i} is done.") + self._reshard(results, state_dict) + torch.cuda.synchronize() + + dist.barrier() + torch.cuda.synchronize() + logger.info(f"Checkpoint conversion took {time.time() - begin:.2f} seconds.") + logger.info(f"Total send bytes: {self.total_send_bytes / 1e9:.2f} GB") + logger.info(f"Total recv bytes: {self.total_recv_bytes / 1e9:.2f} GB") + return state_dict + + def _load_metadata(self) -> None: + metadata_path = os.path.join(self.path, "model.safetensors.index.json") + with open(metadata_path, "r") as f: + self.metadata = json.load(f)["weight_map"] + + def _create_fqn_mappings(self, state_dict: dict[str, torch.Tensor]) -> None: + if not self.metadata: + return + + # Create the mapping from the stored checkpoint keys to TorchTitan keys. + for fqn in list(self.metadata.keys()): + titan_fqns = convert_to_titan_fqns(fqn) + # We don't know how to process _extra_state + if "_extra_state" in fqn: + self.metadata.pop(fqn) + continue + + if titan_fqns[0] not in state_dict: + for titan_fqn in titan_fqns: + assert titan_fqn not in state_dict + self.metadata.pop(fqn) + continue + + self.stored_fqn_to_titan_fqn[fqn] = titan_fqns + for titan_fqn in titan_fqns: + self.titan_fqn_to_stored_fqn[titan_fqn] = fqn + + torchtitan_extra = sorted( + list(set(state_dict.keys()) - set(self.titan_fqn_to_stored_fqn.keys())) + ) + converted_extra = sorted( + list(set(self.titan_fqn_to_stored_fqn.keys()) - set(state_dict.keys())) + ) + assert set(state_dict.keys()) == set(self.titan_fqn_to_stored_fqn.keys()), ( + f"{pprint.pformat(torchtitan_extra)}", + f"{pprint.pformat(converted_extra)}", + ) + + def _get_load_assignments( + self, state_dict: dict[str, Any] + ) -> list[_AssignmentRound]: + if self.my_rank == 0: + filename_to_metas = defaultdict(list) + for fqn, filename in self.metadata.items(): + titan_fqns = self.stored_fqn_to_titan_fqn[fqn] + shape = convert_to_hf_shape(fqn, titan_fqns, state_dict[titan_fqns[0]]) + meta = TensorMetadata( + fqn=fqn, + shape=shape, + # TODO: don't hardcode this + dtype=torch.bfloat16, + ) + filename_to_metas[filename].append(meta) + + loader_filename_to_metas = [{} for _ in range(self.total_loader)] + for idx, (filename, metas) in enumerate(filename_to_metas.items()): + loader_id = idx % self.total_loader + loader_filename_to_metas[loader_id][filename] = metas + + rounds = [] + while any(len(remain) > 0 for remain in loader_filename_to_metas): + round_assignment = _AssignmentRound(loader_assignments={}) + for loader_id in range(self.total_loader): + if not loader_filename_to_metas[loader_id]: + continue + + filename, metas = loader_filename_to_metas[loader_id].popitem() + round_assignment.loader_assignments[loader_id] = _Assignment( + filename=filename, + fqns=[meta.fqn for meta in metas], + shapes=[meta.shape for meta in metas], + dtypes=[meta.dtype for meta in metas], + loader_id=loader_id, + ) + + rounds.append(round_assignment) + + object_list: list[Any] = [ + rounds, + self.titan_fqn_to_stored_fqn, + self.stored_fqn_to_titan_fqn, + ] + else: + object_list = [None, None, None] + + dist.broadcast_object_list(object_list, src=0, group=self.pg) + rounds = object_list[0] + self.titan_fqn_to_stored_fqn = object_list[1] + self.stored_fqn_to_titan_fqn = object_list[2] + return rounds + + def _load_round(self, assignment: _Assignment) -> dict[str, Any]: + from safetensors.torch import load_file as hf_load_file + + path = os.path.join(self.path, assignment.filename) + state_dict = hf_load_file(path) + return { + k: v.to(device="cuda") + for k, v in state_dict.items() + if k in assignment.fqns + } + + def _reshard_send( + self, + assignment: _Assignment, + loaded_state_dict: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + flatten_tensors = [t.flatten() for t in loaded_state_dict.values()] + flatten_tensor = torch.concat(flatten_tensors) + assert self.loader_id == assignment.loader_id + rank = self.loader_id * self.loader_every_n_ranks + assert rank == self.my_rank + logger.info( + f"Sending {assignment.filename} from {rank} {self.loader_id} " + f"{flatten_tensor.shape=} {flatten_tensor.dtype=} {loaded_state_dict.keys()=}." + ) + logger.info(f"Sending {assignment}") + dist.broadcast(flatten_tensor, src=rank, group=self.pg) + self.total_send_bytes += flatten_tensor.numel() * flatten_tensor.element_size() + return loaded_state_dict + + def _reshard_receive( + self, assignment: _Assignment, state_dict: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + + flatten_tensor = torch.empty( + sum(math.prod(s) for s, d in zip(assignment.shapes, assignment.dtypes)), + dtype=assignment.dtypes[0], + device="cuda", + ) + rank = assignment.loader_id * self.loader_every_n_ranks + logger.info( + f"Receiving {assignment.filename} from {rank} " + f"{flatten_tensor.shape=} {flatten_tensor.dtype=}" + ) + logger.info(f"Receiving {assignment}") + dist.broadcast(flatten_tensor, src=rank, group=self.pg) + self.total_recv_bytes += flatten_tensor.numel() * flatten_tensor.element_size() + + ret: dict[str, torch.Tensor] = {} + loc = 0 + for fqn, shape, dtype in zip( + assignment.fqns, assignment.shapes, assignment.dtypes + ): + n_ele = math.prod(shape) + ret[fqn] = flatten_tensor[loc : loc + n_ele].view(shape) + loc += n_ele + return ret + + def _reshard( + self, + result: dict[str, torch.Tensor], + state_dict: dict[str, torch.Tensor], + ) -> None: + def _inplace_copy(fqn: str, full_tensors: list[torch.Tensor]): + titan_fqns = self.stored_fqn_to_titan_fqn[fqn] + assert len(titan_fqns) == len(full_tensors) + for titan_fqn, full_tensor in zip(titan_fqns, full_tensors): + dtensor = state_dict[titan_fqn] + assert isinstance(dtensor, DTensor) + assert dtensor.shape == full_tensor.shape, ( + (fqn, titan_fqn), + dtensor.shape, + full_tensor.shape, + ) + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, dtensor.device_mesh, dtensor.placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] + logger.debug( + f"Copying {titan_fqn} with {slices=} {dtensor._local_tensor.shape=} " + f"{shape=} {offset=} {self.my_rank=} {dtensor.shape=} {full_tensor.shape=} " + f"{dtensor.placements=} {dtensor.device_mesh=} " + ) + dtensor.to_local().copy_(full_tensor[slices].to(dtensor.dtype)) + + for fqn, full_tensor in result.items(): + full_tensors = convert_to_titan_tensors(fqn, full_tensor) + _inplace_copy(fqn, full_tensors) + + +def _create_verified_state_dict( + pg: dist.ProcessGroup, mesh: DeviceMesh +) -> dict[str, torch.Tensor]: + placements = [Shard(0)] + state_dict = { + "vision_model.vision_adapter.mlp.fc1.weight": torch.rand( + 4096, 5632, device="cuda", dtype=torch.bfloat16 + ), + "vision_model.vision_adapter.mlp.fc2.weight": torch.rand( + 4096, 4096, device="cuda", dtype=torch.bfloat16 + ), + "language_model.model.layers.3.feed_forward.experts.gate_up_proj": torch.rand( + 16, 5120, 16384, device="cuda", dtype=torch.bfloat16 + ), + } + return {k: distribute_tensor(v, mesh, placements) for k, v in state_dict.items()} + + +def _verify_state_dict( + state_dict: dict[str, torch.Tensor], path: str, rank: int +) -> None: + metadata_path = os.path.join(path, "model.safetensors.index.json") + with open(metadata_path, "r") as f: + metadata = json.load(f)["weight_map"] + all_filenames = set() + for fqn, tensor in state_dict.items(): + filename = os.path.join(path, metadata[fqn]) + all_filenames.add(filename) + + stored_state_dict = {} + from safetensors.torch import load_file as hf_load_file + + for filename in all_filenames: + _sd = hf_load_file(filename) + for k in list(_sd.keys()): + if k not in state_dict: + _sd.pop(k) + else: + stored_state_dict[k] = _sd[k] + + def read_and_verify_tensor(fqn: str, dtensor: DTensor) -> None: + logger.info(f"Verifying {fqn} {dtensor.shape=} {dtensor.placements=} ") + stored_tensor = stored_state_dict[fqn] + full_tensor = dtensor.full_tensor() + logger.info(f"Gather {fqn} {full_tensor.shape} completely.") + + if rank > 0: + return + + stored_tensor = stored_tensor.to(device="cuda") + logger.info(f"Move to GPU {fqn} completely.") + + assert stored_tensor.shape == full_tensor.shape, fqn + assert stored_tensor.dtype == full_tensor.dtype, fqn + assert stored_tensor.device == full_tensor.device, fqn + assert torch.allclose(stored_tensor, full_tensor), fqn + + for k, v in state_dict.items(): + read_and_verify_tensor(k, v) + + +if __name__ == "__main__": + init_logger() + config = JobConfig() + config.parser.add_argument( + "--checkpoint.convert_path", + type=str, + default="", + help="""Specify the path of the target checkpoint to convert.""", + ) + config.parser.add_argument( + "--checkpoint.convert_hf_token", + type=str, + default="", + help="""Specify hf token.""", + ) + config.parser.add_argument( + "--checkpoint.convert_load_every_n_ranks", + type=int, + default=8, + help=""" + Specify the interval at which ranks are assigned to load checkpoints. + + For example, if this number is 4, then ranks 0, 4, 8, ... will load the + checkpoint. Each loader is responsible for loading one file. If there + are more loaders than files, only the first few loaders will be assigned + to load the checkpoint. The default value is 8. + """, + ) + config.parser.add_argument( + "--checkpoint.fake_model", + action="store_true", + help="""If true, the model will be fake.""", + ) + config.parse_args() + assert config.checkpoint.convert_path != "" + + trainer: Optional[Trainer] = None + + try: + trainer = Trainer(config) + if os.path.exists(trainer.checkpointer.folder): + raise RuntimeError( + "The checkpoint folder already exists. Abort to avoid overwriting " + f"the checkpoint. {trainer.checkpointer.folder=}" + ) + if config.checkpoint.fake_model: + state_dict = _create_verified_state_dict( + trainer.world_mesh.get_group(), trainer.world_mesh + ) + else: + state_dict = trainer.checkpointer.states[MODEL].state_dict() + + size = 0 + for v in state_dict.values(): + size += v.numel() * v.element_size() + logger.info(f"Total size of the model: {size / 1e9:.2f} GB") + + # Do not support PP yet, we will need to iterate over the PP dimension and + # extract the corresponding state_dict and device_mesh. + if "freqs_cis" in state_dict: + state_dict.pop("freqs_cis") + + # Our tokenizer is not up-to-date yet. + tok_embeddings_weight = state_dict.pop("tok_embeddings.weight") + output_weight = state_dict.pop("output.weight") + state_dict = CheckpointConverter( + process_group=trainer.world_mesh.get_group(), + path=config.checkpoint.convert_path, + token=config.checkpoint.convert_hf_token, + loader_every_n_ranks=config.checkpoint.convert_load_every_n_ranks, + ).convert(state_dict) + state_dict["tok_embeddings.weight"] = tok_embeddings_weight + state_dict["output.weight"] = output_weight + + class DummyModel: + def __init__(self, state_dict: dict[str, torch.Tensor]) -> None: + self._state_dict = state_dict + + def state_dict(self) -> dict[str, torch.Tensor]: + return self._state_dict + + if config.checkpoint.fake_model: + begin = time.time() + _verify_state_dict( + state_dict, + config.checkpoint.convert_path, + trainer.world_mesh.get_rank(), + ) + dist.barrier() + logger.info(f"Verifies state_dict {time.time() - begin}.") + else: + # oh, this is pretty bad, when can we get rid of the freqs_cis issue? + state_dict["freqs_cis"] = None + trainer.checkpointer.states[MODEL] = DummyModel(state_dict) + trainer.checkpointer.model_weights_only = True + trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype + trainer.checkpointer.save(curr_step=0, force=True) + time.sleep(2) + finally: + pass diff --git a/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py b/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..7756afe3de1527f469a38fc6a0bdc6c62eaa2526 --- /dev/null +++ b/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py @@ -0,0 +1,536 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import os +import time +from dataclasses import dataclass +from typing import Any, Optional + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor, Shard +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset +from torchtitan.components.checkpoint import MODEL +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import init_logger, logger +from torchtitan.train import Trainer + +# Sharding dims for MP checkpoints + +column_parallel = [ + "tok_embeddings", + "wq", + "wk", + "wv", + "wqkv", + "w_in_shared_FD", + "w_out_eF_D", + "w_swiglu_FD", + "output", + "_linear", + "c_fc", + "vision_projection", +] + +row_parallel = [ + "wo", + "w_out_shared_DF", + "w_in_eD_F", + "moe_w_swiglu_eD_F", + "c_proj", +] + + +def convert_to_titan_fqns(fqn: str) -> list[str]: + # From the stored checkpoint keys to TorchTitan keys. + if "wqkv" in fqn and "layer_norm_weight" not in fqn: + ret = [] + for k in ("wq", "wk", "wv"): + ret.append(fqn.replace("wqkv", k)) + return ret + return [fqn] + + +def get_shard_dim(fqn: str) -> Optional[int]: + if "bias" in fqn: + # Some bias params are still sharded + if "resblocks" in fqn: + for k in ("wq", "wk", "wv", "c_fc"): + if k in fqn: + return 0 + return None + elif any([x in fqn for x in column_parallel]): + return 0 + elif any([x in fqn for x in row_parallel]): + return 1 + else: + return None + + +def split_fused_qkv(shards: list[torch.Tensor]) -> tuple[torch.Tensor, ...]: + qkvs = [torch.split(shard, [640, 128, 128]) for shard in shards] + q = torch.cat([qkv[0] for qkv in qkvs], dim=0) + k = torch.cat([qkv[1] for qkv in qkvs], dim=0) + v = torch.cat([qkv[2] for qkv in qkvs], dim=0) + return q, k, v + + +@dataclass +class _Assignment: + loader_id: int + filename: str + fqns: tuple[str, ...] + shapes: tuple[torch.Size, ...] + dtypes: tuple[torch.dtype, ...] + + +@dataclass +class _AssignmentRound: + loader_assignments: dict[int, _Assignment] # List of assignments for each loader + + +class CheckpointConverter: + TOTAL_SHARDS = 8 + + def __init__( + self, + process_group: dist.ProcessGroup, + path: str, + loader_every_n_ranks: int = 8, + ) -> None: + self.path = path + self.pg = process_group + self.my_rank = dist.get_rank(self.pg) + self.loader_every_n_ranks = loader_every_n_ranks + self.loader_id = self.my_rank // loader_every_n_ranks + self.should_load = ( + self.my_rank % loader_every_n_ranks == 0 + and self.loader_id < CheckpointConverter.TOTAL_SHARDS + ) + self.total_loader = CheckpointConverter.TOTAL_SHARDS + self.titan_fqn_to_stored_fqn: dict[str, str] = {} + self.stored_fqn_to_titan_fqn: dict[str, list[str]] = {} + self.total_send_bytes = 0 + self.total_recv_bytes = 0 + + def convert(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + begin = time.time() + self._load_metadata() + self._create_fqn_mappings(state_dict) + rounds = self._get_load_assignments(state_dict) + + for assignments in rounds: + loader_assignments = assignments.loader_assignments + loaded_state_dict = None + # Let each loader to load its own data and move to its GPU. + for i in range(self.total_loader): + # This loader doesn't have any loading assignment for this round. + if i not in loader_assignments: + continue + # This rank is not the loader + if i != self.loader_id or not self.should_load: + continue + loaded_state_dict = self._load_round(loader_assignments[i]) + + results = [] + for i in range(self.total_loader): + if i not in loader_assignments: + continue + + if i == self.loader_id and self.should_load: + # This rank is the loader. It needs to send the loaded data to + # the other ranks. + assert loaded_state_dict is not None + results.append( + self._reshard_send(loader_assignments[i], loaded_state_dict) + ) + else: + results.append( + self._reshard_receive(loader_assignments[i], state_dict) + ) + + self._reshard(results, state_dict) + + torch.cuda.synchronize() + logger.info(f"Checkpoint conversion took {time.time() - begin:.2f} seconds.") + logger.info(f"Total send bytes: {self.total_send_bytes / 1e9:.2f} GB") + logger.info(f"Total recv bytes: {self.total_recv_bytes / 1e9:.2f} GB") + return state_dict + + def _get_file_path(self, loader_id: int) -> str: + return os.path.join(self.path, f"consolidated.0{loader_id}.pth") + + def _load_metadata(self) -> None: + if not self.should_load: + self.read_dict = {} + return + self.read_dict = torch.load( + self._get_file_path(self.loader_id), + mmap=True, + weights_only=False, + ) + + def _create_fqn_mappings(self, state_dict: dict[str, torch.Tensor]) -> None: + if not self.read_dict: + return + + # Create the mapping from the stored checkpoint keys to TorchTitan keys. + for fqn in list(self.read_dict.keys()): + titan_fqns = convert_to_titan_fqns(fqn) + # We don't know how to process _extra_state + if "_extra_state" in fqn: + self.read_dict.pop(fqn) + continue + + if titan_fqns[0] not in state_dict: + for titan_fqn in titan_fqns: + assert titan_fqns[0] not in state_dict + self.read_dict.pop(fqn) + continue + self.stored_fqn_to_titan_fqn[fqn] = titan_fqns + for titan_fqn in titan_fqns: + self.titan_fqn_to_stored_fqn[titan_fqn] = fqn + + assert set(state_dict.keys()) == set(self.titan_fqn_to_stored_fqn.keys()), ( + set(state_dict.keys()) - set(self.titan_fqn_to_stored_fqn.keys()), + set(self.titan_fqn_to_stored_fqn.keys()) - set(state_dict.keys()), + ) + + def _get_load_assignments( + self, state_dict: dict[str, torch.Tensor] + ) -> list[_AssignmentRound]: + if self.my_rank == 0: + rounds: list[_AssignmentRound] = [] + size = 0 + fqns = [] + shapes = [] + dtypes = [] + + # All loader must load all the FQNs because the checkpoint is purely TP sharded. + all_keys = list(self.read_dict.keys()) + for fqn in all_keys: + fqns.append(fqn) + shapes.append(self.read_dict[fqn].shape) + dtypes.append(self.read_dict[fqn].dtype) + size += self.read_dict[fqn].numel() * self.read_dict[fqn].element_size() + if size < 1e9 and fqn != all_keys[-1]: + continue + + logger.info(f"Adding {fqns} to round {len(rounds)}") + round_assignment = _AssignmentRound(loader_assignments={}) + for loader_id in range(self.total_loader): + path = self._get_file_path(loader_id) + round_assignment.loader_assignments[loader_id] = _Assignment( + filename=path, + fqns=tuple(fqns), + shapes=tuple(shapes), + dtypes=tuple(dtypes), + loader_id=loader_id, + ) + rounds.append(round_assignment) + size = 0 + fqns.clear() + shapes.clear() + dtypes.clear() + + object_list: list[Any] = [ + rounds, + self.titan_fqn_to_stored_fqn, + self.stored_fqn_to_titan_fqn, + ] + else: + object_list = [None, None, None] + + dist.broadcast_object_list(object_list, src=0, group=self.pg) + rounds = object_list[0] + self.titan_fqn_to_stored_fqn = object_list[1] + self.stored_fqn_to_titan_fqn = object_list[2] + return rounds + + def _load_round(self, assignment: _Assignment) -> dict[str, torch.Tensor]: + ret = {} + assert self.read_dict + for fqn in assignment.fqns: + ret[fqn] = self.read_dict[fqn].to(device="cuda") + return ret + + def _reshard_send( + self, + assignment: _Assignment, + loaded_state_dict: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + flatten_tensors = [t.flatten() for t in loaded_state_dict.values()] + flatten_tensor = torch.concat(flatten_tensors) + assert self.loader_id == assignment.loader_id + rank = self.loader_id * self.loader_every_n_ranks + assert rank == self.my_rank + logger.info(f"Sending {assignment.filename} from {rank} {self.loader_id}") + logger.info(f"Sending {assignment.fqns}") + dist.broadcast(flatten_tensor, src=rank, group=self.pg) + self.total_send_bytes += flatten_tensor.numel() * flatten_tensor.element_size() + return loaded_state_dict + + def _reshard_receive( + self, assignment: _Assignment, state_dict: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + flatten_tensor = torch.empty( + sum(math.prod(s) for s, d in zip(assignment.shapes, assignment.dtypes)), + dtype=assignment.dtypes[0], + device="cuda", + ) + rank = assignment.loader_id * self.loader_every_n_ranks + dist.broadcast(flatten_tensor, src=rank, group=self.pg) + self.total_recv_bytes += flatten_tensor.numel() * flatten_tensor.element_size() + + ret: dict[str, torch.Tensor] = {} + loc = 0 + for fqn, shape, dtype in zip( + assignment.fqns, assignment.shapes, assignment.dtypes + ): + n_ele = math.prod(shape) + ret[fqn] = flatten_tensor[loc : loc + n_ele].view(shape) + loc += n_ele + return ret + + def _reshard( + self, + results: list[dict[str, torch.Tensor]], + state_dict: dict[str, torch.Tensor], + ) -> None: + def _inplace_copy(fqn: str, full_tensors: tuple[torch.Tensor, ...]): + titan_fqns = self.stored_fqn_to_titan_fqn[fqn] + assert len(titan_fqns) == len(full_tensors) + for titan_fqn, full_tensor in zip(titan_fqns, full_tensors): + dtensor = state_dict[titan_fqn] + logger.info(f"{titan_fqn} {full_tensor.sum()}") + assert isinstance(dtensor, DTensor) + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, dtensor.device_mesh, dtensor.placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] + logger.info( + f"Copying {titan_fqn} with {slices=} {dtensor._local_tensor.shape=} " + f"{shape=} {offset=} {self.my_rank=} {dtensor.shape=} {full_tensor.shape=} " + f"{dtensor.placements=} {dtensor.device_mesh=} " + ) + dtensor.to_local().copy_(full_tensor[slices]) + + def _concat_shards(fqn, shards: list[torch.Tensor]) -> tuple[torch.Tensor, ...]: + if "wqkv" in fqn: + if "layer_norm" in fqn: + return (shards[0],) + return split_fused_qkv(shards) + + shard_dim = get_shard_dim(fqn) + if shard_dim is None: + return (shards[0],) + return (torch.cat(shards, dim=shard_dim),) + + fqns = list(results[0].keys()) + for result in results: + assert list(result.keys()) == fqns + + for fqn in fqns: + full_tensors = _concat_shards(fqn, [result[fqn] for result in results]) + _inplace_copy(fqn, full_tensors) + + +def _create_verified_state_dict( + pg: dist.ProcessGroup, mesh: DeviceMesh +) -> dict[str, torch.Tensor]: + placements = [Shard(0)] + state_dict = { + "tok_embeddings.weight": torch.rand( + 25256 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.attention.wqkv.layer_norm_weight": torch.rand( + 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.attention.wq.weight": torch.rand( + 640 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.attention.wk.weight": torch.rand( + 128 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.attention.wv.weight": torch.rand( + 128 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.attention.wo.weight": torch.rand( + 5120, 640 * 8, device="cuda", dtype=torch.bfloat16 + ), + # "layers.47.feed_forward.router_DE": torch.rand(5120, 128, device="cuda", dtype=torch.bfloat16), + # "layers.47.feed_forward.running_gate_stats_3E": torch.rand(3, 128, device="cuda", dtype=torch.bfloat16), + # "layers.47.feed_forward.global_gate_stats_3E": torch.rand(3, 128, device="cuda", dtype=torch.bfloat16), + "layers.47.feed_forward.w_in_shared_FD.weight": torch.rand( + 1024 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.w_out_shared_DF.weight": torch.rand( + 5120, 1024 * 8, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.w_swiglu_FD.weight": torch.rand( + 1024 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.norm.weight": torch.rand( + 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.experts.moe_w_in_eD_F": torch.rand( + 655360, 1024 * 8, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.experts.moe_w_out_eF_D": torch.rand( + 131072 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.experts.moe_w_swiglu_eD_F": torch.rand( + 655360, 1024 * 8, device="cuda", dtype=torch.bfloat16 + ), + } + return {k: distribute_tensor(v, mesh, placements) for k, v in state_dict.items()} + + +def _verify_state_dict( + state_dict: dict[str, torch.Tensor], path: str, rank: int +) -> None: + stored_state_dicts = [ + torch.load( + os.path.join(path, f"consolidated.0{i}.pth"), + map_location="cpu", + weights_only=False, + mmap=True, + ) + for i in range(8) + ] + + def read_and_verify_tensor(fqn: str, dtensor: DTensor) -> None: + logger.info(f"Verifying {fqn} {dtensor.shape=} {dtensor.placements=} ") + shards = [stored_state_dicts[i][fqn] for i in range(8)] + full_tensor = dtensor.full_tensor() + logger.info(f"Gather {fqn} {full_tensor.shape} completely.") + + if rank > 0: + return + + if len(shards[0].shape) == 1: + assert full_tensor.shape == shards[0].shape, fqn + assert torch.allclose(shards[0].to(device="cuda"), full_tensor), fqn + return + elif shards[0].shape[0] == full_tensor.shape[0]: + concat_shards = torch.cat(shards, dim=1) + logger.info(f"Load {fqn} completely.") + elif shards[0].shape[1] == full_tensor.shape[1]: + concat_shards = torch.cat(shards, dim=0) + logger.info(f"Load {fqn} completely.") + + concat_shards = concat_shards.to(device="cuda") + logger.info(f"Move to GPU {fqn} completely.") + + assert concat_shards.shape == full_tensor.shape, fqn + assert concat_shards.dtype == full_tensor.dtype, fqn + assert concat_shards.device == full_tensor.device, fqn + assert torch.allclose(concat_shards, full_tensor), fqn + + for k, v in state_dict.items(): + if "wq" in k and "wqkv" not in k: + pass + elif "wk" in k: + pass + elif "wv" in k: + pass + else: + assert v is not None, k + read_and_verify_tensor(k, v) + + +if __name__ == "__main__": + init_logger() + config = JobConfig() + config.parser.add_argument( + "--checkpoint.convert_path", + type=str, + default="", + help="""Specify the path of the target checkpoint to convert.""", + ) + config.parser.add_argument( + "--checkpoint.convert_load_every_n_ranks", + type=int, + default=8, + help=""" + Specify the interval at which ranks are assigned to load checkpoints. + + For example, if this number is 4, then ranks 0, 4, 8, ... will load the + checkpoint. Each loader is responsible for loading one file. If there + are more loaders than files, only the first few loaders will be assigned + to load the checkpoint. The default value is 8. + """, + ) + config.parser.add_argument( + "--checkpoint.fake_model", + action="store_true", + help="""If true, the model will be fake.""", + ) + config.parse_args() + assert config.checkpoint.convert_path != "" + + trainer: Optional[Trainer] = None + + try: + trainer = Trainer(config) + if os.path.exists(trainer.checkpointer.folder): + raise RuntimeError( + "The checkpoint folder already exists. Abort to avoid overwriting " + f"the checkpoint. {trainer.checkpointer.folder=}" + ) + if config.checkpoint.fake_model: + state_dict = _create_verified_state_dict( + trainer.world_mesh.get_group(), trainer.world_mesh + ) + else: + state_dict = trainer.checkpointer.states[MODEL].state_dict() + + size = 0 + for v in state_dict.values(): + size += v.numel() * v.element_size() + logger.info(f"Total size of the model: {size / 1e9:.2f} GB") + + # Do not support PP yet, we will need to iterate over the PP dimension and + # extract the corresponding state_dict and device_mesh. + if "freq_cis" in state_dict: + state_dict.pop("freqs_cis") + + state_dict = CheckpointConverter( + process_group=trainer.world_mesh.get_group(), + path=config.checkpoint.convert_path, + loader_every_n_ranks=config.checkpoint.convert_load_every_n_ranks, + ).convert(state_dict) + + class DummyModel: + def __init__(self, state_dict: dict[str, torch.Tensor]) -> None: + self._state_dict = state_dict + + def state_dict(self) -> dict[str, torch.Tensor]: + return self._state_dict + + if config.checkpoint.fake_model: + begin = time.time() + _verify_state_dict( + state_dict, + config.checkpoint.convert_path, + trainer.world_mesh.get_rank(), + ) + dist.barrier() + logger.info(f"Verifies state_dict {time.time() - begin}.") + else: + # oh, this is pretty bad, when can we get rid of the freqs_cis issue? + state_dict["freqs_cis"] = None + trainer.checkpointer.states[MODEL] = DummyModel(state_dict) + trainer.checkpointer.model_weights_only = True + trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype + trainer.checkpointer.save(curr_step=0, force=True) + time.sleep(2) + finally: + pass diff --git a/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh b/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh new file mode 100644 index 0000000000000000000000000000000000000000..f3fd310934b1181ed83fa9fc4463f0c2336b46fc --- /dev/null +++ b/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh @@ -0,0 +1,25 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# use envs as local overrides for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./convert_meta_to_dcp_with_gpus.sh +NGPU=${NGPU:-"8"} +LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7} +CONFIG_FILE=${CONFIG_FILE:-"../train_configs/llama4_17bx16e.toml"} + +overrides="" +if [ $# -ne 0 ]; then + overrides="$*" +fi + +PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +convert_meta_to_dcp_with_gpus_meta.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml new file mode 100644 index 0000000000000000000000000000000000000000..e947afba56fd3b8ee5bf1fe45e65160c99a6fd18 --- /dev/null +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml @@ -0,0 +1,65 @@ +# TODO: this toml config is still under development + +[job] +dump_folder = "./outputs" +description = "Llama 4 Maverick 17Bx128E training" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = false +save_tb_folder = "tb" + +[model] +name = "llama4" +flavor = "17bx128e" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm +tokenizer_path = "./assets/tokenizer/tokenizer.model" +# converters = "float8" + +[optimizer] +name = "AdamW" +lr = 4e-3 +eps = 1e-15 + +[lr_scheduler] +warmup_steps = 600 +lr_min = 0.1 + +[training] +batch_size = 1 +seq_len = 8192 +max_norm = 1.0 # grad norm clipping +steps = 3000 +compile = false +dataset = "c4" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 8 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 4 +# pipeline_parallel_schedule = "interleaved1f1b" +# pipeline_parallel_microbatches = 2 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 500 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'full' # ['none', 'selective', 'full'] + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = "output,router.gate" diff --git a/torchtitan/experiments/multimodal/tests/__init__.py b/torchtitan/experiments/multimodal/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e41cd717f6a439a9c08d76a9d0e4a54e190fc5a --- /dev/null +++ b/torchtitan/experiments/multimodal/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/experiments/multimodal/tokenizer/tiktoken.py b/torchtitan/experiments/multimodal/tokenizer/tiktoken.py new file mode 100644 index 0000000000000000000000000000000000000000..9d494a06f6557c0108b107dd3a3ba36832bb913f --- /dev/null +++ b/torchtitan/experiments/multimodal/tokenizer/tiktoken.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import os +from pathlib import Path +from typing import ( + AbstractSet, + Any, + cast, + Collection, + Dict, + Iterator, + List, + Literal, + Mapping, + Optional, + Sequence, + Union, +) + +import tiktoken +import torch +from tiktoken.load import load_tiktoken_bpe + +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger + +IMAGE_TOKEN_ID = 128256 +IGNORE_INDEX = -100 + + +class TikTokenizer(Tokenizer): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950 + + def __init__(self, model_path: str): + super().__init__(model_path) + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [ + f"<|reserved_special_token_{i}|>" + for i in range(5, self.num_reserved_special_tokens - 5) + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.special_tokens["<|image|>"] = IMAGE_TOKEN_ID + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + self._n_words: int = self.model.n_vocab + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.pad_id: int = -1 + self.image_id = IMAGE_TOKEN_ID + self.stop_tokens = { + self.special_tokens["<|end_of_text|>"], + self.special_tokens["<|eot_id|>"], + } + logger.info( + f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}, IMAGE ID {self.image_id}" + ) + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, + disallowed_special: Optional[Union[Literal["all"], Collection[str]]] = None, + ) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_tokens ("all"|set[str]): allowed special tokens in string + disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + assert type(s) is str + allowed_special = allowed_special or set() + disallowed_special = disallowed_special or () + + # The tiktoken tokenizer can handle <=400k chars without + # pyo3_runtime.PanicException. + TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + + # https://github.com/openai/tiktoken/issues/195 + # Here we iterate over subsequences and split if we exceed the limit + # of max consecutive non-whitespace or whitespace characters. + MAX_NO_WHITESPACES_CHARS = 25_000 + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: List[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(List[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + + def encode_multimodal(self, sample: Mapping[str, Any]) -> List[int]: + """ + Tokenizes a `str` of text and creates `labels` masking BOS, EOS and `image_id` tokens. + """ + # TODO(tj.solergibert) Should we keep `input_ids` OR `tokens` across this class, VisionCrossAttentionMask & the collator? + # For me it makes more sense to split `tokens` between `input_ids` & `labels` as in train.py BUT the `MultimodalDecoder` + # & everything else expects `tokens` + text = sample["text"] + tokens = self.encode( + text, bos=True, eos=True, allowed_special=set(["<|image|>"]) + ) + input_ids = torch.LongTensor(tokens[:-1]) + labels = torch.LongTensor(tokens[1:]) + labels = torch.where( + torch.isin( + labels, torch.LongTensor([self.bos_id, self.eos_id, self.image_id]) + ), + IGNORE_INDEX, + labels, + ) + + assert len(input_ids) == len(labels) # TODO(tj.solergibert) Delete + + sample.update({"tokens": input_ids, "labels": labels}) + + return sample + + +def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer: + return TikTokenizer(job_config.model.tokenizer_path) diff --git a/torchtitan/experiments/simple_fsdp/README.md b/torchtitan/experiments/simple_fsdp/README.md new file mode 100644 index 0000000000000000000000000000000000000000..887653ac0298369a04df9b791b9676bd7c6107c1 --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/README.md @@ -0,0 +1,40 @@ +## SimpleFSDP + +This folder includes an experimental frontend implementation for [SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile](https://arxiv.org/abs/2411.00284). SimpleFSDP is a compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations. + +### Enable SimpleFSDP Training + +```bash +CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ./run_train.sh --model.name llama3_simple_fsdp --training.compile --training.mixed_precision_param float32 +``` + +Note: The mixed precision training support is on-going. We set `training.mixed_precision_param` to `float32` for now and will remove it once the integration is completed. + +### Composability Support + +Some of the features require the updates from PyTorch, with which we are working on providing composability support for the following features: + +| Feature | Support | +| :--------: | :--------: | +|Meta Initialization| ✅ | +|Activation Checkpointing| ✅ | +|Mixed Precision Training| 🚧 | +|Tensor Parallelism| 🚧 | +|Context Parallelism| ✅ | +|Pipeline Parallelism| ✅ | +|Distributed Checkpointing| 🚧 | +|Float8 Training| ❌ | + + +### Citation + +If you find SimpleFSDP useful, please kindly consider citing the following paper: + +```latex +@article{zhang2024simplefsdp, + title={SimpleFSDP: Simpler Fully Sharded Data Parallel with torch. compile}, + author={Zhang, Ruisi and Liu, Tianyu and Feng, Will and Gu, Andrew and Purandare, Sanket and Liang, Wanchao and Massa, Francisco}, + journal={arXiv preprint arXiv:2411.00284}, + year={2024} +} +``` diff --git a/torchtitan/experiments/simple_fsdp/__init__.py b/torchtitan/experiments/simple_fsdp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9732a5ad22ab05544d15690b0120ea5dc6762ae5 --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer +from torchtitan.models.llama3 import llama3_configs, pipeline_llama +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .model import SimpleFSDPTransformer +from .parallelize_llama import parallelize_llama + +register_train_spec( + TrainSpec( + name="llama3_simple_fsdp", + cls=SimpleFSDPTransformer, + config=llama3_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_tiktoken_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-312.pyc b/torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e968bbb29b128e89fa1832739241a306683da0f7 Binary files /dev/null and b/torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-312.pyc differ diff --git a/torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc b/torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bf9816b8cc4ba6ee22f8487193269c51d1d0dd8 Binary files /dev/null and b/torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc differ diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..94e447cde456803682e2beff12f177ad6b5d19f1 --- /dev/null +++ b/torchtitan/models/attention.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from typing import Callable, ClassVar, Optional + +import torch +import torch.nn.functional as F +from torch.nn.attention.flex_attention import ( + BlockMask, + create_block_mask, + flex_attention, +) + + +class FlexAttention(torch.nn.Module): + # We registered flex_attention related attributes as class variables as we + # need to amortize the cost of compilation. + flex_attn: ClassVar[Callable] = torch.compile( + flex_attention, mode="max-autotune-no-cudagraphs" + ) + compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask) + used_attn_mask_types: ClassVar[set[str]] = set() + # Attention mask type to the created BlockMask. + # This allows us to keep track the created block masks for each + # new batch. We will use this to update the block mask when a + # new batch is created. This also allows user to create different + # block masks for different layers. + block_masks: ClassVar[dict[str, BlockMask]] = {} + + # Instance variables. + attn_mask_type: str + + def __init__(self, attn_mask_type: str) -> None: + super().__init__() + if attn_mask_type not in ["causal", "block_causal"]: + raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.") + self.attn_mask_type = attn_mask_type + FlexAttention.used_attn_mask_types.add(attn_mask_type) + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + block_mask = FlexAttention.block_masks[self.attn_mask_type] + return FlexAttention.flex_attn(q, k, v, block_mask=block_mask) + + @staticmethod + def _get_causal_mask_fn() -> Callable: + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + return causal_mask + + @staticmethod + def _get_block_causal_mask_fn(batch: torch.Tensor, eos_id: int) -> Callable: + # batch is [b, s, h, d] shape + mask = batch == eos_id + mask[:, -1] = True + acc_mask = torch.cumsum(torch.where(mask, 1, 0), dim=1) + seq_idx = torch.zeros_like(acc_mask, dtype=torch.int32) + seq_idx[:, 1:] = acc_mask[:, :-1] + + def block_causal_mask(b, h, q_idx, kv_idx): + return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx) + + return block_causal_mask + + @staticmethod + @torch.no_grad() + def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None: + # batch is [b, s, h, d] shape + for attn_mask_type in FlexAttention.used_attn_mask_types: + match attn_mask_type: + case "causal": + if FlexAttention.block_masks.get(attn_mask_type, None) is not None: + continue + # We don't care about batch dimension -- + # all samples have the same lower triangle mask. + batch_dimension = 1 + mask_fn = FlexAttention._get_causal_mask_fn() + case "block_causal": + if eos_id is None: + raise RuntimeError( + "eos_id must be provided for block_causal mask." + ) + batch_dimension = batch.shape[0] + mask_fn = FlexAttention._get_block_causal_mask_fn(batch, eos_id) + case _: + raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}") + + seq_len = batch.shape[1] + block_mask = FlexAttention.compiled_create_block_mask( + mask_fn, batch_dimension, None, seq_len, seq_len + ) + FlexAttention.block_masks[attn_mask_type] = block_mask + + +class ScaledDotProductAttention(torch.nn.Module): + def __init__(self, attn_mask_type: str) -> None: + super().__init__() + if attn_mask_type != "causal": + raise ValueError( + "TorchTitan with SDPA currently only supports causal mask." + ) + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + return F.scaled_dot_product_attention(q, k, v, is_causal=True) + + +def build_attention(use_flex_attn: bool, attn_mask_type: str): + if use_flex_attn: + return FlexAttention(attn_mask_type) + else: + return ScaledDotProductAttention(attn_mask_type) + + +def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None: + FlexAttention.init_attention_mask(batch, eos_id) diff --git a/torchtitan/models/llama3/__pycache__/__init__.cpython-312.pyc b/torchtitan/models/llama3/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14fdfaddb9f5887a341b77ee41a2e0aaec551d88 Binary files /dev/null and b/torchtitan/models/llama3/__pycache__/__init__.cpython-312.pyc differ diff --git a/torchtitan/models/llama3/__pycache__/model.cpython-312.pyc b/torchtitan/models/llama3/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53abf9e64e5d588fd333f9953a8741750fbe072f Binary files /dev/null and b/torchtitan/models/llama3/__pycache__/model.cpython-312.pyc differ diff --git a/torchtitan/models/llama3/__pycache__/parallelize_llama.cpython-312.pyc b/torchtitan/models/llama3/__pycache__/parallelize_llama.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0c87a4e10b2896a7f42393cee1b24361c972cfe Binary files /dev/null and b/torchtitan/models/llama3/__pycache__/parallelize_llama.cpython-312.pyc differ