diff --git a/fla/layers/__init__.py b/fla/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..16b64d5b84beda1d06fc7334f6fea2e5aba6a7fc --- /dev/null +++ b/fla/layers/__init__.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from .abc import ABCAttention +from .attn import Attention +from .based import BasedLinearAttention +from .bitattn import BitAttention +from .delta_net import DeltaNet +from .forgetting_attn import ForgettingAttention +from .gated_deltanet import GatedDeltaNet +from .gated_deltaproduct import GatedDeltaProduct +from .gla import GatedLinearAttention +from .gsa import GatedSlotAttention +from .hgrn import HGRNAttention +from .hgrn2 import HGRN2Attention +from .lightnet import LightNetAttention +from .linear_attn import LinearAttention +from .multiscale_retention import MultiScaleRetention +from .nsa import NativeSparseAttention +from .rebased import ReBasedLinearAttention +from .rwkv6 import RWKV6Attention +from .rwkv7 import RWKV7Attention + +__all__ = [ + 'ABCAttention', + 'Attention', + 'BasedLinearAttention', + 'BitAttention', + 'DeltaNet', + 'ForgettingAttention', + 'GatedDeltaNet', + 'GatedDeltaProduct', + 'GatedLinearAttention', + 'GatedSlotAttention', + 'HGRNAttention', + 'HGRN2Attention', + 'LightNetAttention', + 'LinearAttention', + 'MultiScaleRetention', + 'NativeSparseAttention', + 'ReBasedLinearAttention', + 'RWKV6Attention', + 'RWKV7Attention', +] diff --git a/fla/models/abc/__pycache__/configuration_abc.cpython-312.pyc b/fla/models/abc/__pycache__/configuration_abc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..769514785efdda58c9c56e369954122435b36ab0 Binary files /dev/null and b/fla/models/abc/__pycache__/configuration_abc.cpython-312.pyc differ diff --git a/fla/models/abc/configuration_abc.py b/fla/models/abc/configuration_abc.py new file mode 100644 index 0000000000000000000000000000000000000000..f358bc9bd044d76b2bd74fbc02ea52c0f461cdc3 --- /dev/null +++ b/fla/models/abc/configuration_abc.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class ABCConfig(PretrainedConfig): + + model_type = 'abc' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + hidden_size: int = 2048, + gate_low_rank_dim: int = 16, + clamp_min: float = -32, + clamp_max: float = 32, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_slots: Optional[int] = 64, + use_short_conv: bool = False, + conv_size: int = 4, + exapnd_k: float = 0.5, + exapnd_v: float = 1, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_rope: bool = True, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.006, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs + ): + self.hidden_size = hidden_size + self.gate_low_rank_dim = gate_low_rank_dim + self.clamp_min = clamp_min + self.clamp_max = clamp_max + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_slots = num_slots + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.expand_k = exapnd_k + self.expand_v = exapnd_v + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_rope = use_rope + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla/models/delta_net/__pycache__/__init__.cpython-312.pyc b/fla/models/delta_net/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17ccaa6a83253d8466ade3dad101b4c6b85dd12e Binary files /dev/null and b/fla/models/delta_net/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-312.pyc b/fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..945e17c04e052f26b57046ae956549d855006dd8 Binary files /dev/null and b/fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-312.pyc differ diff --git a/fla/models/gla/__pycache__/modeling_gla.cpython-312.pyc b/fla/models/gla/__pycache__/modeling_gla.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c69cf180549a37c9e3a4a11088c72faa2f52b870 Binary files /dev/null and b/fla/models/gla/__pycache__/modeling_gla.cpython-312.pyc differ diff --git a/fla/models/gsa/__pycache__/__init__.cpython-312.pyc b/fla/models/gsa/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..424184a18dd8944f6bf669a04720ade111ce16c9 Binary files /dev/null and b/fla/models/gsa/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc b/fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46bd8d344e9288c050d397665157c282cae01294 Binary files /dev/null and b/fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc differ diff --git a/fla/models/gsa/__pycache__/modeling_gsa.cpython-312.pyc b/fla/models/gsa/__pycache__/modeling_gsa.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..209d7f6e64ce6e3e8f8b37e11d5104f028449ee6 Binary files /dev/null and b/fla/models/gsa/__pycache__/modeling_gsa.cpython-312.pyc differ diff --git a/fla/models/hgrn/__pycache__/modeling_hgrn.cpython-312.pyc b/fla/models/hgrn/__pycache__/modeling_hgrn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4a6d286e5536273724bdc8904e7af62cd02875f Binary files /dev/null and b/fla/models/hgrn/__pycache__/modeling_hgrn.cpython-312.pyc differ diff --git a/fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc b/fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb8fb48b443ac4126bc5866683edd74cfdd9e426 Binary files /dev/null and b/fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc differ diff --git a/fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc b/fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f03d3bb27aaf7d79813f7b2dfedc897689a0e8f Binary files /dev/null and b/fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc differ diff --git a/fla/models/linear_attn/__init__.py b/fla/models/linear_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4446d4725923bdcf649dd38e400e7a44ee2cae0 --- /dev/null +++ b/fla/models/linear_attn/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.linear_attn.configuration_linear_attn import LinearAttentionConfig +from fla.models.linear_attn.modeling_linear_attn import LinearAttentionForCausalLM, LinearAttentionModel + +AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig) +AutoModel.register(LinearAttentionConfig, LinearAttentionModel) +AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM) + +__all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel'] diff --git a/fla/models/mamba/__pycache__/__init__.cpython-312.pyc b/fla/models/mamba/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..389dd1905a3646584f79a759a659556b130bb49b Binary files /dev/null and b/fla/models/mamba/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc b/fla/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a424b0cf05696880e3f0f058823ceb12c0a44b7 Binary files /dev/null and b/fla/models/mamba/__pycache__/modeling_mamba.cpython-312.pyc differ diff --git a/fla/models/mamba/modeling_mamba.py b/fla/models/mamba/modeling_mamba.py new file mode 100644 index 0000000000000000000000000000000000000000..93671429b60fc063958f6e31f3326f3906bfcf52 --- /dev/null +++ b/fla/models/mamba/modeling_mamba.py @@ -0,0 +1,843 @@ +# coding=utf-8 +# Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MAMBA model.""" + +import math +import warnings +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from transformers.generation import GenerationMixin +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput, logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.models.mamba.configuration_mamba import MambaConfig +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm + +logger = logging.get_logger(__name__) + + +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + try: + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + except ImportError: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + + try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + except ImportError: + causal_conv1d_update, causal_conv1d_fn = None, None + is_fast_path_available = all(( + selective_state_update, + selective_scan_fn, + causal_conv1d_fn, + causal_conv1d_update, + mamba_inner_fn + )) + + +class MambaCache: + """ + Cache for mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + batch_size (`int`): + The batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Attributes: + dtype: (`torch.dtype`): + The default `dtype` used to initializing the cache. + intermediate_size: (`int`): + Model's intermediate_size taken from config. + ssm_state_size: (`int`): + Model's state_size taken from config. + conv_kernel_size: (`int`): + Model's convolution kernel size taken from config + conv_states: (`torch.Tensor`): + A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states. + ssm_states: (`torch.Tensor`): + A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states + + Example: + + ```python + >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache + + >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values + MambaCache() + ``` + """ + + # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. + def __init__( + self, + config: PretrainedConfig, + batch_size: int = None, + dtype: torch.dtype = torch.float16, + device: Optional[Union[torch.device, str]] = None, + max_batch_size: Optional[int] = None, + ): + if max_batch_size is not None: + logger.warning_once( + f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " + "v4.46. Use the more precisely named 'batch_size' argument instead." + ) + self.dtype = dtype + self.batch_size = batch_size or max_batch_size + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + + self.conv_states: torch.Tensor = torch.zeros( + config.num_hidden_layers, + self.batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=dtype, + ) + self.ssm_states: torch.Tensor = torch.zeros( + config.num_hidden_layers, + self.batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=dtype, + ) + + torch._dynamo.mark_static_address(self.conv_states) + torch._dynamo.mark_static_address(self.ssm_states) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +class MambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: MambaConfig, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.intermediate_size + self.time_step_rank = int(config.time_step_rank) + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias + self.conv1d = nn.Conv1d( + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.intermediate_size, + padding=config.conv_kernel - 1, + ) + + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + + # projection of the input hidden states + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias) + # selective projection used to make dt, B and C input dependant + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) + # time step projection (discretization) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] + A = A.expand(self.intermediate_size, -1).contiguous() + + self.A_log = nn.Parameter(torch.log(A)) + self.D = nn.Parameter(torch.ones(self.intermediate_size)) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.use_bias = config.use_bias + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of " + "`(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the naive implementation. " + "To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[MambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states).transpose(1, 2) + + if self.training and cache_params is None: # Doesn't support outputting the states -> used for training + contextualized_states = mamba_inner_fn( + projected_states, + self.conv1d.weight, + self.conv1d.bias if self.use_conv_bias else None, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias.float() if self.use_bias else None, + -torch.exp(self.A_log.float()), + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + ) + + else: + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if cache_params is not None and cache_position[0] > 0: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_states[self.layer_idx], + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.update_conv_state(self.layer_idx, conv_states, cache_position) + hidden_states = causal_conv1d_fn( + hidden_states, conv_weights, self.conv1d.bias, activation=self.activation + ) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) + + A = -torch.exp(self.A_log.float()) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None + if cache_params is not None and cache_position[0] > 0: + scan_outputs = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states[..., 0], + discrete_time_step[..., 0], + A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(self.layer_idx, ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + return contextualized_states + + def slow_forward( + self, + input_states, + cache_params: Optional[MambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # 1. Gated MLP's linear projection + # [batch, 2 * intermediate_size, seq_len] + projected_states = self.in_proj(input_states).transpose(1, 2) + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + # use `cache_position.shape[0]` to check whether we are in prefill + # stage, it's equivalent to check `cache_position[0] == 0`, which + # breaks dynamo fullgraph constraints + if cache_position.shape[0] == self.conv_kernel_size: + conv_state = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + + cache_params.update_conv_state(self.layer_idx, conv_state, cache_position) + # [batch, intermediate_size, seq_len] + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) + else: + conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + # [batch, intermediate_size, 1] : decoding + hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) + # [batch, intermediate_size, seq_len] + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + # [batch, seq_len, intermediate_size] + discrete_time_step = self.dt_proj(time_step) + # [batch, intermediate_size, seq_len] + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) + + # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) + # [intermediate_size, ssm_state_size] + A = -torch.exp(self.A_log.float()) + # [batch, intermediate_size, seq_len, ssm_state_size] + discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) + # [batch, intermediate_size, seq_len, ssm_state_size] + discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() + deltaB_u = discrete_B * hidden_states[:, :, :, None].float() + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + scan_outputs = [] + for i in range(seq_len): + # [batch, intermediade_size, ssm_state] + ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] + # [batch, intermediade_size, 1] + scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) + scan_outputs.append(scan_output[:, :, 0]) + # [batch, seq_len, intermediade_size] + scan_output = torch.stack(scan_outputs, dim=-1) + scan_output = scan_output + (hidden_states * self.D[None, :, None]) + scan_output = (scan_output * self.act(gate)) + + if cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + # 4. Final linear projection + # [batch, seq_len, hidden_size] + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[MambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + if is_fast_path_available and "cuda" in self.x_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class MambaBlock(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.mixer = MambaMixer(config, layer_idx=layer_idx) + + def forward( + self, + hidden_states, + cache_params: Optional[MambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + residual = hidden_states + hidden_states = self.norm(hidden_states) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + hidden_states = self.mixer( + hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask + ) + hidden_states = residual + hidden_states + if self.residual_in_fp32: + hidden_states = hidden_states.to(dtype=self.norm.weight.dtype) + return hidden_states + + +class MambaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MambaConfig + base_model_prefix = "backbone" + _no_split_modules = ["MambaBlock", "MambaMixer"] + supports_gradient_checkpointing = True + _is_stateful = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, MambaMixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale + if self.config.time_step_init_scheme == "constant": + nn.init.constant_(module.dt_proj.weight, dt_init_std) + elif self.config.time_step_init_scheme == "random": + nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std) + + dt = torch.exp( + torch.rand(self.config.intermediate_size) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + module.dt_proj.bias.data = nn.Parameter(inv_dt.to(module.dt_proj.bias.device)) + module.dt_proj.bias._no_reinit = True + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.num_hidden_layers) + + +@dataclass +class MambaOutput(ModelOutput): + """ + Class for the MAMBA model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`MambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cache_params: Optional[MambaCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MambaCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`MambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[MambaCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class MambaModel(MambaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + # Initialize weights and apply final processing + self._register_load_state_dict_pre_hook(self.load_hook) + self.post_init() + + def load_hook(self, state_dict, prefix, *args): + for k in state_dict: + if "embedding." in k: + state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) + break + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + cache_params: Optional[MambaCache] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, MambaOutput]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if use_cache: + if cache_params is None: + cache_params = MambaCache( + self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device) + elif cache_position is None: + # cases when we do manual forward instead of using `model.generate` which will initiate + # `cache_position` and makes sure it is not None, throw error here instead of doing some + # hack to conjecture the current cache position + raise ValueError( + "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, " + "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will " + "be initialized for you automatically" + ) + else: + cache_params = None + + hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None + for mixer_block in self.layers: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask + ) + else: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return MambaOutput( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + ) + + +class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.backbone = MambaModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def _update_model_kwargs_for_generation( + self, outputs: ModelOutput, + model_kwargs: Dict[str, Any], + num_new_tokens: int = 1, + **kwargs + ) -> Dict[str, Any]: + model_kwargs["cache_params"] = outputs.get("cache_params", None) + if ( + model_kwargs.get("use_cache", True) + and "cache_position" in model_kwargs + and model_kwargs["cache_position"] is not None + ): + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens + + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + return model_kwargs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids, + inputs_embeds=None, + use_cache=None, + cache_params: Optional[MambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + logits_to_keep: Optional[int] = None, + **kwargs, + ): + if use_cache: + # `cache_position` should have been initialized in `generate` + if cache_position is None: + raise ValueError( + "`cache_position` should not be None as it should have been initialized in " + "`model.generate`, you are responsible for passing in a valid `cache_position` if " + "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" + ) + if cache_position[0] > 0: + input_ids = input_ids[:, -1].unsqueeze(-1) + + if attention_mask is not None: + attention_mask = None + + else: + # we initialize the `cache_position` to full size of `conv_states` at prefill stage + # considering padding will be applied when input length is shorter, and truncation + # will be applied when it is longer, so it will be equivalent to always have it match + # the length of `cache_params.conv_states`, which is `config.conv_kernel` + cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) + + if inputs_embeds is not None and cache_params is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'cache_params': cache_params, + 'use_cache': use_cache, + 'cache_position': cache_position, + 'attention_mask': attention_mask, + 'logits_to_keep': logits_to_keep, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_params: Optional[MambaCache] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + logits_to_keep: Optional[int] = 0, + **kwargs, # for now we need this for generation + ) -> Union[Tuple, MambaCausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + mamba_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = mamba_outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + # Enable model parallelism + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + mamba_outputs[1:] + return (loss,) + output if loss is not None else output + + return MambaCausalLMOutput( + loss=loss, + logits=logits, + cache_params=mamba_outputs.cache_params, + hidden_states=mamba_outputs.hidden_states, + ) diff --git a/fla/models/nsa/__init__.py b/fla/models/nsa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..65b8d8982cfb751a9dc0b15b4c8546ac08bf1b06 --- /dev/null +++ b/fla/models/nsa/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.nsa.configuration_nsa import NSAConfig +from fla.models.nsa.modeling_nsa import NSAForCausalLM, NSAModel + +AutoConfig.register(NSAConfig.model_type, NSAConfig) +AutoModel.register(NSAConfig, NSAModel) +AutoModelForCausalLM.register(NSAConfig, NSAForCausalLM) + + +__all__ = [ + 'NSAConfig', 'NSAModel', 'NSAForCausalLM', +] diff --git a/fla/models/nsa/__pycache__/__init__.cpython-312.pyc b/fla/models/nsa/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32a4fe047418e7735bc5636b067c8b5537a0414f Binary files /dev/null and b/fla/models/nsa/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/models/retnet/__pycache__/__init__.cpython-312.pyc b/fla/models/retnet/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ca7b09bebc20f464d2d071751f61ebe8f65d143 Binary files /dev/null and b/fla/models/retnet/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/models/rwkv6/__init__.py b/fla/models/rwkv6/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a5902546e88200af55af16acf7c6f85512d72cf --- /dev/null +++ b/fla/models/rwkv6/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config +from fla.models.rwkv6.modeling_rwkv6 import RWKV6ForCausalLM, RWKV6Model + +AutoConfig.register(RWKV6Config.model_type, RWKV6Config, True) +AutoModel.register(RWKV6Config, RWKV6Model, True) +AutoModelForCausalLM.register(RWKV6Config, RWKV6ForCausalLM, True) + + +__all__ = ['RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model'] diff --git a/fla/models/rwkv6/configuration_rwkv6.py b/fla/models/rwkv6/configuration_rwkv6.py new file mode 100644 index 0000000000000000000000000000000000000000..8635aa543bf0373e260279fc9d6db3c7e8985f7d --- /dev/null +++ b/fla/models/rwkv6/configuration_rwkv6.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class RWKV6Config(PretrainedConfig): + + model_type = 'rwkv6' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 0.5, + expand_v: int = 1, + hidden_ratio: Optional[int] = 3.5, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + proj_low_rank_dim: int = 32, + gate_low_rank_dim: int = 64, + hidden_act: str = "sqrelu", + max_position_embeddings: int = 2048, + norm_first: bool = True, + norm_bias: bool = True, + norm_eps: float = 1e-5, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.006, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs + ): + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.norm_first = norm_first + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.proj_low_rank_dim = proj_low_rank_dim + self.gate_low_rank_dim = gate_low_rank_dim + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.norm_bias = norm_bias + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla/models/rwkv7/__pycache__/configuration_rwkv7.cpython-312.pyc b/fla/models/rwkv7/__pycache__/configuration_rwkv7.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..780d6bc67a0b8272a0bd7dd76fbb922ccec3915c Binary files /dev/null and b/fla/models/rwkv7/__pycache__/configuration_rwkv7.cpython-312.pyc differ diff --git a/fla/models/rwkv7/modeling_rwkv7.py b/fla/models/rwkv7/modeling_rwkv7.py new file mode 100644 index 0000000000000000000000000000000000000000..038e58d254883865f2f5d8a612ec0d0060c130c1 --- /dev/null +++ b/fla/models/rwkv7/modeling_rwkv7.py @@ -0,0 +1,505 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.layers.rwkv7 import RWKV7Attention +from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, LayerNorm +from fla.modules.activations import ACT2FN + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + +logger = logging.get_logger(__name__) + + +class RWKV7FeedForward(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'sqrelu', + layer_idx: int = None + ) -> RWKV7FeedForward: + super().__init__() + + self.hidden_size = hidden_size + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio) + intermediate_size = 32 * ((intermediate_size + 32 - 1) // 32) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + self.x_k = nn.Parameter(torch.zeros(hidden_size)) + + self.key = nn.Linear(hidden_size, intermediate_size, bias=False) + self.value = nn.Linear(intermediate_size, hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + self.layer_idx = layer_idx + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + state: Optional[Cache] = None + ) -> torch.Tensor: + if attention_mask is not None: + x = x.mul(attention_mask[:, -x.shape[-2]:, None]) + if x.shape[1] == 1 and state is not None and state[self.layer_idx]['ffn_state'] is not None: + shifted = state[self.layer_idx]['ffn_state'].unsqueeze(1) + else: + shifted = self.time_shift(x) + if state is not None and state[self.layer_idx]['ffn_state'] is not None: + shifted[:, 0] = state[self.layer_idx]['ffn_state'][-1] + if state is not None: + # no need to update the offset twice + state.update(ffn_state=x[:, -1], layer_idx=self.layer_idx, offset=0) + return self.value(self.act_fn(self.key(x.addcmul(shifted - x, self.x_k)))), state + + +class RWKV7Block(nn.Module): + + def __init__( + self, + config: RWKV7Config, + layer_idx: int + ) -> RWKV7Block: + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + if config.norm_first and layer_idx == 0: + self.pre_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)( + config.hidden_size, + bias=config.norm_bias, + eps=config.norm_eps + ) + self.attn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)( + config.hidden_size, + bias=config.norm_bias, + eps=config.norm_eps + ) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = RWKV7Attention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + head_dim=config.head_dim, + num_heads=config.num_heads, + decay_low_rank_dim=config.decay_low_rank_dim, + gate_low_rank_dim=config.gate_low_rank_dim, + a_low_rank_dim=config.a_low_rank_dim, + v_low_rank_dim=config.v_low_rank_dim, + norm_eps=config.norm_eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx, + value_dim=config.value_dim[layer_idx] + ) + self.ffn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)( + config.hidden_size, + bias=config.norm_bias, + eps=config.norm_eps + ) + self.ffn = RWKV7FeedForward( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + layer_idx=layer_idx + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + v_first: torch.Tensor = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = self.pre_norm(hidden_states) if hasattr(self, 'pre_norm') else hidden_states + hidden_states = self.attn_norm(residual) + hidden_states, attentions, past_key_values, v_first = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + v_first=v_first, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.ffn_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states, past_key_values = self.ffn(hidden_states, attention_mask, past_key_values) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values, v_first) + + return outputs + + +class RWKV7PreTrainedModel(PreTrainedModel): + + config_class = RWKV7Config + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['RWKV7Block'] + _supports_cache_class = True + _skip_keys_device_placement = ["past_key_values"] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + warnings.warn( + "RWKV-7 employs a carefully designed initialization strategy tailored to its architecture. " + "The detailed initialization scheme is currently not implemented here but can be found in the " + "official code repository. We emphasize that using the recommended initialization is essential " + "for replicating the results in RWKV-7 paper. Deviations from the prescribed initialization " + "may lead to performance degradation.\n" + "Alternatively, please generate initial weights from the official RWKV code repository, and " + "convert the PyTorch checkpoint into FLA supported format." + ) + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Parameter): + nn.init.normal_(module, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class RWKV7Model(RWKV7PreTrainedModel): + + def __init__(self, config: RWKV7Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([RWKV7Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)( + config.hidden_size, + bias=config.norm_bias, + eps=config.norm_eps + ) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`RWKV7Model` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + + v_first = torch.zeros_like(hidden_states) + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values, v_first = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + v_first, + **kwargs + ) + else: + hidden_states, attentions, past_key_values, v_first = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + v_first=v_first, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class RWKV7ForCausalLM(RWKV7PreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = RWKV7Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + labels: Optional[torch.LongTensor] = None, + shift_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + has_labels = (labels is not None) or (shift_labels is not None) + if not (fuse_linear_and_cross_entropy and has_labels): + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if has_labels: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + + # shift_labels: See https://github.com/huggingface/transformers/pull/36607/files. + if shift_labels is None: + shift_labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + shift_labels = shift_labels.to(hidden_states.device) + + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, shift_labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(shift_labels.numel(), -1), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla/models/samba/__init__.py b/fla/models/samba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a27a4b4cac782eb4a3e6c35216405d320e2c6507 --- /dev/null +++ b/fla/models/samba/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.samba.configuration_samba import SambaConfig +from fla.models.samba.modeling_samba import SambaBlock, SambaForCausalLM, SambaModel + +AutoConfig.register(SambaConfig.model_type, SambaConfig, True) +AutoModel.register(SambaConfig, SambaModel, True) +AutoModelForCausalLM.register(SambaConfig, SambaForCausalLM, True) + + +__all__ = ['SambaConfig', 'SambaForCausalLM', 'SambaModel', 'SambaBlock'] diff --git a/fla/models/samba/configuration_samba.py b/fla/models/samba/configuration_samba.py new file mode 100644 index 0000000000000000000000000000000000000000..27311f06a81f0132a409b9dab10b63fc9e19333a --- /dev/null +++ b/fla/models/samba/configuration_samba.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- + +import math +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class SambaConfig(PretrainedConfig): + + model_type = "samba" + + def __init__( + self, + hidden_size: int = 2304, + state_size: int = 16, + num_hidden_layers: int = 18, + norm_eps=1e-5, + pad_token_id: int = 0, + bos_token_id: int = 1, + eos_token_id: int = 2, + expand: int = 2, + conv_kernel: int = 4, + use_bias: bool = False, + use_conv_bias: bool = True, + hidden_act: str = "swish", + initializer_range: str = 0.02, + residual_in_fp32: bool = False, + time_step_rank: str = "auto", + time_step_scale: float = 1.0, + time_step_min: float = 0.001, + time_step_max: float = 0.1, + time_step_init_scheme: str = "random", + time_step_floor: float = 1e-4, + max_position_embeddings: int = 2048, + attn: Optional[Dict] = { + 'layers': (1, 3, 5, 7, 9, 11, 13, 15, 17), + 'num_heads': 18, + 'num_kv_heads': 18, + 'qkv_bias': False, + 'window_size': 2048, + 'rope_theta': 10000. + }, + hidden_ratio: Optional[int] = 4, + rescale_prenorm_residual: bool = False, + use_cache: bool = True, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + tie_word_embeddings: bool = False, + **kwargs, + ): + self.hidden_size = hidden_size + self.state_size = state_size + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.conv_kernel = conv_kernel + self.expand = expand + self.intermediate_size = int(expand * self.hidden_size) + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.use_bias = use_bias + self.use_conv_bias = use_conv_bias + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank + self.time_step_scale = time_step_scale + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.time_step_init_scheme = time_step_init_scheme + self.time_step_floor = time_step_floor + self.max_position_embeddings = max_position_embeddings + self.attn = attn + self.hidden_ratio = hidden_ratio + self.rescale_prenorm_residual = rescale_prenorm_residual + self.residual_in_fp32 = residual_in_fp32 + self.use_cache = use_cache + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs + ) diff --git a/fla/models/samba/modeling_samba.py b/fla/models/samba/modeling_samba.py new file mode 100644 index 0000000000000000000000000000000000000000..0da2cfa64f3bb89e51e2f799194625e4448138a9 --- /dev/null +++ b/fla/models/samba/modeling_samba.py @@ -0,0 +1,413 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.generation import GenerationMixin +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput, logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.models.mamba.modeling_mamba import MambaCache, MambaMixer +from fla.models.samba.configuration_samba import SambaConfig +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as SambaMLP +from fla.modules import RMSNorm + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + +logger = logging.get_logger(__name__) + + +class SambaBlock(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.mixer_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.mixer = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.mixer = MambaMixer(config, layer_idx=layer_idx) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = SambaMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[Tuple[torch.Tensor]] = None, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + hidden_states = self.mixer_norm(hidden_states) + if isinstance(self.mixer, MambaMixer): + hidden_states = self.mixer(hidden_states, cache_params=cache_params, **kwargs) + else: + hidden_states, _, cache_params = self.mixer(hidden_states=hidden_states, past_key_values=cache_params, **kwargs) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + return hidden_states + + +class SambaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SambaConfig + base_model_prefix = "backbone" + _no_split_modules = ["SambaBlock"] + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, MambaMixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale + if self.config.time_step_init_scheme == "constant": + nn.init.constant_(module.dt_proj.weight, dt_init_std) + elif self.config.time_step_init_scheme == "random": + nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std) + + dt = torch.exp( + torch.rand(self.config.intermediate_size) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + module.dt_proj.bias.data = nn.Parameter(inv_dt.to(module.dt_proj.bias.device)) + module.dt_proj.bias._no_reinit = True + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.num_layers) + + +@dataclass +class SambaOutput(ModelOutput): + """ + Class for the Samba model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`MambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cache_params: Optional[MambaCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class SambaCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`MambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[MambaCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class SambaModel(SambaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([SambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = RMSNorm(config.hidden_size, eps=config.norm_eps) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + cache_params: Optional[MambaCache] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, SambaOutput]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if cache_params is None and use_cache: + cache_params = MambaCache( + self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + + hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None + for mixer_block in self.layers: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, + hidden_states, + cache_params, + **kwargs + ) + else: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + **kwargs + ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if use_cache: + cache_params.seqlen_offset += inputs_embeds.shape[1] + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return SambaOutput( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + ) + + +class SambaForCausalLM(SambaPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.backbone = SambaModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def _update_model_kwargs_for_generation( + self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs + ) -> Dict[str, Any]: + model_kwargs["cache_params"] = outputs.get("cache_params", None) + return model_kwargs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids, + cache_params: + Optional[MambaCache] = None, + inputs_embeds=None, + attention_mask=None, + use_cache: Optional[bool] = True, + logits_to_keep: Optional[int] = None, + **kwargs: Unpack[Dict] + ): + # only last token for inputs_ids if the state is passed along. + if cache_params is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + + if inputs_embeds is not None and cache_params is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'cache_params': cache_params, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + 'logits_to_keep': logits_to_keep, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_params: Optional[MambaCache] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, SambaCausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + **kwargs + ) + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return SambaCausalLMOutput( + loss=loss, + logits=logits, + cache_params=outputs.cache_params, + hidden_states=outputs.hidden_states, + ) diff --git a/fla/models/transformer_dsmtp/__pycache__/__init__.cpython-312.pyc b/fla/models/transformer_dsmtp/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..486480b593d5e9484906c465ca6dac44b9f3bf82 Binary files /dev/null and b/fla/models/transformer_dsmtp/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla/models/transformer_dsmtp/__pycache__/configuration_transformer.cpython-312.pyc b/fla/models/transformer_dsmtp/__pycache__/configuration_transformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c54025129aa34f01aceead8a3c9185ef39fce49 Binary files /dev/null and b/fla/models/transformer_dsmtp/__pycache__/configuration_transformer.cpython-312.pyc differ diff --git a/fla/models/transformer_mtp/__pycache__/configuration_transformer.cpython-312.pyc b/fla/models/transformer_mtp/__pycache__/configuration_transformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bde6621bf8ce6362d931d614464e1dc5cad4181 Binary files /dev/null and b/fla/models/transformer_mtp/__pycache__/configuration_transformer.cpython-312.pyc differ diff --git a/fla/models/transformer_top/__pycache__/configuration_transformer.cpython-312.pyc b/fla/models/transformer_top/__pycache__/configuration_transformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c3e05dc9ee9f1c6d727066405640548d5ea922f Binary files /dev/null and b/fla/models/transformer_top/__pycache__/configuration_transformer.cpython-312.pyc differ diff --git a/fla/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc b/fla/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f36948489a1a9d3a629f0b2c1254aa91ff296dc5 Binary files /dev/null and b/fla/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc differ diff --git a/flame/__pycache__/__init__.cpython-312.pyc b/flame/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d75beec3a455f0b7b43ed2a233b5bbefe5547c0 Binary files /dev/null and b/flame/__pycache__/__init__.cpython-312.pyc differ diff --git a/flame/models/__init__.py b/flame/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/flame/models/parallelize_fla.py b/flame/models/parallelize_fla.py new file mode 100644 index 0000000000000000000000000000000000000000..37178af1bf365b3f5179cefc62000bf8f2f4ded3 --- /dev/null +++ b/flame/models/parallelize_fla.py @@ -0,0 +1,550 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D parallelisms (except pipeline parallelism) and various +# training techniques (e.g. activation checkpointing and compile) to the Llama model. + +from collections import defaultdict + +import torch +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard +from torch.distributed._composable.replicate import replicate +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + PrepareModuleOutput, + RowwiseParallel, + SequenceParallel, + parallelize_module +) + +from fla.modules.fused_linear_cross_entropy import LinearLossParallel +from fla.modules.mlp import SwiGLULinearParallel +from fla.modules.parallel import PrepareModuleWeight +from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig +from torchtitan.distributed.parallel_dims import ParallelDims +from torchtitan.tools.logging import logger + + +def parallelize_fla( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + if parallel_dims.tp_enabled: + if ( + job_config.experimental.enable_async_tensor_parallel + and not job_config.training.compile + ): + raise RuntimeError("Async TP requires --training.compile") + enable_float8_linear = "float8" in job_config.model.converters + apply_tp( + model, + world_mesh["tp"], + loss_parallel=parallel_dims.loss_parallel_enabled, + enable_float8=enable_float8_linear, + enable_async_tp=job_config.experimental.enable_async_tensor_parallel, + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + + # turn on per-block compile after AC wrapping and before FSDP + if job_config.training.compile: + apply_compile(model) + + if ( + parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled + ): # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + + apply_fsdp( + model, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + world_mesh, + enable_compile=job_config.training.compile, + enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, + ) + + +class TPPlan: + def __init__( + self, + model=None, + loss_parallel=False, + enable_float8=False, + ): + self.model = model + self.loss_parallel = loss_parallel + self.enable_float8 = enable_float8 + self.base_model_prefix = getattr(model, "base_model_prefix", "model") + + # TODO(vkuzo): once float8 configuration supports delayed scaling, + # add a check here to enforce supported float8 all-gather configurations + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there + try: + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput + ) + except ImportError: + Float8ColwiseParallel = None + Float8RowwiseParallel = None + PrepareFloat8ModuleInput = None + if self.enable_float8 and Float8ColwiseParallel is not None: + self.rowwise_parallel = Float8RowwiseParallel + self.colwise_parallel = Float8ColwiseParallel + self.prepare_module_input = PrepareFloat8ModuleInput + self.prepare_module_output = PrepareModuleOutput + else: + self.rowwise_parallel = RowwiseParallel + self.colwise_parallel = ColwiseParallel + self.prepare_module_input = PrepareModuleInput + self.prepare_module_output = PrepareModuleOutput + + @property + def model_plan(self): + plans = { + f"{self.base_model_prefix}.embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + f"{self.base_model_prefix}.norm": SequenceParallel(), + } + if self.loss_parallel: + plans.update( + { + "lm_head": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if self.loss_parallel else Replicate(), + use_local_output=not self.loss_parallel, + ), + } + ) + else: + plans.update( + { + "lm_head": PrepareModuleWeight(layouts=Replicate()), + "criterion": LinearLossParallel(), + } + ) + return plans + + @property + def layer_plan(self): + return { + "attn_norm": SequenceParallel(), + **self.attn_plan, + "mlp_norm": SequenceParallel(), + **self.mlp_plan, + } + + @property + def attn_plan(self): + raise NotImplementedError( + f"TP plans for token mixing layers of {self.model.config.model_type} not implemented" + ) + + @property + def mlp_plan(self): + return { + "mlp": self.prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "mlp.gate_proj": self.colwise_parallel(), + "mlp.up_proj": self.colwise_parallel(), + "mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)), + "mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)), + } + + +class TransformerTPPlan(TPPlan): + + @property + def attn_plan(self): + return { + "attn": self.prepare_module_input( + input_kwarg_layouts={"hidden_states": Shard(1)}, + desired_input_kwarg_layouts={"hidden_states": Replicate()}, + ), + "attn.q_proj": self.colwise_parallel(), + "attn.k_proj": self.colwise_parallel(), + "attn.v_proj": self.colwise_parallel(), + "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)), + } + + +class GLATPPlan(TPPlan): + + @property + def attn_plan(self): + return { + "attn": self.prepare_module_input( + input_kwarg_layouts={"hidden_states": Shard(1)}, + desired_input_kwarg_layouts={"hidden_states": Replicate()}, + ), + "attn.q_proj": self.colwise_parallel(), + "attn.k_proj": self.colwise_parallel(), + "attn.v_proj": self.colwise_parallel(), + "attn.g_proj": self.colwise_parallel(), + "attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()), + "attn.gk_proj.1": self.colwise_parallel(), + "attn.g_norm": SequenceParallel(sequence_dim=-1), + "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)), + } + + +TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan} + + +def apply_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + tp_plan = TP_PLAN_MAP[model.config.model_type]( + model, loss_parallel=loss_parallel, enable_float8=enable_float8 + ) + parallelize_module(model, tp_mesh, tp_plan.model_plan) + + blocks = get_blocks(model) + if blocks is None: + logger.warning("No block found for tensor parallelism") + else: + for _, block in enumerate(blocks): + parallelize_module( + module=block, + device_mesh=tp_mesh, + parallelize_plan=tp_plan.layer_plan, + ) + + if enable_async_tp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + + logger.info( + f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" + ) + + +# for selective op activation checkpointing +_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, +} + + +def _apply_ac_to_block(module: nn.Module, ac_config): + valid_ac_modes = ("full", "selective") + if ac_config.mode not in valid_ac_modes: + raise ValueError( + f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" + ) + + if ac_config.mode == "full": + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + + assert ac_config.mode == "selective", f"{ac_config.mode}" + use_op_sac = ac_config.selective_ac_option == "op" + use_layer_sac = ac_config.selective_ac_option.isdigit() + if not use_op_sac and not use_layer_sac: + raise ValueError( + f"Invalid selective AC option: {ac_config.selective_ac_option}. " + f"Valid options: 'op' or a positive int representing layer frequency" + ) + if use_op_sac: + from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts + + def _get_custom_policy(meta): + def _custom_policy(ctx, func, *args, **kwargs): + mode = "recompute" if ctx.is_recompute else "forward" + mm_count_key = f"{mode}_mm_count" + if func == torch.ops.aten.mm.default: + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + to_save = func in _save_list and not ( + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + ) + return ( + CheckpointPolicy.MUST_SAVE + if to_save + else CheckpointPolicy.PREFER_RECOMPUTE + ) + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = defaultdict(int) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + return ptd_checkpoint_wrapper( + module, + context_fn=selective_checkpointing_context_fn, + preserve_rng_state=False, + ) + elif use_layer_sac: + # Checkpoint every `ac_freq` of the modules passed to this function + ac_freq = int(ac_config.selective_ac_option) + ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) + ptd_checkpoint_wrapper._count += 1 + if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + else: + return module + + +def apply_ac(model: nn.Module, ac_config): + """Apply activation checkpointing to the model.""" + blocks = get_blocks(model) + if blocks is None: + logger.warning("No block found for activation checkpointing") + return + + for layer_id, block in blocks.named_children(): + block = _apply_ac_to_block(block, ac_config) + blocks.register_module(layer_id, block) + + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") + + +def apply_compile(model: nn.Module): + """ + Apply torch.compile to each block, which makes compilation efficient due to + repeated structure. Alternatively one can compile the whole model (after applying DP). + """ + + blocks = get_blocks(model) + if blocks is None: + logger.warning("No block found for torch.compile") + else: + for layer_id, block in blocks.named_children(): + block = torch.compile(block) + blocks.register_module(layer_id, block) + logger.info("Compiling each block with torch.compile") + + real_model = get_model(model) + + logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile") + embeddings_key = get_components_name(real_model, "tok_embeddings") + if embeddings_key is not None: + embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True) + real_model.register_module(embeddings_key, embeddings) + + norm_key = get_components_name(real_model, "norm") + if norm_key is not None: + norm = torch.compile(getattr(real_model, norm_key), fullgraph=True) + real_model.register_module(norm_key, norm) + + lm_head_key = get_components_name(model, "lm_head") + if lm_head_key is not None: + lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True) + model.register_module(lm_head_key, lm_head) + + logger.info("Compiling the entire model with torch.compile") + model = torch.compile(model) + + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, + cpu_offload: bool = False, + reshard_after_forward_policy: str = "default", +): + """ + Apply data parallelism (via FSDP2) to the model. + + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + param_dtype (torch.dtype): The data type to use for model parameters. + reduce_dtype (torch.dtype): The data type to use for reduction operations. + pp_enabled (bool): Whether pipeline parallelism is enabled. + cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. + reshard_after_forward_policy (str, optional): + The policy to use for resharding after forward pass. Defaults to "default". + Other options: "never", "always". + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. + - "always" will enable `reshard_after_forward` for all forward passes. + - "never" will disable `reshard_after_forward` for all forward passes. + + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() + + blocks = get_blocks(model) + if blocks is None: + logger.warning("No block found for FSDP") + else: + total_blocks = len(blocks) + for layer_id, block in enumerate(blocks): + if reshard_after_forward_policy == "always": + reshard_after_forward = True + elif reshard_after_forward_policy == "never": + reshard_after_forward = False + elif reshard_after_forward_policy == "default": + if pp_enabled: + # For PP, do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = False + else: + # As an optimization, do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately + reshard_after_forward = int(layer_id) < total_blocks - 1 + else: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) + fully_shard( + block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + + fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) + + +def apply_ddp( + model: nn.Module, + dp_mesh: DeviceMesh, + enable_compile: bool, + enable_compiled_autograd: bool, +): + if enable_compile: + if enable_compiled_autograd: + torch._dynamo.config.optimize_ddp = ( + "python_reducer_without_compiled_forward" + ) + else: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" + + replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + + logger.info("Applied DDP to the model") + + +def get_model(model): + base_model_prefix = getattr(model, "base_model_prefix", "model") + if not hasattr(model, base_model_prefix): + return None + model = getattr(model, base_model_prefix) + return model + + +def get_blocks(model): + # TODO[flame]: adapt for network not using 'layers' attribute + model = get_model(model) + if not hasattr(model, "layers"): + logger.warning('no "layers" in model can be found') + return None + return model.layers + + +def get_components_name(model, component_name): + """ + We try to catch tok_embeddings, norm layers and lm_head layers + We do not catch the layer names in the blocks, for blocks see `get_blocks` + We assume the model has the following structure: + LlamaForCausalLM: + Model: + embed_tokens, + layers, + norm, + lm_head + *** + so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)` + and for 'lm_head' we need to pass `model` + *** + """ + + if component_name == "tok_embeddings": + if hasattr(model, "tok_embeddings"): + return "tok_embeddings" + elif hasattr(model, "embed_tokens"): + return "embed_tokens" + elif hasattr(model, "embeddings"): + return "embeddings" + else: + logger.warning("No tok_embeddings found in model") + return None + + elif component_name == "norm": + if hasattr(model, "norm"): + return "norm" + elif hasattr(model, "norms"): + return "norms" + elif hasattr(model, "layernorm"): + return "layernorm" + else: + logger.warning("No norm found in model") + return None + + elif component_name == "lm_head": + if hasattr(model, "lm_head"): + return "lm_head" + else: + logger.warning("No lm_head found in model") + return None diff --git a/flame/tools/__init__.py b/flame/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/flame/tools/utils.py b/flame/tools/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a798ec243f6054aecf7878ad62a3f818f32faeca --- /dev/null +++ b/flame/tools/utils.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn +from torchtitan.tools.logging import logger + + +def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple[int, int]: + nparams = sum(p.numel() for p in model.parameters()) + nparams_embedding = sum( + sum(p.numel() for p in m.parameters()) + for m in model.children() + if isinstance(m, nn.Embedding) + ) + + if hasattr(model_config, "num_heads"): + num_heads = model_config.num_heads + elif hasattr(model_config, "num_attention_heads"): + num_heads = model_config.num_attention_heads + else: + num_heads = 1 + logger.warning("num_heads not found in model_config, defaulting to 1. ") + + l, h, q, t = ( + model_config.num_hidden_layers, + num_heads, + model_config.hidden_size // num_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t + + return nparams, num_flops_per_token diff --git a/flame/utils/checkpoint.py b/flame/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..839ac7df075c3bfca6747855781953c8a82a4c28 --- /dev/null +++ b/flame/utils/checkpoint.py @@ -0,0 +1,50 @@ +import os +import glob +import re +import shutil +from torchtitan.tools.logging import logger + + +def cleanup_local_checkpoints(checkpoint_dir: str, keep_latest_k: int): + """Removes older checkpoint directories locally, keeping only the latest k for both DCP and HF formats.""" + if keep_latest_k <= 0: + return # Keep all checkpoints + + logger.info(f"Cleaning up local checkpoints in {checkpoint_dir}, keeping latest {keep_latest_k}") + + # Cleanup DCP checkpoints (step-*) + dcp_checkpoints = sorted( + glob.glob(os.path.join(checkpoint_dir, "step-*")), + key=lambda x: int(re.search(r"step-(\d+)", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)", os.path.basename(x)) and not x.endswith("-hf") else -1, + reverse=True + ) + # Filter out HF format directories + dcp_checkpoints = [d for d in dcp_checkpoints if not d.endswith("-hf")] + + if len(dcp_checkpoints) > keep_latest_k: + checkpoints_to_delete = dcp_checkpoints[keep_latest_k:] + logger.info(f"Deleting {len(checkpoints_to_delete)} old DCP checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}") + for ckpt_path in checkpoints_to_delete: + if os.path.isdir(ckpt_path): # Ensure it's a directory + try: + shutil.rmtree(ckpt_path) + except OSError as e: + logger.error(f"Error removing directory {ckpt_path}: {e}") + + + # Cleanup HF checkpoints (step-*-hf) + hf_checkpoints = sorted( + glob.glob(os.path.join(checkpoint_dir, "step-*-hf")), + key=lambda x: int(re.search(r"step-(\d+)-hf", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)-hf", os.path.basename(x)) else -1, + reverse=True + ) + + if len(hf_checkpoints) > keep_latest_k: + checkpoints_to_delete = hf_checkpoints[keep_latest_k:] + logger.info(f"Deleting {len(checkpoints_to_delete)} old HF checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}") + for ckpt_path in checkpoints_to_delete: + if os.path.isdir(ckpt_path): # Ensure it's a directory + try: + shutil.rmtree(ckpt_path) + except OSError as e: + logger.error(f"Error removing directory {ckpt_path}: {e}") diff --git a/flame/utils/convert_hf_to_dcp.py b/flame/utils/convert_hf_to_dcp.py new file mode 100644 index 0000000000000000000000000000000000000000..bab94ebf80ea8822139b851e0c64b95854c2e78b --- /dev/null +++ b/flame/utils/convert_hf_to_dcp.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import argparse +from pathlib import Path + +import torch +import torch.distributed.checkpoint as DCP +from transformers import AutoModelForCausalLM + +import fla # noqa +from torchtitan.tools.logging import init_logger, logger + + +@torch.inference_mode() +def convert_hf_weights(model: str, checkpoint: str): + logger.info(f"Loading model from {model}") + model = AutoModelForCausalLM.from_pretrained(model) + state_dict = model.state_dict() + + logger.info(f"Writing to DCP at '{checkpoint}'") + checkpoint.mkdir(parents=True, exist_ok=True) + storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8) + DCP.save({"model": state_dict}, storage_writer=storage_writer) + + +if __name__ == "__main__": + init_logger() + parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.") + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--checkpoint", type=Path, required=True) + args = parser.parse_args() + + convert_hf_weights(args.model, args.checkpoint) diff --git a/flame/utils/hf_utils.py b/flame/utils/hf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c8954965dbde4c33131bdf05811fdf803c247168 --- /dev/null +++ b/flame/utils/hf_utils.py @@ -0,0 +1,77 @@ +import os +import re +from huggingface_hub import HfApi, HfFolder, logging as hf_logging, create_repo +from torchtitan.tools.logging import logger + +def upload_checkpoint_to_hf( + local_path: str, + step: int, + hf_repo_id_for_run: str, + hf_keep_latest_k: int, + upload_format: str +): + """Uploads a checkpoint directory to HF Hub and manages retention.""" + if not os.path.isdir(local_path): + logger.error(f"Local path for upload does not exist or is not a directory: {local_path}") + return + + api = HfApi() + token = HfFolder.get_token() + if not token: + logger.warning("Hugging Face Hub token not found. Skipping upload. Login via `huggingface-cli login` or set HF_TOKEN.") + return + + # --- Ensure the specific repository for this run exists --- + try: + logger.info(f"Ensuring repository {hf_repo_id_for_run} exists...") + # Use create_repo which handles creation only if it doesn't exist + create_repo(repo_id=hf_repo_id_for_run, token=token, repo_type="model", exist_ok=True) + logger.info(f"Repository {hf_repo_id_for_run} ensured.") + except Exception as e: + logger.error(f"Failed to create or ensure repository {hf_repo_id_for_run}: {e}", exc_info=True) + return # Stop if repo interaction fails + + commit_message = f"Upload {upload_format.upper()} checkpoint step {step}" + path_in_repo = f"step-{step}" + + logger.info(f"Uploading {local_path} to {hf_repo_id_for_run}/{path_in_repo} on Hugging Face Hub...") + try: + api.upload_folder( + folder_path=local_path, + path_in_repo=path_in_repo, + repo_id=hf_repo_id_for_run, + repo_type="model", + commit_message=commit_message, + token=token, + ) + logger.info(f"Successfully uploaded step {step} to {hf_repo_id_for_run}.") + except Exception as e: + logger.error(f"Failed to upload checkpoint step {step} to {hf_repo_id_for_run}: {e}", exc_info=True) + if hf_keep_latest_k > 0: + logger.info(f"Cleaning up old checkpoints on {hf_repo_id_for_run}, keeping latest {hf_keep_latest_k}") + try: + repo_files = api.list_repo_tree(hf_repo_id_for_run, repo_type="model", token=token, recursive=False) + step_folders = [ + item.path for item in repo_files + if item.path.startswith("step-") and item.path[5:].isdigit() + ] + + step_folders.sort(key=lambda x: int(x.split('-')[1]), reverse=True) + + if len(step_folders) > hf_keep_latest_k: + folders_to_delete = step_folders[hf_keep_latest_k:] + logger.info(f"Found {len(step_folders)} checkpoints on Hub. Deleting {len(folders_to_delete)} older ones: {folders_to_delete}") + for folder in folders_to_delete: + # Deleting requires repo_id, path_in_repo, and token + api.delete_folder( + repo_id=hf_repo_id_for_run, + path_in_repo=folder, + repo_type="model", + commit_message=f"Delete old checkpoint {folder}", + token=token + ) + logger.info("Hub cleanup complete.") + else: + logger.info("No old checkpoints found on Hub to delete.") + except Exception as e: + logger.error(f"Error during Hub checkpoint cleanup for {hf_repo_id_for_run}: {e}", exc_info=True) \ No newline at end of file diff --git a/torchtitan/components/__pycache__/dataloader.cpython-312.pyc b/torchtitan/components/__pycache__/dataloader.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..407ba2d683099d718b043e019ed6017fcd232ea8 Binary files /dev/null and b/torchtitan/components/__pycache__/dataloader.cpython-312.pyc differ diff --git a/torchtitan/components/__pycache__/ft.cpython-312.pyc b/torchtitan/components/__pycache__/ft.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..845b560a7d71d551b23f4639cc4d3278f09a8d16 Binary files /dev/null and b/torchtitan/components/__pycache__/ft.cpython-312.pyc differ diff --git a/torchtitan/components/__pycache__/loss.cpython-312.pyc b/torchtitan/components/__pycache__/loss.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3812ee9d6e327c0667af5341c8d457d566d66951 Binary files /dev/null and b/torchtitan/components/__pycache__/loss.cpython-312.pyc differ diff --git a/torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc b/torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b09cd86b3cef0a24b8e8712b594df1f9affc2eb Binary files /dev/null and b/torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc differ diff --git a/torchtitan/components/__pycache__/tokenizer.cpython-312.pyc b/torchtitan/components/__pycache__/tokenizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66e070e1e9b18a6ff9e7aa136da66fc234a3da6d Binary files /dev/null and b/torchtitan/components/__pycache__/tokenizer.cpython-312.pyc differ diff --git a/torchtitan/datasets/__pycache__/hf_datasets.cpython-312.pyc b/torchtitan/datasets/__pycache__/hf_datasets.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..888bd61e5e804186e47e6043cb5fe6b357af8389 Binary files /dev/null and b/torchtitan/datasets/__pycache__/hf_datasets.cpython-312.pyc differ diff --git a/torchtitan/datasets/tokenizer/__pycache__/tiktoken.cpython-312.pyc b/torchtitan/datasets/tokenizer/__pycache__/tiktoken.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d792d34a13bcfdf35095d1d6972b9cdd3edefd4 Binary files /dev/null and b/torchtitan/datasets/tokenizer/__pycache__/tiktoken.cpython-312.pyc differ diff --git a/torchtitan/datasets/tokenizer/tiktoken.py b/torchtitan/datasets/tokenizer/tiktoken.py new file mode 100644 index 0000000000000000000000000000000000000000..401757a93e6b598a6a3a60c4ca934ea0427f25a4 --- /dev/null +++ b/torchtitan/datasets/tokenizer/tiktoken.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import os +from collections.abc import Collection, Iterator, Sequence, Set as AbstractSet +from pathlib import Path +from typing import cast, Literal + +import tiktoken +from tiktoken.load import load_tiktoken_bpe + +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger + + +class TikTokenizer(Tokenizer): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + + special_tokens: dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950 + + def __init__(self, model_path: str): + super().__init__() + assert os.path.exists( + model_path + ), f"The tokenizer path does not exist: {model_path}" + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [ + f"<|reserved_special_token_{i}|>" + for i in range(5, self.num_reserved_special_tokens - 5) + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + self._n_words: int = self.model.n_vocab + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.pad_id: int = -1 + self.stop_tokens = { + self.special_tokens["<|end_of_text|>"], + self.special_tokens["<|eot_id|>"], + } + logger.info( + f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}" + ) + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Literal["all"] | AbstractSet[str] | None = None, + disallowed_special: Literal["all"] | Collection[str] | None = None, + ) -> list[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_tokens ("all"|set[str]): allowed special tokens in string + disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + assert type(s) is str + allowed_special = allowed_special or set() + disallowed_special = disallowed_special or () + + # The tiktoken tokenizer can handle <=400k chars without + # pyo3_runtime.PanicException. + TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + + # https://github.com/openai/tiktoken/issues/195 + # Here we iterate over subsequences and split if we exceed the limit + # of max consecutive non-whitespace or whitespace characters. + MAX_NO_WHITESPACES_CHARS = 25_000 + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: list[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(list[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + + +def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer: + return TikTokenizer(job_config.model.tokenizer_path) diff --git a/torchtitan/distributed/__pycache__/__init__.cpython-312.pyc b/torchtitan/distributed/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8de0cb1651c84e3a819fa7bcdf99d4d7aea77b03 Binary files /dev/null and b/torchtitan/distributed/__pycache__/__init__.cpython-312.pyc differ diff --git a/torchtitan/distributed/__pycache__/parallel_dims.cpython-312.pyc b/torchtitan/distributed/__pycache__/parallel_dims.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86f44c9aa1be4e74e8dcb84939e4353b35b74aee Binary files /dev/null and b/torchtitan/distributed/__pycache__/parallel_dims.cpython-312.pyc differ diff --git a/torchtitan/distributed/__pycache__/pipeline.cpython-312.pyc b/torchtitan/distributed/__pycache__/pipeline.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..659b173cef95361d2a86a7a7da37b850d6bd1f13 Binary files /dev/null and b/torchtitan/distributed/__pycache__/pipeline.cpython-312.pyc differ diff --git a/torchtitan/experiments/deepseek_v3/attn_mask_utils.py b/torchtitan/experiments/deepseek_v3/attn_mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a54899c34e021a43c8a7e090d854140afa8f9e7 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/attn_mask_utils.py @@ -0,0 +1,397 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This code is based on src/transformers/modeling_attn_mask_utils.py of +# huggingface/transformers. It has been modified from its original forms to +# contain only the necessary utilities. + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + + +@dataclass +class AttentionMaskConverter: + """ + A utility attention mask class that allows one to: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Examples: + + ```python + >>> import torch + >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + >>> converter = AttentionMaskConverter(True) + >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32) + tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38], + [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]]) + ``` + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. + """ + + is_causal: bool + sliding_window: int + + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + + if self.sliding_window is not None and self.sliding_window <= 0: + raise ValueError( + "Make sure that when passing `sliding_window` that its value is a strictly positive integer, " + f"not `{self.sliding_window}`" + ) + + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype, + device: Union[torch.device, "str"] = "cpu", + ) -> Optional[torch.Tensor]: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError( + f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True." + ) + + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + dtype: torch.dtype, + key_value_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + if key_value_length is None: + raise ValueError( + "This attention mask converter is causal. Make sure to pass " + "`key_value_length` to correctly create a causal mask." + ) + + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError( + "Sliding window is currently only implemented for causal masking" + ) + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask( + attention_mask_2d, dtype, tgt_len=input_shape[-1] + ).to(attention_mask_2d.device) + + if causal_4d_mask is not None: + expanded_attn_mask = causal_4d_mask.masked_fill( + expanded_attn_mask.bool(), torch.finfo(dtype).min + ) + + # expanded_attn_mask + causal_4d_mask can cause some overflow + expanded_4d_mask = expanded_attn_mask + + return expanded_4d_mask + + @staticmethod + def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window - 1 + + context_mask = torch.tril( + torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal + ) + mask.masked_fill_(context_mask, torch.finfo(dtype).min) + + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + @staticmethod + def _expand_mask( + mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None + ): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = ( + mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + ) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + @staticmethod + def _unmask_unattended( + expanded_mask: torch.FloatTensor, + min_dtype: float, + ): + # fmt: off + """ + Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when + using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + Details: https://github.com/pytorch/pytorch/issues/110213 + + `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. + `attention_mask` is [bsz, src_seq_len]. + + The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case + of alibi attention bias. + + For example, if `expanded_mask` is (e.g. here left-padding case) + ``` + [[[[0, 0, 0], + [0, 0, 0], + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[0, 0, 0], + [0, 1, 0], + [0, 1, 1]]]] + ``` + then the modified `expanded_mask` will be + ``` + [[[[1, 1, 1], <-- modified + [1, 1, 1], <-- modified + [0, 0, 1]]], + [[[1, 0, 0], + [1, 1, 0], + [1, 1, 1]]], + [[[1, 1, 1], <-- modified + [0, 1, 0], + [0, 1, 1]]]] + ``` + """ + # fmt: on + if expanded_mask.dtype == torch.bool: + raise ValueError( + "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor." + ) + + return expanded_mask.mul( + ~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True) + ) + + @staticmethod + def _ignore_causal_mask_sdpa( + attention_mask: Optional[torch.Tensor], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, + is_training: bool = False, + ) -> bool: + """ + Detects whether the optional user-specified attention_mask & the automatically created causal mask can be + ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. + + In case no token is masked in the `attention_mask` argument, if `query_length == 1` or + `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, + allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is + passed). + """ + + _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] + key_value_length = query_length + past_key_values_length + + is_tracing = ( + torch.jit.is_tracing() + or isinstance(inputs_embeds, torch.fx.Proxy) + or is_torchdynamo_compiling() + ) + + ignore_causal_mask = False + + if attention_mask is None: + # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input + # shape, thus SDPA's `is_causal` argument is rightfully updated + # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using + # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is + # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` + # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). + # Thus, we only set `ignore_causal_mask = True` if the model is set to training. + # + # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` + # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor"). + if ( + (is_training or not is_tracing) + and (query_length == 1 or key_value_length == query_length) + and (sliding_window is None or key_value_length < sliding_window) + ): + ignore_causal_mask = True + elif sliding_window is None or key_value_length < sliding_window: + if len(attention_mask.shape) == 4: + return False + elif not is_tracing and torch.all(attention_mask == 1): + if query_length == 1 or key_value_length == query_length: + # For query_length == 1, causal attention and bi-directional attention are the same. + ignore_causal_mask = True + + # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore + # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in + # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. + # Reference: https://github.com/pytorch/pytorch/issues/108108 + # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3. + + return ignore_causal_mask + + +def _prepare_4d_causal_attention_mask( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + attention_mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + inputs_embeds (`torch.Tensor`): + The embedded inputs as a torch Tensor. + past_key_values_length (`int`): + The length of the key value cache. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter( + is_causal=True, sliding_window=sliding_window + ) + + key_value_length = input_shape[-1] + past_key_values_length + + # 4d mask is passed through the layers + if attention_mask is not None and len(attention_mask.shape) == 2: + attention_mask = attn_mask_converter.to_4d( + attention_mask, + input_shape[-1], + key_value_length=key_value_length, + dtype=inputs_embeds.dtype, + ) + elif attention_mask is not None and len(attention_mask.shape) == 4: + expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + else: + # if the 4D mask has correct shape - invert it and fill with negative infinity + inverted_mask = 1.0 - attention_mask + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min + ) + else: + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], + input_shape[-1], + key_value_length, + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ) + + return attention_mask diff --git a/torchtitan/experiments/deepseek_v3/download.py b/torchtitan/experiments/deepseek_v3/download.py new file mode 100644 index 0000000000000000000000000000000000000000..0b9ec3104d716cbd6142c6564d83f042f128770f --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/download.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Usage: +# Downloads a given model to the HF Cache. Pass in a listed option ala "v3" or your own custom model path. +# python download.py {model_id} [custom_model_path] +# Examples: +# python download.py v2 # Use predefined model: deepseek-ai/DeepSeek-V2 +# python download.py custom "deepseek-ai/new-model" # Download a custom model path + +# Available models: +# "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat", +# "v2-lite": "deepseek-ai/DeepSeek-V2-Lite", +# "v2": "deepseek-ai/DeepSeek-V2", +# "v3": "deepseek-ai/deepseek-v3", +# "v3-0324": "deepseek-ai/DeepSeek-V3-0324", +# "custom": None, # Placeholder for custom models + + +import sys + +from transformers import AutoModelForCausalLM + + +MODELS = { + "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat", + "v2-lite": "deepseek-ai/DeepSeek-V2-Lite", + "v2": "deepseek-ai/DeepSeek-V2", + "v3": "deepseek-ai/deepseek-v3", + "v3-0324": "deepseek-ai/DeepSeek-V3-0324", + "custom": None, # For custom (any) models +} + + +def print_usage(): + print("Usage:") + print(" python download.py [model_version]") + print(" python download.py custom [custom_model_path]") + print("\nAvailable predefined models:") + for key, model in MODELS.items(): + if key != "custom": # Skip the custom placeholder + print(f" {key}: {model}") + print("\nFor custom models:") + print(" custom: Specify your own model path") + print(' Example: python download.py custom "organization/model-name"') + sys.exit(1) + + +# Process command line arguments +if len(sys.argv) < 2 or sys.argv[1] not in MODELS: + print_usage() + +if sys.argv[1] == "custom": + if len(sys.argv) != 3: + print("Error: Custom model requires a model path") + print_usage() + model_id = sys.argv[2] + print(f"Using custom model: {model_id}") +else: + model_id = MODELS[sys.argv[1]] +print(f"Downloading model: {model_id}") + +model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map="auto", + trust_remote_code=True, +) diff --git a/torchtitan/experiments/deepseek_v3/inference.sh b/torchtitan/experiments/deepseek_v3/inference.sh new file mode 100644 index 0000000000000000000000000000000000000000..afbab8f20d54f38521e5e2683e502396102e1172 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/inference.sh @@ -0,0 +1,15 @@ + +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +NGPU=${NGPU:-"4"} + +# Get the prompt from command line argument or use a default +prompt="${1:-What is 2+2?}" + +# Run the model with the prompt +torchrun --standalone --nproc-per-node ${NGPU} generate.py "$prompt" diff --git a/torchtitan/experiments/deepseek_v3/model.py b/torchtitan/experiments/deepseek_v3/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0669df9528b3db0de3325db36f010312b5b3eac7 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/model.py @@ -0,0 +1,1325 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This code is based on model definition of `deepseek-ai/DeepSeek-V3-Base` on +# Hugging Face Model Hub. Url: +# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py +# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/resolve/main/configuration_deepseek.py +# +# It has been modified from its original forms to accommodate naming convention +# and usage patterns of the TorchTitan project. + +# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeepSeek model.""" +import math +from typing import Optional, Tuple + +import torch +import torch.distributed as dist + +import torch.distributed._symmetric_memory as symm_mem +import torch.nn.functional as F +import torch.utils.checkpoint + +from attn_mask_utils import _prepare_4d_causal_attention_mask +from indices import generate_permute_indices +from model_config import ModelArgs +from symm_mem_recipes import OnDeviceAllToAllV +from torch import nn +from torch.distributed._functional_collectives import all_to_all_single_autograd + +from torchtitan.experiments.kernels.triton_mg_group_gemm.torchao_pr import ( + ALIGN_SIZE_M, + grouped_gemm_forward, +) + +# Get model parallel subgroup by name: +# e.g. "pp", "ep", None +def get_group(dim_name: Optional[str] = None) -> dist.ProcessGroup: + glob = torch.distributed.device_mesh._mesh_resources.get_current_mesh() + return glob.get_group(dim_name) + + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Deepseek +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class YarnRotaryEmbedding(RotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MLP(nn.Module): + act_fn = nn.SiLU() + + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim)) + ) + if self.topk_method == "noaux_tc": + self.e_score_correction_bias = nn.Parameter( + # Changed from torch.empty to torch.rand to avoid non-even + # distribution for runs without actual weigths + torch.rand((self.n_routed_experts)) + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + # compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear( + hidden_states.type(torch.float32), self.weight.type(torch.float32), None + ) + if self.scoring_func == "sigmoid": + scores = logits.sigmoid() + elif self.scoring_func == "softmax": + scores = logits.softmax(dim=-1, dtype=torch.float32) + else: + raise NotImplementedError( + f"insupportable scoring function for MoE gating: {self.scoring_func}" + ) + + # select top-k experts + if self.topk_method == "noaux_tc": + scores_for_choice = scores.view( + bsz * seq_len, -1 + ) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(bsz * seq_len, self.n_group, -1) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group + ) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores_for_choice.masked_fill( + ~score_mask.bool(), 0.0 + ) # [n, e] + _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) + topk_weight = scores.gather(1, topk_idx) + elif self.topk_method == "greedy": + topk_weight, topk_idx = torch.topk( + scores, k=self.top_k, dim=-1, sorted=False + ) + else: + raise NotImplementedError( + f"insupportable TopK function for MoE gating: {self.topk_method}" + ) + + # norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + topk_weight = ( + topk_weight * self.routed_scaling_factor + ) # must multiply the scaling factor + + return topk_idx, topk_weight + + +class MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + # Class attributes: + # Two shuffle method supported: + # 1. "torch_all_to_all" + # 2. "symm_mem" (see `setup_symm_mem` below) + shuffle_method = "torch_all_to_all" + + # Symmetric memory buffers shared by all MoE instances across layers + token_send_buf: Optional[torch.Tensor] = None + token_gather_buf: Optional[torch.Tensor] = None + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + # ep_size is the number of ranks in expert dimension + if config.ep_size <= 1: + raise ValueError( + "For code simplicity, this model only supports distributed experts, " + "thus EP size must be > 1, please modify your model config" + ) + self.ep_group = get_group("ep") + assert config.ep_size == self.ep_group.size() + self.ep_size = config.ep_size + self.ep_rank = self.ep_group.rank() + self.experts_per_rank = config.n_routed_experts // config.ep_size + # Use ModuleDict instead of ModuleList to preserve absoulte expert + # IDs while avoiding `None` experts. The absolute expert IDs match + # with checkpoint FQNs. + self.experts = nn.ModuleDict() + for i in range(self.experts_per_rank): + abs_expert_id = self.ep_rank * self.experts_per_rank + i + self.experts[str(abs_expert_id)] = MLP( + config, intermediate_size=config.moe_intermediate_size + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = MLP( + config=config, intermediate_size=intermediate_size + ) + + def combine_experts(self, submod_name): + all_weights = [] + for expert in self.experts.values(): + lin = expert.get_submodule(submod_name) + all_weights.append(lin.weight) + lin.weight = None + + concat_weight = torch.cat(all_weights) + self.register_parameter(f"{submod_name}_weight", nn.Parameter(concat_weight)) + + # This function is used to create a symm mem buffer for MoE's. It is for + # shuffling tokens fully "on-device", as compared to traditional torch + # all_to_all APIs which requrie a GPU-to-CPU sync of the splits. If a user + # calls this function, the `shuffle_method` would switch from + # `torch_all_to_all` to `symm_mem`. + def setup_symm_mem(self, dtype: torch.dtype, device: torch.device): + # Switch shuffle method + self.shuffle_method = "symm_mem" + + # Combine expert weights + print("Combining expert weights for Group GEMM") + self.combine_experts("gate_proj") + self.combine_experts("up_proj") + self.combine_experts("down_proj") + + # Assuming worst case, 2x tokens are routed to one EP rank + overflow = 2 + OnDeviceAllToAllV.max_output_len = ( + self.config.max_seq_len * self.num_experts_per_tok * overflow + ) + + # Symmetric memory buffers are shared by all MoE instances across + # layers, we only need to initialize them once + if MoE.token_send_buf is not None: + return + + # Input buffer for DP-to-EP shuffle + MoE.token_send_buf = symm_mem.empty( + self.config.max_seq_len + * self.num_experts_per_tok, # seq len * top k (flattened) + self.config.hidden_size, # hidden dim + dtype=dtype, + device=device, + ) + # Input buffer for EP-to-DP shuffle + MoE.token_gather_buf = symm_mem.empty( + self.config.max_seq_len + * self.num_experts_per_tok # seq len * top k (flattened) + * overflow, + self.config.hidden_size, # hidden dim + dtype=dtype, + device=device, + ) + print(f"EP rank [{self.ep_rank}]: Created Symmetric Memory for MoE") + + def get_send_buf(self): + # [Why detach?] During a first forward-backward step, the buffer would + # be included in a computational graph. In a second step, autograd will + # return an error saying "Trying to backward through the graph a second + # time (or directly access saved tensors more than once)". This is + # because the buffer is still in the graph, and autograd is trying to + # backward through the graph a second time. To avoid this, we detach the + # buffer from the graph. `detach()` returns a new tensor, which shares + # the same storage with the original one. + self.token_send_buf.grad = None + return self.token_send_buf.detach() + + def get_gather_buf(self): + # See [Why detach?] in `get_send_buf` + self.token_gather_buf.grad = None + return self.token_gather_buf.detach() + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + # for each token, select top-k experts, and compute the weight for each expert + topk_idx, topk_weight = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + if self.shuffle_method == "symm_mem": + y = self.moe_on_device(hidden_states, topk_idx, topk_weight) + else: # "torch_all_to_all" + y = self.moe_forward(hidden_states, topk_idx, topk_weight) + + y = y.view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y + + def moe_forward(self, x, topk_ids, topk_weight): + # This part sorts the token indices so that tokens routed to the same expert reside consecutively. + # An implication is that tokens to the same "expert group" (i.e., device) are also consecutive. + # Since this is an "aritificial" index creation (final outcome being + # `idxs`), we don't need gradients here. + with torch.no_grad(): + # [seq_len, n_routed_experts] + cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts)) + # Fill 1 to the selected experts + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + # Token indices for each expert + idxs = topk_ids.view(-1).argsort() + sorted_tokens_shape = idxs.shape + x.shape[1:] + + sorted_tokens = x[idxs // topk_ids.shape[1]] + assert sorted_tokens.shape == sorted_tokens_shape + + # This part exchange the information about the number of tokens send and + # received by each expert. We can understand this information as "side + # band", which is not part of the actual data. Thus no gradient is + # needed. + with torch.no_grad(): + # Sum the tokens over local experts, then we get tokens per EP rank, + # which is the input splits + tokens_per_expert_group = tokens_per_expert.new_empty( + tokens_per_expert.shape[0] + ) + dist.all_to_all_single( + tokens_per_expert_group, tokens_per_expert, group=self.ep_group + ) + input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + + # DP to EP token shuffle. This part needs gradient. + if self.shuffle_method == "symm_mem": + # Move input to the `token_send_buf` symm mem + token_send_buf = self.get_send_buf() + token_send_buf[: idxs.shape[0]].copy_(sorted_tokens) + # Note: `out=` avoids copy, but it is not differentiable + # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]]) + token_gather_buf, output_splits = OnDeviceAllToAllV.apply( + token_send_buf, + input_splits, + self.ep_group, + ) + with torch.no_grad(): + # Received tokens from all other ranks. TODO: use mask instead + received = output_splits.sum() + # TODO: don't use `received` + gathered_tokens = token_gather_buf[:received] + else: # "torch_all_to_all" + # Prepare input ans output splits + with torch.no_grad(): + output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum( + dim=1 + ) + gathered_tokens = all_to_all_single_autograd( + sorted_tokens, + output_splits.tolist(), + input_splits.tolist(), + self.ep_group, + ) + + # This part prepares a 1D tensor with the same length as + # `gathered_tokens`. The 1D tensor is filled with local expert IDs which + # the tokens in `gathered_tokens` are headed for. This part doesn't need + # gradient. + with torch.no_grad(): + gatherd_idxs = ( + torch.arange( + tokens_per_expert_group.numel(), + device=tokens_per_expert_group.device, + ) + % self.experts_per_rank + ) + gatherd_idxs = gatherd_idxs.repeat_interleave(tokens_per_expert_group) + + # Prepare buffer for tokens processed by experts + if self.shuffle_method == "symm_mem": + # Take necessary space from `token_gather_buf` symm mem because we are + # going to send them out after expert processing + processed_tokens = self.get_gather_buf()[: gathered_tokens.shape[0]] + else: # "torch_all_to_all" + processed_tokens = torch.empty_like(gathered_tokens) + + # This part processes the tokens routed to the local experts. + # TODO: can we use group GEMM here? + for i, expert in enumerate(self.experts.values()): + processed_tokens[gatherd_idxs == i] = expert( + gathered_tokens[gatherd_idxs == i] + ) + + # Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle. + # The input/output splits are just a reverse of the previous shuffle. + if self.shuffle_method == "symm_mem": + token_return_buf, _ = OnDeviceAllToAllV.apply( + processed_tokens, + output_splits, + self.ep_group, + ) + returned_tokens = token_return_buf[: sorted_tokens_shape[0]] + else: # "torch_all_to_all" + returned_tokens = all_to_all_single_autograd( + processed_tokens, + input_splits.tolist(), + output_splits.tolist(), + self.ep_group, + ) + + output_tokens = torch.empty_like(returned_tokens) + output_tokens[idxs] = returned_tokens + final_out = ( + output_tokens.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(returned_tokens.dtype) + ) + return final_out + + def moe_on_device(self, x, topk_ids, topk_weight): + # This part sorts the token indices so that tokens routed to the same expert reside consecutively. + # An implication is that tokens to the same "expert group" (i.e., device) are also consecutive. + # Since this is an "aritificial" index creation (final outcome being + # `idxs`), we don't need gradients here. + with torch.no_grad(): + # [seq_len, n_routed_experts] + cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts)) + # Fill 1 to the selected experts + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + # Token indices for each expert + idxs = topk_ids.view(-1).argsort() + sorted_tokens_shape = idxs.shape + x.shape[1:] + + sorted_tokens = x[idxs // topk_ids.shape[1]] + assert sorted_tokens.shape == sorted_tokens_shape + + # This part exchange the information about the number of tokens send and + # received by each expert. We can understand this information as "side + # band", which is not part of the actual data. Thus no gradient is + # needed. + with torch.no_grad(): + # Sum the tokens over local experts, then we get tokens per EP rank, + # which is the input splits + tokens_per_expert_group = tokens_per_expert.new_empty( + tokens_per_expert.shape[0] + ) + dist.all_to_all_single( + tokens_per_expert_group, tokens_per_expert, group=self.ep_group + ) + input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + + # Move input to the `token_send_buf` symm mem + token_send_buf = self.get_send_buf() + token_send_buf[: idxs.shape[0]].copy_(sorted_tokens) + # Note: `out=` avoids copy, but it is not differentiable + # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]]) + token_gather_buf, output_splits = OnDeviceAllToAllV.apply( + token_send_buf, + input_splits, + self.ep_group, + ) + + # We need to permute the received tokens so that tokens for the same expert are contiguous. + # This part prepares a 1D tensor `permuted_indices` for such permutation. + # This part doesn't need gradient. + with torch.no_grad(): + permuted_indices, m_sizes = generate_permute_indices( + tokens_per_expert_group, + self.experts_per_rank, + self.ep_size, + token_gather_buf.shape[0], + ALIGN_SIZE_M, + ) + + # Permute the received tokens so that tokens for the same expert are contiguous. + contig_tokens = token_gather_buf[permuted_indices] + + # Run the first grouped GEMM + w1 = self.get_parameter("gate_proj_weight") + gate_proj = grouped_gemm_forward(contig_tokens, w1, m_sizes) + + # Run the second grouped GEMM + w3 = self.get_parameter("up_proj_weight") + up_proj = grouped_gemm_forward(contig_tokens, w3, m_sizes) + + # Apply activation + hidden_outputs = MLP.act_fn(gate_proj) * up_proj + + # Run the third grouped GEMM + w2 = self.get_parameter("down_proj_weight") + hidden_outputs = grouped_gemm_forward(hidden_outputs, w2, m_sizes) + + # Prepare buffer for tokens processed by experts + # Take necessary space from `token_gather_buf` symm mem because we are + # going to send them out after expert processing + processed_tokens = self.get_gather_buf() + + # Move into Symmetric Memory for the return shuffle + processed_tokens[permuted_indices] = hidden_outputs + + # Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle. + # The input/output splits are just a reverse of the previous shuffle. + token_return_buf, _ = OnDeviceAllToAllV.apply( + processed_tokens, + output_splits, + self.ep_group, + ) + returned_tokens = token_return_buf[: sorted_tokens_shape[0]] + + output_tokens = torch.empty_like(returned_tokens) + output_tokens[idxs] = returned_tokens + final_out = ( + output_tokens.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(returned_tokens.dtype) + ) + return final_out + + +class Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: ModelArgs, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + self._init_rope() + + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if attention_mask is not None: + # Attention mask was made 4D because the `attn_weights` above is 4D. + # We probably can make this mask smarter if we want to pack sequences + # together, instead of using padding. This optimization can be used in + # inference. For training, if we want to pack sequences, data loader + # will pass in a mask containing such info. + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, # None, or user provided mask in 2D + (bsz, q_len), + hidden_states, + 0, # past_key_values_length, 0 when training + ) + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query=query_states, + key=key_states, + value=value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout, + is_causal=attention_mask is None, + scale=self.softmax_scale, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output + + +class DecoderLayer(nn.Module): + def __init__(self, config: ModelArgs, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Attention(config=config, layer_idx=layer_idx) + + self.mlp = ( + MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else MLP(config) + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +Deepseek_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class DeepseekModel(torch.nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DecoderLayer`] + + Args: + config: ModelArgs + """ + + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Creating model parts related to my stage + assert ( + config.stage_idx < config.num_stages + ), f"Stage {config.stage_idx} is not in the model" + print(f"Creating model stage {config.stage_idx} of {config.num_stages}") + + self.embed_tokens = ( + nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + if config.stage_idx == 0 + else None + ) + + self.layers = torch.nn.ModuleDict() + division = config.num_hidden_layers // config.num_stages + residual = config.num_hidden_layers % config.num_stages + # Some earlier stages may have 1 more layer than latter stages because + # the division may have residual; this is more even than giving the + # entire residual to the last stage. + layers_per_stage = [ + division + 1 if stage < residual else division + for stage in range(config.num_stages) + ] + assert sum(layers_per_stage) == config.num_hidden_layers + layer_id_start = sum(layers_per_stage[: config.stage_idx]) + layer_id_end = layer_id_start + layers_per_stage[config.stage_idx] + for layer_id in range(layer_id_start, layer_id_end): + self.layers[str(layer_id)] = DecoderLayer(config, layer_id) + + self.norm = ( + RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if config.stage_idx == config.num_stages - 1 + else None + ) + + # Initialize weights and apply final processing + self.apply(self._init_weights) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def forward( + self, + tokens: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + # Embedding + hidden_states = ( + self.embed_tokens(tokens) if self.embed_tokens is not None else tokens + ) + + # decoder layers + for decoder_layer in self.layers.values(): + hidden_states = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + hidden_states = ( + self.norm(hidden_states) if self.norm is not None else hidden_states + ) + return hidden_states + + +class DeepseekForCausalLM(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.model = DeepseekModel(config) + self.lm_head = ( + nn.Linear(config.hidden_size, config.vocab_size, bias=False) + if config.stage_idx == config.num_stages - 1 + else None + ) + + # Initialize weights and apply final processing + # self.post_init() + + def forward( + self, + tokens: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ) -> Tuple: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekForCausalLM + + >>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + hidden_states = self.model( + tokens, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + logits = ( + self.lm_head(hidden_states) if self.lm_head is not None else hidden_states + ) + return logits + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + **kwargs, + ): + if past_key_values is not None: + # Assuming isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + # Setup Symmetric Memory for MoE token shuffle. + # Supports inference currently. + def setup_symm_mem(self, dtype: torch.dtype, device: torch.device): + for layer in self.model.layers.values(): + if not isinstance(layer.mlp, MoE): + continue + layer.mlp.setup_symm_mem(dtype, device) diff --git a/torchtitan/experiments/deepseek_v3/model_config.py b/torchtitan/experiments/deepseek_v3/model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d559d4ee94ecf7fccc933cf1a243161d1796a123 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/model_config.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + + +@dataclass +class ModelArgs: + r""" + This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V3. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within + `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + """ + + vocab_size: int = 129280 + hidden_size: int = 7168 + intermediate_size: int = 18432 + moe_intermediate_size: int = 2048 + num_hidden_layers: int = 61 + num_nextn_predict_layers: int = 1 + num_attention_heads: int = 128 + num_key_value_heads: int = 128 + n_shared_experts: int = 1 + n_routed_experts: int = 256 + ep_size: int = 1 + routed_scaling_factor: float = 2.5 + kv_lora_rank: int = 512 + q_lora_rank: int = 1536 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + qk_nope_head_dim: int = 128 + topk_method: str = "noaux_tc" + n_group: int = 8 + topk_group: int = 4 + num_experts_per_tok: int = 8 + moe_layer_freq: int = 1 + first_k_dense_replace: int = 3 + norm_topk_prob: bool = True + scoring_func: str = "sigmoid" + aux_loss_alpha: float = 0.001 + seq_aux: bool = True + hidden_act: str = "silu" + max_position_embeddings: int = 163840 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + rope_theta: float = 10000.0 + rope_scaling: dict = field( + default_factory=lambda: { + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn", + } + ) + attention_bias: bool = False + attention_dropout: float = 0.0 + pad_token_id = None + # Added for symmetric memory + max_seq_len: int = 4096 + dtype: str = "bfloat16" + # Added for pipeline parallel + num_stages: int = 1 + stage_idx: int = 0 + + +# This is the configuration for deepseek-ai/DeepSeek-V2-Lite. +deepseek_v2_lite_config = ModelArgs( + vocab_size=102400, + hidden_size=2048, + intermediate_size=10944, + moe_intermediate_size=1408, + num_hidden_layers=27, + num_attention_heads=16, + num_key_value_heads=16, + n_shared_experts=2, + n_routed_experts=64, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=None, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="greedy", + n_group=1, + topk_group=1, + num_experts_per_tok=6, + first_k_dense_replace=1, + norm_topk_prob=False, + scoring_func="softmax", + max_position_embeddings=4096, + rope_scaling={ + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 0.707, + "mscale_all_dim": 0.707, + "original_max_position_embeddings": 4096, + "type": "yarn", + }, +) + + +# Model configuration registry +# Key is the model distribution ID on HuggingFace Hub +deepseek_config_registry = { + "deepseek-ai/DeepSeek-V2-Lite": deepseek_v2_lite_config, + "deepseek-ai/DeepSeek-V2-Lite-Chat": deepseek_v2_lite_config, + "deepseek-ai/deepseek-v3": ModelArgs(), +} diff --git a/torchtitan/experiments/deepseek_v3/requirements.txt b/torchtitan/experiments/deepseek_v3/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2b66a52d87be39b1c4fb36e822c24958d40dfa81 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/requirements.txt @@ -0,0 +1,5 @@ +transformers +accelerate +torchdata >= 0.8.0 +datasets >= 2.21.0 +tomli >= 1.1.0 ; python_version < "3.11" diff --git a/torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..335bc2d966efbe486418525cb784078a6ec879d5 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .triton_on_device_all_to_all_v import OnDeviceAllToAllV + +__all__ = [ + "OnDeviceAllToAllV", +] diff --git a/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd9b283f41daffab3f4ce4d1e0a5d844f2a2c70 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_barrier.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import triton +import triton.language as tl + +from .triton_utils import get_flat_bid, get_flat_tid + + +@triton.jit +def send_signal(addrs, sem: tl.constexpr): + if sem == "relaxed": + tl.inline_asm_elementwise( + """ + { + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + send_signal: + atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1; + setp.eq.u32 %p0, %tmp32_0, 0; + @!%p0 bra send_signal; + } + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + elif sem == "acq_rel": + tl.inline_asm_elementwise( + """ + { + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + send_signal: + atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1; + setp.eq.u32 %p0, %tmp32_0, 0; + @!%p0 bra send_signal; + } + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + else: + raise RuntimeError(f"Unrecognized sem: {sem}") + + +@triton.jit +def wait_signal(addrs, sem: tl.constexpr): + if sem == "relaxed": + tl.inline_asm_elementwise( + """ + { + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + wait_signal: + atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0; + setp.eq.u32 %p0, %tmp32_0, 1; + @!%p0 bra wait_signal; + } + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + elif sem == "acq_rel": + tl.inline_asm_elementwise( + """ + { + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + wait_signal: + atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0; + setp.eq.u32 %p0, %tmp32_0, 1; + @!%p0 bra wait_signal; + } + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + else: + raise RuntimeError(f"Unrecognized sem: {sem}") + + +@triton.jit +def blockwise_barrier( + signal_pad_ptrs, + block_id, + rank: tl.constexpr, + world_size: tl.constexpr, + sem: tl.constexpr, +): + """ + Synchronizes blocks with matching block_id across participating devices. + + Note: the function itself is not a system level barrier/fence. It is a + building block for expressing different synchronization patterns. + + Pattern 0: Ensures that all writes to symm_mem buffers from previous + kernels across all devices are visible to the current kernel: + + blockwise_barrier(..., sem="relaxed") + sync_threads() + + Pattern 1: Ensures that all writes to symm_mem buffers from the current + block are visible to all remote blocks with matching blockIdx: + + sync_threads() + blockwise_barrier(..., sem="acq_rel") + sync_threads() + + Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe + for writing by subsequent kernels across all devices. + + sync_threads() + blockwise_barrier(..., sem="relaxed") + + CUDA graph friendliness: + + This barrier operates through atomic operations on a zero-filled signal + pad, which resets to a zero-filled state after each successful + synchronization. This design eliminates the need for incrementing a + flag from host. + """ + if block_id is None: + block_id = get_flat_bid() + flat_tid = get_flat_tid() + + remote_ranks = tl.arange(0, world_size) + signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64)) + remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to( + tl.pointer_type(tl.uint32) + ) + send_addrs = remote_signal_pad_addrs + block_id * world_size + rank + + local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to( + tl.pointer_type(tl.uint32) + ) + wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks + + if flat_tid < world_size: + send_signal(send_addrs, sem) + wait_signal(wait_addrs, sem) diff --git a/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py new file mode 100644 index 0000000000000000000000000000000000000000..5cd023c36bd9737bfb03da22ea38ef57a448eb80 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py @@ -0,0 +1,260 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem +import triton +import triton.language as tl + +from .triton_barrier import blockwise_barrier +from .triton_utils import sync_threads + + +@triton.jit +def _exchange_row_offsets( + split_sizes_ptrs, + rank: tl.constexpr, + world_size: tl.constexpr, + BLOCKS_PER_REMOTE_RANK: tl.constexpr, +): + remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK + + # split_sizes_ptr for all ranks + # All these vector stacks into split_sizes_matrix + split_sizes_ptrs = split_sizes_ptrs.to(tl.pointer_type(tl.uint64)) + + # split_sizes_matrix[remote_rank, :] + input_split_sizes_ptr = tl.load(split_sizes_ptrs + remote_rank).to( + tl.pointer_type(tl.int64) + ) + + offsets_ = tl.arange(0, world_size) + input_split_sizes = tl.load( + input_split_sizes_ptr + offsets_, mask=offsets_ <= rank, other=0 + ) + + num_rows = tl.load(input_split_sizes_ptr + rank) + input_row_offset = tl.sum(input_split_sizes) - num_rows + + # split_sizes_matrix[:, rank] + output_split_sizes_ptrs = ( + tl.load(split_sizes_ptrs + offsets_).to(tl.pointer_type(tl.int64)) + rank + ) + output_split_sizes = tl.load( + output_split_sizes_ptrs, mask=offsets_ <= remote_rank, other=0 + ) + output_row_offset = tl.sum(output_split_sizes) - num_rows + + return input_row_offset, output_row_offset, num_rows + + +@triton.jit +def on_device_all_to_all_v_kernel( + output_ptr, + output_splits_ptr, + input_ptrs, + input_splits_ptr, + signal_pad_ptrs, + dim: tl.constexpr, # Separate dim for easier vectorization + rank: tl.constexpr, + world_size: tl.constexpr, + BLOCKS_PER_REMOTE_RANK: tl.constexpr, + UNROLL_FACTOR: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed") + sync_threads() + + remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK + block_offset = tl.program_id(0) % BLOCKS_PER_REMOTE_RANK + + input_row_offset, output_row_offset, num_rows = _exchange_row_offsets( + input_splits_ptr, rank, world_size, BLOCKS_PER_REMOTE_RANK + ) + + output_splits_ptr = output_splits_ptr.to(tl.pointer_type(tl.uint64)) + if block_offset == 0: + # Update output_splits + tl.store(output_splits_ptr + remote_rank, num_rows) + + input_ptr = ( + tl.load(input_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to( + tl.pointer_type(tl.bfloat16) + ) + + input_row_offset * dim + ) + output_ptr = output_ptr + output_row_offset * dim + + outer_loop_step = BLOCK_SIZE * UNROLL_FACTOR + outer_loop_iters_per_rank = tl.cdiv( + tl.cdiv(num_rows * dim, outer_loop_step), BLOCKS_PER_REMOTE_RANK + ) + numel_per_rank = outer_loop_step * outer_loop_iters_per_rank + offset = numel_per_rank * block_offset + end = tl.minimum(numel_per_rank * (block_offset + 1), num_rows * dim) + + unroll_region_size = (end - offset) // outer_loop_step * outer_loop_step + for i in tl.range(offset, offset + unroll_region_size, outer_loop_step): + datas = [] + for j in tl.range( + i, + i + outer_loop_step, + BLOCK_SIZE, + loop_unroll_factor=UNROLL_FACTOR, + ): + offsets = j + tl.arange(0, BLOCK_SIZE) + data = tl.load(input_ptr + offsets) + tl.store(output_ptr + offsets, data) + + offset += unroll_region_size + while offset < end: + offsets = offset + tl.arange(0, BLOCK_SIZE) + mask = offsets < num_rows * dim + data = tl.load(input_ptr + offsets, mask=mask) + tl.store(output_ptr + offsets, data, mask=mask) + offset += BLOCK_SIZE + + sync_threads() + blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed") + return + + +def _on_device_all_to_all_v( + output: torch.Tensor, + output_splits: torch.Tensor, + input: torch.Tensor, + input_splits: torch.Tensor, + group: dist.ProcessGroup = dist.group.WORLD, + BLOCKS_PER_REMOTE_RANK=8, + UNROLL_FACTOR: int = 8, + BLOCK_SIZE: int = 16384, +): + assert output.dim() == 2, f"{output.shape}" + assert input.dim() == 2, f"{input.shape}" + assert output.shape[1] == input.shape[1] + + dim = output.shape[1] + input_hdl = symm_mem.rendezvous(input, group=group) + input_splits_hdl = symm_mem.rendezvous(input_splits, group=group) + + num_blocks = input_hdl.world_size * BLOCKS_PER_REMOTE_RANK + kernel = on_device_all_to_all_v_kernel[(num_blocks, 1, 1)]( + output, + output_splits, + input_hdl.buffer_ptrs_dev, + input_splits_hdl.buffer_ptrs_dev, + input_hdl.signal_pad_ptrs_dev, + dim=dim, + rank=input_hdl.rank, + world_size=input_hdl.world_size, + BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK, + UNROLL_FACTOR=UNROLL_FACTOR, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=16, + ) + # log_triton_kernel(kernel) + return output + + +class OnDeviceAllToAllV(torch.autograd.Function): + # A symmetric memory holding the grad_output during backward + grad_output_buf = None + # A symmetric memory for exchanges split sizes during both forward and backward + splits_buf = None + # Maximum output length (need to be set before use of OnDeviceAllToAllV) + max_output_len = None + + @staticmethod + def forward( + ctx, + input: torch.Tensor, + input_splits: torch.Tensor, + group: dist.ProcessGroup = dist.group.WORLD, + ): + """ + Args: + input: input tensor with data for all ranks concatenated. + input_splits: input splits of shape (group.world_size,) + group: process group to scope the collective. + """ + # Initialize input splits buffer (one time only) + if OnDeviceAllToAllV.splits_buf is None: + OnDeviceAllToAllV.splits_buf = symm_mem.empty( + *input_splits.shape, + dtype=input_splits.dtype, + device=input_splits.device, + ) + + if OnDeviceAllToAllV.max_output_len is None: + raise RuntimeError( + "Please set max output length via `OnDeviceAllToAllV.max_output_len = ...`" + ) + + # Allocate output buffer + output = input.new_empty(OnDeviceAllToAllV.max_output_len, *input.shape[1:]) + # Allocate output splits tensor + output_splits = torch.empty_like(input_splits) + # Copy input splits to the buffer + OnDeviceAllToAllV.splits_buf.copy_(input_splits) + + # Shuffle input to output + _on_device_all_to_all_v( + output, output_splits, input, OnDeviceAllToAllV.splits_buf, group=group + ) + + # Output splits in forward is the input splits in backward + ctx.save_for_backward(output_splits) + ctx.group = group + ctx.input_shape = input.shape + return output, output_splits + + @staticmethod + def backward(ctx, grad_output, grad_splits): + """ + Backward is implemented as a shuffle of the output's gradients to the input. + Args: + `grad_output`: output's gradients passed from the downstream. + `grad_splits`: unused. + """ + + # Initialize grad_output buffer (one time only) + if OnDeviceAllToAllV.grad_output_buf is None: + assert ( + OnDeviceAllToAllV.max_output_len is not None + ), "`max_output_len` not set" + OnDeviceAllToAllV.grad_output_buf = symm_mem.empty( + OnDeviceAllToAllV.max_output_len, + *grad_output.shape[1:], + dtype=grad_output.dtype, + device=grad_output.device, + ) + + # TODO: is there a way to tell autograd to feed grad_output directly to + # our symm_mem buffer? + OnDeviceAllToAllV.grad_output_buf.narrow(0, 0, grad_output.shape[0]).copy_( + grad_output + ) + + # Size info + (grad_output_splits,) = ctx.saved_tensors + OnDeviceAllToAllV.splits_buf.copy_(grad_output_splits) + grad_input_splits = torch.empty_like(grad_output_splits) # unused + grad_input = grad_output.new_empty(*ctx.input_shape) + + # Shuffle gradients back to the input + _on_device_all_to_all_v( + grad_input, + grad_input_splits, + OnDeviceAllToAllV.grad_output_buf, + OnDeviceAllToAllV.splits_buf, + group=ctx.group, + ) + return grad_input, None, None + + +# Alias +on_device_all_to_all_v = OnDeviceAllToAllV.apply diff --git a/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ed00317084d85abd10e13cc4f18437d6e9337a75 --- /dev/null +++ b/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import triton +import triton.language as tl + + +@triton.jit +def get_tid(): + return tl.inline_asm_elementwise( + """ + mov.u32 $0, %tid.x; + mov.u32 $1, %tid.y; + mov.u32 $2, %tid.z; + """, + "=r,=r,=r", + [], + dtype=(tl.uint32, tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + + +@triton.jit +def get_ntid(): + return tl.inline_asm_elementwise( + """ + mov.u32 $0, %ntid.x; + mov.u32 $1, %ntid.y; + mov.u32 $2, %ntid.z; + """, + "=r,=r,=r", + [], + dtype=(tl.uint32, tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + + +@triton.jit +def get_flat_tid(): + tid_x, tid_y, tid_z = get_tid() + ntid_x, ntid_y, _ = get_ntid() + return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x + + +@triton.jit +def get_flat_bid(): + return ( + tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0) + + tl.program_id(1) * tl.num_programs(0) + + tl.program_id(0) + ) + + +@triton.jit +def sync_threads(): + tl.inline_asm_elementwise( + "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 + ) diff --git a/torchtitan/experiments/flux/README.md b/torchtitan/experiments/flux/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2e56939b6eea7769d5130703cd3acb58f7eb5f5a --- /dev/null +++ b/torchtitan/experiments/flux/README.md @@ -0,0 +1,23 @@ +# FLUX model in torchtitan + +## Overview + +## Usage +First, download the autoencoder model from HuggingFace with your own access token: +```bash +python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token +``` +This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file. + +Run the following command to train the model on a single GPU: +```bash +PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml +``` + +## TODO +- [ ] Supporting for multiple GPUs is comming soon (FSDP, etc) +- [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc) +- [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc) +- [ ] Support for distributed checkpointing and loading +- [ ] Implement init_weights() function to initialize the model weights +- [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function diff --git a/torchtitan/experiments/flux/__init__.py b/torchtitan/experiments/flux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..314a8689b291c74db639669e7bc4943612b47a03 --- /dev/null +++ b/torchtitan/experiments/flux/__init__.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader +from torchtitan.experiments.flux.loss import build_mse_loss +from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams +from torchtitan.experiments.flux.parallelize_flux import parallelize_flux +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .model.model import FluxModel, FluxModelArgs + +__all__ = [ + "FluxModelArgs", + "FluxModel", + "flux_configs", + "parallelize_flux", +] + + +flux_configs = { + "flux-dev": FluxModelArgs( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=512, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=(16, 56, 56), + theta=10_000, + qkv_bias=True, + guidance_embed=True, + autoencoder_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": FluxModelArgs( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=(16, 56, 56), + theta=10_000, + qkv_bias=True, + guidance_embed=False, + autoencoder_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-debug": FluxModelArgs( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=512, + hidden_size=512, + mlp_ratio=4.0, + num_heads=4, + depth=2, + depth_single_blocks=2, + axes_dim=(16, 56, 56), + theta=10_000, + qkv_bias=True, + guidance_embed=True, + autoencoder_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +register_train_spec( + TrainSpec( + name="flux", + cls=FluxModel, + config=flux_configs, + parallelize_fn=parallelize_flux, + pipelining_fn=None, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_flux_dataloader, + build_tokenizer_fn=None, + build_loss_fn=build_mse_loss, + ) +) diff --git a/torchtitan/experiments/flux/dataset/flux_dataset.py b/torchtitan/experiments/flux/dataset/flux_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..995f0af3b4152052bcfb21b4331e8dcff8ddd7da --- /dev/null +++ b/torchtitan/experiments/flux/dataset/flux_dataset.py @@ -0,0 +1,267 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import random +from dataclasses import dataclass +from typing import Any, Callable, Optional + +import numpy as np + +import torch + +from datasets import Dataset, load_dataset +from datasets.distributed import split_dataset_by_node +from PIL import Image + +from torch.distributed.checkpoint.stateful import Stateful + +from torch.utils.data import IterableDataset +from torchtitan.components.dataloader import ParallelAwareDataloader + +from torchtitan.config_manager import JobConfig +from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer +from torchtitan.tools.logging import logger + + +def _process_cc12m_image( + img: Image.Image, + output_size: int = 256, +) -> Optional[torch.Tensor]: + """Process CC12M image to the desired size.""" + + width, height = img.size + # Skip low resolution images + if width < output_size or height < output_size: + return None + + if width >= height: + # resize height to be equal to output_size, then crop + new_width, new_height = math.ceil(output_size / height * width), output_size + img = img.resize((new_width, new_height)) + left = random.randint(0, new_width - output_size) + resized_img = img.crop((left, 0, left + output_size, output_size)) + else: + # resize width to be equal to output_size, the crop + new_width, new_height = ( + output_size, + math.ceil(output_size / width * height), + ) + img = img.resize((new_width, new_height)) + lower = random.randint(0, new_width - output_size) + resized_img = img.crop((0, lower, output_size, lower + output_size)) + + assert resized_img.size[0] == resized_img.size[1] == output_size + + # Skip grayscale images + if resized_img.mode == "L": + return None + + np_img = np.array(resized_img).transpose((2, 0, 1)) + tensor_img = torch.tensor(np_img).float() / 255.0 + + # NOTE: The following commented code is an alternative way + # img_transform = transforms.Compose( + # [ + # transforms.Resize(max(output_size, output_size)), + # transforms.CenterCrop((output_size, output_size)), + # transforms.ToTensor(), + # ] + # ) + # tensor_img = img_transform(img) + + return tensor_img + + +def _flux_data_processor( + sample: dict[str, Any], + t5_tokenizer: FluxTokenizer, + clip_tokenizer: FluxTokenizer, + output_size: int = 256, +) -> dict[str, Any]: + """ + Preprocess CC12M dataset sample image and text for Flux model. + + Args: + sample: A sample from dataset + t5_encoder: T5 encoder + clip_encoder: CLIP encoder + output_size: The output image size + + """ + img = _process_cc12m_image(sample["jpg"], output_size=output_size) + t5_tokens = t5_tokenizer.encode(sample["txt"]) + clip_tokens = clip_tokenizer.encode(sample["txt"]) + + return { + "image": img, + "clip_tokens": clip_tokens, # type: List[int] + "t5_tokens": t5_tokens, # type: List[int] + } + + +@dataclass +class TextToImageDatasetConfig: + path: str + loader: Callable + data_processor: Callable + + +DATASETS = { + "cc12m": TextToImageDatasetConfig( + path="pixparse/cc12m-wds", + loader=lambda path: load_dataset(path, split="train", streaming=True), + data_processor=_flux_data_processor, + ), +} + + +def _validate_dataset( + dataset_name: str, dataset_path: Optional[str] = None +) -> tuple[str, Callable, Callable]: + """Validate dataset name and path.""" + if dataset_name not in DATASETS: + raise ValueError( + f"Dataset {dataset_name} is not supported. " + f"Supported datasets are: {list(DATASETS.keys())}" + ) + + config = DATASETS[dataset_name] + path = dataset_path or config.path + logger.info(f"Preparing {dataset_name} dataset from {path}") + return path, config.loader, config.data_processor + + +class FluxDataset(IterableDataset, Stateful): + """Dataset for FLUX text-to-image model. + + Args: + dataset_name (str): Name of the dataset. + dataset_path (str): Path to the dataset. + model_transform (Transform): Callable that applies model-specific preprocessing to the sample. + dp_rank (int): Data parallel rank. + dp_world_size (int): Data parallel world size. + infinite (bool): Whether to loop over the dataset infinitely. + """ + + def __init__( + self, + dataset_name: str, + dataset_path: Optional[str], + t5_tokenizer: FluxTokenizer, + clip_tokenizer: FluxTokenizer, + job_config: Optional[JobConfig] = None, + dp_rank: int = 0, + dp_world_size: int = 1, + infinite: bool = False, + ) -> None: + + # Force lowercase for consistent comparison + dataset_name = dataset_name.lower() + + path, dataset_loader, data_processor = _validate_dataset( + dataset_name, dataset_path + ) + ds = dataset_loader(path) + + self.dataset_name = dataset_name + self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) + + self._t5_tokenizer = t5_tokenizer + self._clip_tokenizer = clip_tokenizer + self._data_processor = data_processor + self.job_config = job_config + + self.infinite = infinite + + # Variables for checkpointing + self._sample_idx = 0 + self._all_samples: list[dict[str, Any]] = [] + + def _get_data_iter(self): + if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): + return iter([]) + + it = iter(self._data) + for _ in range(self._sample_idx): + next(it) + return it + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + # Use the dataset-specific preprocessor + sample_dict = self._data_processor( + sample, self._t5_tokenizer, self._clip_tokenizer, output_size=256 + ) + + # skip low quality image or image with color channel = 1 + if sample_dict["image"] is None: + logger.warning( + f"Low quality image {sample['__key__']} is skipped in Flux Dataloader" + ) + continue + + self._all_samples.extend(sample_dict) + self._sample_idx += 1 + + labels = sample_dict.pop("image") + yield sample_dict, labels + + if not self.infinite: + logger.warning(f"Dataset {self.dataset_name} has run out of data") + break + else: + # Reset offset for the next iteration + self._sample_idx = 0 + logger.warning(f"Dataset {self.dataset_name} is being re-looped") + + def load_state_dict(self, state_dict): + self._sample_idx = state_dict["sample_idx"] + self._all_samples = state_dict["all_samples"] + + def state_dict(self): + return { + "all_samples": self._all_samples, + "sample_idx": self._sample_idx, + } + + +def build_flux_dataloader( + dp_world_size: int, + dp_rank: int, + job_config: JobConfig, + # This parameter is not used, keep it for compatibility + tokenizer: FluxTokenizer | None, + infinite: bool = True, +) -> ParallelAwareDataloader: + """Build a data loader for HuggingFace datasets.""" + dataset_name = job_config.training.dataset + dataset_path = job_config.training.dataset_path + batch_size = job_config.training.batch_size + + t5_encoder_name = job_config.encoder.t5_encoder + clip_encoder_name = job_config.encoder.clip_encoder + max_t5_encoding_len = job_config.encoder.max_t5_encoding_len + + ds = FluxDataset( + dataset_name=dataset_name, + dataset_path=dataset_path, + t5_tokenizer=FluxTokenizer(t5_encoder_name, max_length=max_t5_encoding_len), + clip_tokenizer=FluxTokenizer( + clip_encoder_name, max_length=77 + ), # fix max_length for CLIP + dp_rank=dp_rank, + dp_world_size=dp_world_size, + infinite=infinite, + ) + + return ParallelAwareDataloader( + dataset=ds, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + batch_size=batch_size, + ) diff --git a/torchtitan/experiments/flux/dataset/tokenizer.py b/torchtitan/experiments/flux/dataset/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..090bfc955152d87614f03793fd606330995da39d --- /dev/null +++ b/torchtitan/experiments/flux/dataset/tokenizer.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + + +from typing import List + +from torchtitan.components.tokenizer import Tokenizer +from transformers import CLIPTokenizer, T5Tokenizer + + +class FluxTokenizer(Tokenizer): + """ + Tokenizing and encoding/decoding text using the T5 or Clip tokenizer. + + Args: + model_path (str): Path to the tokenzier from hugging face. + + """ + + def __init__(self, model_path: str = "t5-small", max_length: int = 77): + super().__init__() + self._n_words = 8 # TODO(jianiw): check + self._max_length = max_length + + self.is_clip = model_path.startswith("openai") + + if self.is_clip: + self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( + model_path, max_length=max_length + ) + else: + self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( + model_path, max_length=max_length + ) + + def encode( + self, + s: str, + ) -> List[int]: + """ + Encode the prompt text into tokens. + """ + tokens = self._tokenizer( + s, + truncation=True, + max_length=self._max_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", # return pytorch tensors, default return List[int] + )["input_ids"] + return tokens + + def decode(self, t: List[int]) -> str: + """ + Decode function. This function will not be called. + """ + return self._tokenizer.decode(t) diff --git a/torchtitan/experiments/flux/loss.py b/torchtitan/experiments/flux/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e3d2f000be025942c3e48cc455987319d989a103 --- /dev/null +++ b/torchtitan/experiments/flux/loss.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, TypeAlias + +import torch + +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger + +LossFunction: TypeAlias = Callable[..., torch.Tensor] + + +def mse_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Common MSE loss function for Transformer models training.""" + return torch.nn.functional.mse_loss(pred.float(), labels.float().detach()) + + +def build_mse_loss(job_config: JobConfig): + loss_fn = mse_loss + if job_config.training.compile: + logger.info("Compiling the loss function with torch.compile") + loss_fn = torch.compile(loss_fn) + return loss_fn diff --git a/torchtitan/experiments/flux/model/autoencoder.py b/torchtitan/experiments/flux/model/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a68d5fb750d04b37d059dbef1de1f399bd3caea2 --- /dev/null +++ b/torchtitan/experiments/flux/model/autoencoder.py @@ -0,0 +1,388 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from dataclasses import dataclass + +import torch +from einops import rearrange +from safetensors.torch import load_file as load_sft +from torch import nn, Tensor + + +@dataclass +class AutoEncoderParams: + resolution: int = 256 + in_channels: int = 3 + ch: int = 128 + out_ch: int = 3 + ch_mult: tuple[int] = (1, 2, 4, 4) + num_res_blocks: int = 2 + z_channels: int = 16 + scale_factor: float = 0.3611 + shift_factor: float = 0.1159 + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = nn.GroupNorm( + num_groups=32, num_channels=out_channels, eps=1e-6, affine=True + ) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=block_in, eps=1e-6, affine=True + ) + self.conv_out = nn.Conv2d( + block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=block_in, eps=1e-6, affine=True + ) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # get dtype for proper tracing + upscale_dtype = next(self.up.parameters()).dtype + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # cast to proper dtype + h = h.to(upscale_dtype) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.params = params + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) + + +def load_ae( + ckpt_path: str, + autoencoder_params: AutoEncoderParams, + device: str | torch.device = "cuda", + dtype=torch.bfloat16, +) -> AutoEncoder: + """ + Load the autoencoder from the given model name. + Args: + name (str): The name of the autoencoder. + device (str or torch.device): The device to load the autoencoder to. + Returns: + AutoEncoder: The loaded autoencoder. + """ + # Loading the autoencoder + print("Init AE") + with torch.device(device): + ae = AutoEncoder(autoencoder_params) + + if not os.path.exists(ckpt_path): + raise ValueError( + f"Autoencoder path {ckpt_path} does not exist. Please download it first." + ) + + if ckpt_path is not None: + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + if len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + if len(unexpected) > 0: + print( + f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) + ) + return ae.to(dtype=dtype) diff --git a/torchtitan/experiments/flux/model/hf_embedder.py b/torchtitan/experiments/flux/model/hf_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..495fd7a81d16cc0cadeaab3b390a638339ff0f94 --- /dev/null +++ b/torchtitan/experiments/flux/model/hf_embedder.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn, Tensor +from transformers import CLIPTextModel, T5EncoderModel + + +class FluxEmbedder(nn.Module): + def __init__(self, version: str, **hf_kwargs): + super().__init__() + self.is_clip = version.startswith("openai") + self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" + + if self.is_clip: + self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( + version, **hf_kwargs + ) + else: + self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained( + version, **hf_kwargs + ) + + self.hf_module = self.hf_module.eval().requires_grad_(False) + + def forward(self, batch_tokens: Tensor) -> Tensor: + """ + batch_tokens: [bsz, embedding_length] + + For T5 Encoder, embeding_length is 768 + For CLIP, embedding_length is 256 + """ + outputs = self.hf_module( + input_ids=batch_tokens.to(self.hf_module.device), + attention_mask=None, + output_hidden_states=False, + ) + return outputs[self.output_key] diff --git a/torchtitan/experiments/flux/model/layers.py b/torchtitan/experiments/flux/model/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..73141b373a5d579b8c8988fa66d1f9594e5bad3f --- /dev/null +++ b/torchtitan/experiments/flux/model/layers.py @@ -0,0 +1,286 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# imported from black-forest-labs/FLUX +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import nn, Tensor + +from torchtitan.experiments.flux.model.math import attention, rope + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) # TODO(jianiw): switch to pytorch nn.RMSNorm + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk( + self.multiplier, dim=-1 + ) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__( + self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False + ): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + def forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor + ) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange( + img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange( + txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp( + (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift + ) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp( + (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift + ) + return img, txt + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split( + self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1 + ) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/torchtitan/experiments/flux/model/math.py b/torchtitan/experiments/flux/model/math.py new file mode 100644 index 0000000000000000000000000000000000000000..69a2d4acf13c1acf9f66edba1e5fe49c26d9b1d5 --- /dev/null +++ b/torchtitan/experiments/flux/model/math.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from einops import rearrange +from torch import Tensor + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack( + [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1 + ) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/torchtitan/experiments/flux/model/model.py b/torchtitan/experiments/flux/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..67b9e6aeaacee709c4fdc7d86f338eec050bf322 --- /dev/null +++ b/torchtitan/experiments/flux/model/model.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +import torch + +from torch import nn, Tensor +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig + +from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams +from torchtitan.experiments.flux.model.layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, +) + +from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol +from torchtitan.tools.logging import logger + + +@dataclass +class FluxModelArgs(BaseModelArgs): + in_channels: int = 64 + out_channels: int = 64 + vec_in_dim: int = 768 + context_in_dim: int = 512 + hidden_size: int = 3072 + mlp_ratio: float = 4.0 + num_heads: int = 24 + depth: int = 19 + depth_single_blocks: int = 38 + axes_dim: tuple = (16, 56, 56) + theta: int = 10_000 + qkv_bias: bool = True + guidance_embed: bool = True + autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams) + + def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: + # context_in_dim is the same as the T5 embedding dimension + self.context_in_dim = job_config.encoder.max_t5_encoding_len + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + # TODO(jianiw): Add the number of flops for the autoencoder + nparams = sum(p.numel() for p in model.parameters()) + logger.warning("FLUX model haven't implement get_nparams_and_flops() function") + return nparams, 1 + + +class FluxModel(nn.Module, ModelProtocol): + """ + Transformer model for flow matching on sequences. + + Agrs: + model_args: FluxModelArgs. + + Attributes: + model_args (TransformerModelArgs): Model configuration arguments. + """ + + def __init__(self, model_args: FluxModelArgs): + super().__init__() + + self.model_args = model_args + self.in_channels = model_args.in_channels + self.out_channels = model_args.out_channels + if model_args.hidden_size % model_args.num_heads != 0: + raise ValueError( + f"Hidden size {model_args.hidden_size} must be divisible by num_heads {model_args.num_heads}" + ) + pe_dim = model_args.hidden_size // model_args.num_heads + if sum(model_args.axes_dim) != pe_dim: + raise ValueError( + f"Got {model_args.axes_dim} but expected positional dim {pe_dim}" + ) + self.hidden_size = model_args.hidden_size + self.num_heads = model_args.num_heads + self.pe_embedder = EmbedND( + dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim + ) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(model_args.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + if model_args.guidance_embed + else nn.Identity() + ) + self.txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=model_args.mlp_ratio, + qkv_bias=model_args.qkv_bias, + ) + for _ in range(model_args.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, self.num_heads, mlp_ratio=model_args.mlp_ratio + ) + for _ in range(model_args.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + def init_weights(self, buffer_device=None): + # TODO(jianiw): replace placeholder with real weight init + for param in self.parameters(): + param.data.uniform_(0, 0.1) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.model_args.guidance_embed: + if guidance is None: + raise ValueError( + "Didn't get guidance strength for guidance distilled model." + ) + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img + + @classmethod + def from_model_args(cls, model_args: FluxModelArgs) -> "FluxModel": + """ + Initialize a Flux model from a FluxModelArgs object. + + Args: + model_args (FluxModelArgs): Model configuration arguments. + + Returns: + FluxModel: FluxModel model. + + """ + return cls(model_args) diff --git a/torchtitan/experiments/flux/parallelize_flux.py b/torchtitan/experiments/flux/parallelize_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..fcdde64f86899ae19fa2f0891bdd71d14b9cbe97 --- /dev/null +++ b/torchtitan/experiments/flux/parallelize_flux.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D parallelisms (except pipeline parallelism) and various +# training techniques (e.g. activation checkpointing and compile) to the Llama model. + + +import torch.nn as nn + +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims + + +def parallelize_flux( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + # TODO: Add model parallel strategy here + return model diff --git a/torchtitan/experiments/flux/scripts/download_autoencoder.py b/torchtitan/experiments/flux/scripts/download_autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c4dd4437bc583987da69ace57e61ef1b8314d582 --- /dev/null +++ b/torchtitan/experiments/flux/scripts/download_autoencoder.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +from requests.exceptions import HTTPError + + +def hf_download( + repo_id: str, file_path: str, local_dir: str, hf_token: Optional[str] = None +) -> None: + from huggingface_hub import hf_hub_download + + try: + hf_hub_download( + repo_id=repo_id, + filename=file_path, + local_dir=local_dir, + local_dir_use_symlinks=False, + token=hf_token, + ) + except HTTPError as e: + if e.response.status_code == 401: + print( + "You need to pass a valid `--hf_token=...` to download private checkpoints." + ) + else: + raise e + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.") + parser.add_argument( + "--repo_id", + type=str, + default="black-forest-labs/FLUX.1-dev", + help="Repository ID to download from. default to Flux-dev model", + ) + parser.add_argument( + "--ae_path", + type=str, + default="ae.safetensors", + help="the autoencoder path relative to repo_id", + ) + parser.add_argument( + "--hf_token", type=str, default=None, help="HuggingFace API token" + ) + parser.add_argument( + "--local_dir", + type=str, + default="torchtitan/experiments/flux/assets/autoencoder/", + help="local directory to save the autoencoder", + ) + + args = parser.parse_args() + hf_download(args.repo_id, args.ae_path, args.local_dir, args.hf_token) diff --git a/torchtitan/experiments/flux/tests/test_flux_dataloader.py b/torchtitan/experiments/flux/tests/test_flux_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..fc87f1b8b4ae3ad7daf1558835716720127e3b42 --- /dev/null +++ b/torchtitan/experiments/flux/tests/test_flux_dataloader.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +from torchtitan.config_manager import JobConfig +from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader +from torchtitan.tools.profiling import ( + maybe_enable_memory_snapshot, + maybe_enable_profiling, +) + + +class TestFluxDataLoader: + def test_flux_dataloader(self): + dataset_name = "cc12m" + batch_size = 32 + world_size = 4 + rank = 0 + + num_steps = 10 + + path = "torchtitan.experiments.flux.flux_argparser" + sys.argv.append(f"--experimental.custom_args_module={path}") + config = JobConfig() + config.maybe_add_custom_args() + config.parse_args( + [ + # Profiling options + # "--profiling.enable_profiling", + # "--profiling.profile_freq", + # "5", + # "--profiling.enable_memory_snapshot", + # "--profiling.save_memory_snapshot_folder", + # "memory_snapshot_flux", + "--training.dataset", + dataset_name, + "--training.batch_size", + str(batch_size), + "--encoder.t5_encoder", + "google/t5-v1_1-small", + "--encoder.clip_encoder", + "openai/clip-vit-large-patch14", + "--encoder.max_t5_encoding_len", + "512", + ] + ) + + with maybe_enable_profiling( + config, global_step=0 + ) as torch_profiler, maybe_enable_memory_snapshot( + config, global_step=0 + ) as memory_profiler: + dl = self._build_dataloader( + config, + world_size, + rank, + ) + dl = iter(dl) + + for i in range(0, num_steps): + input_data, labels = next(dl) + print(f"Step {i} image size: {labels.shape}") + if torch_profiler: + torch_profiler.step() + if memory_profiler: + memory_profiler.step() + + print(len(input_data["clip_tokens"])) + for k, v in input_data.items(): + print(f"Step {i} {k} value: {type(v), v.shape}") + + assert len(input_data) == 2 # (clip_encodings, t5_encodings) + assert labels.shape == (batch_size, 3, 256, 256) + # assert input_data["clip_tokens"].shape[0] == batch_size + # assert input_data["t5_tokens"].shape == (batch_size, 512, 512) + + if torch_profiler: + torch_profiler.step() + if memory_profiler: + memory_profiler.step(exit_ctx=True) + + def test_preprocess(self): + # TODO + pass + + def _build_dataloader( + self, + job_config, + world_size, + rank, + ): + + return build_flux_dataloader( + dp_world_size=world_size, + dp_rank=rank, + job_config=job_config, + tokenizer=None, + infinite=False, + ) diff --git a/torchtitan/experiments/flux/tests/test_generate_image.py b/torchtitan/experiments/flux/tests/test_generate_image.py new file mode 100644 index 0000000000000000000000000000000000000000..86d8d16cfbbcbfaa706e6ff6713403520744efd5 --- /dev/null +++ b/torchtitan/experiments/flux/tests/test_generate_image.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import os +import time +from typing import Callable + +import torch +from einops import rearrange + +from PIL import ExifTags, Image + +from torch import Tensor + +from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer + +from torchtitan.experiments.flux.model.autoencoder import ( + AutoEncoder, + AutoEncoderParams, + load_ae, +) +from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder + +from torchtitan.experiments.flux.model.model import FluxModel, FluxModelArgs +from torchtitan.experiments.flux.utils import ( + create_position_encoding_for_latents, + generate_noise_latent, + pack_latents, + preprocess_flux_data, + unpack_latents, +) + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # estimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +class TestGenerateImage: + def test_generate_image(self): + """ + Run a forward pass of flux model to generate an image. + """ + name = "flux-dev" + img_width = 512 + img_height = 512 + seed = None + prompt = ( + "a photo of a forest with mist swirling around the tree trunks. The word " + '"FLUX" is painted over it in big, red brush strokes with visible texture' + ) + device = "cuda" + num_steps = None + loop = False + guidance = 3.5 + output_dir = "output" + add_sampling_metadata = True + + prompt = prompt.split("|") + if len(prompt) == 1: + prompt = prompt[0] + additional_prompts = None + else: + additional_prompts = prompt[1:] + prompt = prompt[0] + + assert not ( + (additional_prompts is not None) and loop + ), "Do not provide additional prompts and set loop to True" + + torch_device = torch.device(device) + if num_steps is None: + num_steps = 30 + + # allow for packing and conversion to latent space + img_height = 16 * (img_height // 16) + img_width = 16 * (img_width // 16) + + # init all components + model = FluxModel(FluxModelArgs()).to(device=torch_device, dtype=torch.bfloat16) + + ae = load_ae( + ckpt_path="assets/autoencoder/ae.safetensors", + autoencoder_params=AutoEncoderParams(), + device=torch_device, + dtype=torch.bfloat16, + ) + clip_tokenizer = FluxTokenizer( + model_path="openai/clip-vit-large-patch14", max_length=77 + ) + t5_tokenizer = FluxTokenizer(model_path="google/t5-v1_1-small", max_length=512) + clip_encoder = FluxEmbedder(version="openai/clip-vit-large-patch14").to( + torch_device, dtype=torch.bfloat16 + ) + t5_encoder = FluxEmbedder(version="google/t5-v1_1-small").to( + torch_device, dtype=torch.bfloat16 + ) + + rng = torch.Generator(device="cpu") + + if seed is None: + seed = rng.seed() + print(f"Generating with seed {seed}:\n{prompt}") + t0 = time.perf_counter() + output_name = os.path.join(output_dir, f"img_{seed}.jpg") + + # Tokenize the prompt, on CPU + clip_tokens = clip_tokenizer.encode(prompt) + t5_tokens = t5_tokenizer.encode(prompt) + + batch = preprocess_flux_data( + device=torch_device, + dtype=torch.bfloat16, + autoencoder=None, + clip_encoder=clip_encoder, + t5_encoder=t5_encoder, + batch={ + "clip_tokens": clip_tokens, + "t5_tokens": t5_tokens, + }, + ) + + img = self._generate_images( + device=torch_device, + dtype=torch.bfloat16, + model=model, + decoder=ae, + img_width=img_width, + img_height=img_height, + denoising_steps=num_steps, + seed=seed, + clip_encodings=batch["clip_encodings"], + t5_encodings=batch["t5_encodings"], + guidance=guidance, + ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t1 = time.perf_counter() + + print(f"Done in {t1 - t0:.1f}s.") + + self._save_image(name, output_name, img, add_sampling_metadata, prompt) + + def _generate_images( + self, + device: torch.device, + dtype: torch.dtype, + model: FluxModel, + decoder: AutoEncoder, + # image params: + img_width: int, + img_height: int, + # sampling params: + denoising_steps: int, + seed: int, + clip_encodings: torch.Tensor, + t5_encodings: torch.Tensor, + guidance: float = 4.0, + ): + + bsz = clip_encodings.shape[0] + latents = generate_noise_latent(bsz, img_height, img_width, device, dtype, seed) + _, latent_channels, latent_height, latent_width = latents.shape + + # create denoising schedule + timesteps = get_schedule(denoising_steps, latent_channels, shift=True) + + # create positional encodings + POSITION_DIM = 3 # constant for Flux flow model + latent_pos_enc = create_position_encoding_for_latents( + bsz, latent_height, latent_width, POSITION_DIM + ).to(latents) + text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents) + + # convert img-like latents into sequences of patches + latents = pack_latents(latents) + + # this is ignored for schnell + guidance_vec = torch.full((bsz,), guidance, device=device, dtype=dtype) + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device) + pred = model( + img=latents, + img_ids=latent_pos_enc, + txt=t5_encodings, + txt_ids=text_pos_enc, + y=clip_encodings, + timesteps=t_vec, + guidance=guidance_vec, + ) + + latents = latents + (t_prev - t_curr) * pred + + # convert sequences of patches into img-like latents + latents = unpack_latents(latents, latent_height, latent_width) + + img = decoder.decode(latents) + return img + + def _save_image( + self, + name: str, + output_name: str, + x: torch.Tensor, + add_sampling_metadata: bool, + prompt: str, + ): + print(f"Saving {output_name}") + # bring into PIL format and save + x = x.clamp(-1, 1) + x = rearrange(x[0], "c h w -> h w c") + + img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + + exif_data = Image.Exif() + exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" + exif_data[ExifTags.Base.Make] = "Black Forest Labs" + exif_data[ExifTags.Base.Model] = name + if add_sampling_metadata: + exif_data[ExifTags.Base.ImageDescription] = prompt + img.save(output_name, exif=exif_data, quality=95, subsampling=0) diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py new file mode 100644 index 0000000000000000000000000000000000000000..064e854b2650c4792295438247fbe37e56a1d1b2 --- /dev/null +++ b/torchtitan/experiments/flux/train.py @@ -0,0 +1,224 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from typing import Optional + +import torch + +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import utils as dist_utils +from torchtitan.experiments.flux.model.autoencoder import load_ae +from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder +from torchtitan.experiments.flux.model.model import FluxModel +from torchtitan.experiments.flux.utils import ( + create_position_encoding_for_latents, + pack_latents, + preprocess_flux_data, + unpack_latents, +) +from torchtitan.tools.logging import init_logger, logger +from torchtitan.train import Trainer + + +class FluxTrainer(Trainer): + def __init__(self, job_config: JobConfig): + super().__init__(job_config) + + self.preprocess_fn = preprocess_flux_data + # self.dtype = job_config.encoder.dtype + self._dtype = torch.bfloat16 + self._seed = job_config.training.seed + self._guidance = job_config.training.guidance + + # load components + model_config = self.train_spec.config[job_config.model.flavor] + self.autoencoder = load_ae( + job_config.encoder.auto_encoder_path, + model_config.autoencoder_params, + device="cpu", + dtype=self._dtype, + ) + self.clip_encoder = FluxEmbedder(version=job_config.encoder.clip_encoder).to( + dtype=self._dtype + ) + self.t5_encoder = FluxEmbedder(version=job_config.encoder.t5_encoder).to( + dtype=self._dtype + ) + + def _predict_noise( + self, + model: FluxModel, + latents: torch.Tensor, + clip_encodings: torch.Tensor, + t5_encodings: torch.Tensor, + timesteps: torch.Tensor, + guidance: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Use Flux's flow-matching model to predict the noise in image latents. + Args: + model (FluxFlowModel): The Flux flow model. + latents (Tensor): Image encodings from the Flux autoencoder. + Shape: [bsz, 16, latent height, latent width] + clip_encodings (Tensor): CLIP text encodings. + Shape: [bsz, 768] + t5_encodings (Tensor): T5 text encodings. + Shape: [bsz, sequence length, 256 or 512] + timesteps (Tensor): The amount of noise (0 to 1). + Shape: [bsz] + guidance (Optional[Tensor]): The guidance value (1.5 to 4) if guidance-enabled model. + Shape: [bsz] + Default: None + model_ctx (ContextManager): Optional context to wrap the model call (e.g. for activation offloading) + Default: nullcontext + Returns: + Tensor: The noise prediction. + Shape: [bsz, 16, latent height, latent width] + """ + bsz, _, latent_height, latent_width = latents.shape + + POSITION_DIM = 3 # constant for Flux flow model + with torch.no_grad(): + # Create positional encodings + latent_pos_enc = create_position_encoding_for_latents( + bsz, latent_height, latent_width, POSITION_DIM + ) + text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM) + + # Convert latent into a sequence of patches + latents = pack_latents(latents) + + # Predict noise + latent_noise_pred = model( + img=latents, + img_ids=latent_pos_enc.to(latents), + txt=t5_encodings.to(latents), + txt_ids=text_pos_enc.to(latents), + y=clip_encodings.to(latents), + timesteps=timesteps.to(latents), + guidance=guidance.to(latents) if guidance is not None else None, + ) + + # Convert sequence of patches to latent shape + latent_noise_pred = unpack_latents( + latent_noise_pred, latent_height, latent_width + ) + + return latent_noise_pred + + def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): + # generate t5 and clip + input_dict["image"] = labels + input_dict = self.preprocess_fn( + device=self.device, + dtype=self._dtype, + autoencoder=self.autoencoder, + clip_encoder=self.clip_encoder, + t5_encoder=self.t5_encoder, + batch=input_dict, + offload=True, + ) + labels = input_dict["img_encodings"] + + self.optimizers.zero_grad() + + # Keep these variables local to shorten the code as these are + # the major variables that are used in the training loop. + model_parts = self.model_parts + world_mesh = self.world_mesh + parallel_dims = self.parallel_dims + + # image in latent space transformed by self.auto_encoder + clip_encodings = input_dict["clip_encodings"] + t5_encodings = input_dict["t5_encodings"] + + bsz = labels.shape[0] + + with torch.no_grad(): + noise = torch.randn_like(labels) + timesteps = torch.rand((bsz,)).to(labels) + sigmas = timesteps.view(-1, 1, 1, 1) + noisy_latents = (1 - sigmas) * labels + sigmas * noise + guidance = torch.full((bsz,), self._guidance).to(labels) + + target = noise - labels + + assert len(model_parts) == 1 + # TODO(jianiw): model_parts will be wrapped by FSDP, which will cacluate + model_parts[0] = model_parts[0].to(dtype=self._dtype) + + pred = self._predict_noise( + model_parts[0], + noisy_latents, + clip_encodings, + t5_encodings, + timesteps, + guidance, + ) + loss = self.loss_fn(pred, target) + # pred.shape=(bs, seq_len, vocab_size) + # need to free to before bwd to avoid peaking memory + del (pred, noise, target) + loss.backward() + + dist_utils.clip_grad_norm_( + [p for m in model_parts for p in m.parameters()], + self.job_config.training.max_norm, + foreach=True, + pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None, + ) + self.checkpointer.maybe_wait_for_staging() + self.optimizers.step() + self.lr_schedulers.step() + + # log metrics + if not self.metrics_processor.should_log(self.step): + return + + if ( + parallel_dims.dp_replicate_enabled + or parallel_dims.dp_shard_enabled + or parallel_dims.cp_enabled + ): + loss = loss.detach() + global_avg_loss, global_max_loss = ( + dist_utils.dist_mean(loss, world_mesh["dp_cp"]), + dist_utils.dist_max(loss, world_mesh["dp_cp"]), + ) + else: + global_avg_loss = global_max_loss = loss.item() + + self.metrics_processor.log(self.step, global_avg_loss, global_max_loss) + + +if __name__ == "__main__": + init_logger() + config = JobConfig() + config.maybe_add_custom_args() + config.parse_args() + trainer: Optional[FluxTrainer] = None + + try: + trainer = FluxTrainer(config) + if config.checkpoint.create_seed_checkpoint: + assert int( + os.environ["WORLD_SIZE"] + ), "Must create seed checkpoint using a single device, to disable sharding." + assert ( + config.checkpoint.enable_checkpoint + ), "Must enable checkpointing when creating a seed checkpoint." + trainer.checkpointer.save(curr_step=0, force=True) + logger.info("Created seed checkpoint") + else: + trainer.train() + finally: + if trainer: + trainer.close() + + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + logger.info("Process group destroyed.") diff --git a/torchtitan/experiments/flux/train_configs/debug_model.toml b/torchtitan/experiments/flux/train_configs/debug_model.toml new file mode 100644 index 0000000000000000000000000000000000000000..250a71d60ec28028b548803bad7f14b6b3a6db62 --- /dev/null +++ b/torchtitan/experiments/flux/train_configs/debug_model.toml @@ -0,0 +1,68 @@ + +[job] +dump_folder = "./outputs" +description = "Flux debug model" +print_args = false +use_for_integration_test = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "flux" +flavor = "flux-debug" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm +# test tokenizer.model, for debug purpose only +# tokenizer_path = "./tests/assets/test_tiktoken.model" +# converters = "float8" + + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 0.0 + +[training] +batch_size = 32 +seq_len = 512 +max_norm = 1.0 # grad norm clipping +steps = 10 +compile = false +dataset = "cc12m" +guidance = 3.5 +seed = 0 + +[encoder] +t5_encoder="google/t5-v1_1-small" +clip_encoder="openai/clip-vit-large-patch14" +max_t5_encoding_len=512 +auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = 1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[experimental] +custom_args_module = "torchtitan.experiments.flux.flux_argparser" diff --git a/torchtitan/experiments/flux/utils.py b/torchtitan/experiments/flux/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..15db50d90c81ed0fa9f5296a1c725af8005e3601 --- /dev/null +++ b/torchtitan/experiments/flux/utils.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from torch import Tensor + +from torchtitan.experiments.flux.model.autoencoder import AutoEncoder +from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder + + +def preprocess_flux_data( + # arguments from the recipe + device: torch.device, + dtype: torch.dtype, + *, + # arguments from the config + autoencoder: Optional[AutoEncoder], + clip_encoder: FluxEmbedder, + t5_encoder: FluxEmbedder, + batch: dict[str, Tensor], + offload: bool = False, +) -> dict[str, Tensor]: + """ + Take a batch of inputs and encoder as input and return a batch of preprocessed data. + + Args: + device (torch.device): device to do preprocessing on + dtype (torch.dtype): data type to do preprocessing in + autoencoer(AutoEncoder): autoencoder to use for preprocessing + clip_encoder + t5_encoder + batch (dict[str, Tensor]): batch of data to preprocess + + Returns: + dict[str, Tensor]: batch of preprocessed data + """ + + # The input of encoder should be torch.int type + if offload: + clip_encoder.to(device) + t5_encoder.to(device) + if autoencoder is not None: + autoencoder.to(device) + + clip_tokens = batch["clip_tokens"].squeeze().to(device=device, dtype=torch.int) + t5_tokens = batch["t5_tokens"].squeeze().to(device=device, dtype=torch.int) + + clip_text_encodings = clip_encoder(clip_tokens) + t5_text_encodings = t5_encoder(t5_tokens) + + if autoencoder is not None: + images = batch["image"].to(device=device, dtype=dtype) + img_encodings = autoencoder.encode(images) + batch["img_encodings"] = img_encodings.to(device=device, dtype=dtype) + + batch["clip_encodings"] = clip_text_encodings.to(dtype) + batch["t5_encodings"] = t5_text_encodings.to(dtype) + + # offload encoders to cpu after preprocessing + if offload: + clip_encoder.to("cpu") + t5_encoder.to("cpu") + if autoencoder is not None: + autoencoder.to("cpu") + + return batch + + +def generate_noise_latent( + bsz: int, + height: int, + width: int, + device: str | torch.device, + dtype: torch.dtype, + seed: int, +) -> Tensor: + """Generate noise latents for the Flux flow model. + + Args: + bsz (int): batch_size. + height (int): The height of the image. + width (int): The width of the image. + device (str | torch.device): The device to use. + dtype (torch.dtype): The dtype to use. + seed (int): The seed to use for randomize. + + Returns: + Tensor: The noise latents. + Shape: [num_samples, LATENT_CHANNELS, height // IMG_LATENT_SIZE_RATIO, width // IMG_LATENT_SIZE_RATIO] + + """ + LATENT_CHANNELS, IMAGE_LATENT_SIZE_RATIO = 16, 8 + return torch.randn( + bsz, + LATENT_CHANNELS, + height // IMAGE_LATENT_SIZE_RATIO, + width // IMAGE_LATENT_SIZE_RATIO, + dtype=dtype, + generator=torch.Generator().manual_seed(seed), + ).to(device) + + +def create_position_encoding_for_latents( + bsz: int, latent_height: int, latent_width: int, position_dim: int = 3 +) -> Tensor: + """ + Create the packed latents' position encodings for the Flux flow model. + + Args: + bsz (int): The batch size. + latent_height (int): The height of the latent. + latent_width (int): The width of the latent. + + Returns: + Tensor: The position encodings. + Shape: [bsz, (latent_height // PATCH_HEIGHT) * (latent_width // PATCH_WIDTH), POSITION_DIM) + """ + PATCH_HEIGHT, PATCH_WIDTH = 2, 2 + + height = latent_height // PATCH_HEIGHT + width = latent_width // PATCH_WIDTH + + position_encoding = torch.zeros(height, width, position_dim) + + row_indices = torch.arange(height) + position_encoding[:, :, 1] = row_indices.unsqueeze(1) + + col_indices = torch.arange(width) + position_encoding[:, :, 2] = col_indices.unsqueeze(0) + + # Flatten and repeat for the full batch + # [height, width, 3] -> [bsz, height * width, 3] + position_encoding = position_encoding.view(1, height * width, position_dim) + position_encoding = position_encoding.repeat(bsz, 1, 1) + + return position_encoding + + +def pack_latents(x: Tensor) -> Tensor: + """ + Rearrange latents from an image-like format into a sequence of patches. + Equivalent to `einops.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)")`. + + Args: + x (Tensor): The unpacked latents. + Shape: [bsz, ch, latent height, latent width] + + Returns: + Tensor: The packed latents. + Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw) + """ + PATCH_HEIGHT, PATCH_WIDTH = 2, 2 + + b, c, latent_height, latent_width = x.shape + h = latent_height // PATCH_HEIGHT + w = latent_width // PATCH_WIDTH + + # [b, c, h*ph, w*ph] -> [b, c, h, w, ph, pw] + x = x.unfold(2, PATCH_HEIGHT, PATCH_HEIGHT).unfold(3, PATCH_WIDTH, PATCH_WIDTH) + + # [b, c, h, w, ph, PW] -> [b, h, w, c, ph, PW] + x = x.permute(0, 2, 3, 1, 4, 5) + + # [b, h, w, c, ph, PW] -> [b, h*w, c*ph*PW] + return x.reshape(b, h * w, c * PATCH_HEIGHT * PATCH_WIDTH) + + +def unpack_latents(x: Tensor, latent_height: int, latent_width: int) -> Tensor: + """ + Rearrange latents from a sequence of patches into an image-like format. + Equivalent to `einops.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)")`. + + Args: + x (Tensor): The packed latents. + Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw) + latent_height (int): The height of the unpacked latents. + latent_width (int): The width of the unpacked latents. + + Returns: + Tensor: The unpacked latents. + Shape: [bsz, ch, latent height, latent width] + """ + PATCH_HEIGHT, PATCH_WIDTH = 2, 2 + + b, _, c_ph_pw = x.shape + h = latent_height // PATCH_HEIGHT + w = latent_width // PATCH_WIDTH + c = c_ph_pw // (PATCH_HEIGHT * PATCH_WIDTH) + + # [b, h*w, c*ph*pw] -> [b, h, w, c, ph, pw] + x = x.reshape(b, h, w, c, PATCH_HEIGHT, PATCH_WIDTH) + + # [b, h, w, c, ph, pw] -> [b, c, h, ph, w, pw] + x = x.permute(0, 3, 1, 4, 2, 5) + + # [b, c, h, ph, w, pw] -> [b, c, h*ph, w*pw] + return x.reshape(b, c, h * PATCH_HEIGHT, w * PATCH_WIDTH) diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..7dbabd1317a5923545f24c9a77feca46f5a92130 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/benchmark.py @@ -0,0 +1,630 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Benchmark comparing reference PyTorch vs optimized M*G group GEMM implementation + +import argparse +import logging +import time + +# from typing import Dict, List, Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import torch +import triton + +# import triton.language as tl + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +# Try to import the optimized implementations +try: + from torchao_pr.mg_grouped_gemm import grouped_gemm_forward + +except ImportError: + logging.error( + "Error importing MG grouped GEMM modules. Make sure the implementation files are in the correct path." + ) + raise + + +def compute_reference_forward(x, w, m_sizes): + """ + Reference PyTorch implementation of M*G grouped GEMM forward pass. + + Args: + x (torch.Tensor): Input tensor of shape (M, K) + w (torch.Tensor): Weight tensor of shape (N, K) + m_sizes (torch.Tensor): Group sizes tensor of shape (G) + + Returns: + torch.Tensor: Output tensor of shape (M, N) + """ + result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device) + + m_start = 0 + for g in range(len(m_sizes)): + m_size = m_sizes[g].item() + if m_size > 0: + m_end = m_start + m_size + + # Extract group input + x_g = x[m_start:m_end] + + # Compute group output + y_g = torch.matmul(x_g, w.T) + + # Store result + result[m_start:m_end] = y_g + + # Update start index + m_start = m_end + + return result + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["N"], # We'll vary the output dimension + x_vals=[1024, 2048, 4096, 8192, 16384], # Different output dimensions to test + # x_vals=[8192, 16384], + line_arg="provider", # We'll compare different providers + line_vals=["pytorch_reference", "M*G grouped GEMM"], + line_names=["PyTorch Reference", "M*G grouped Kernel"], + styles=[("blue", "-"), ("red", "-")], + ylabel="TFLOPS", # We'll measure TFLOPS + plot_name="mg_grouped_gemm_comparison", + args={ + "M": 8192, # Batch dimension, fixed for all tests + "K": 7168, # Hidden dimension, fixed for all tests + "G": 8, # Number of groups + "dtype": torch.float16, + "device": "cuda", + }, + ) +) +def benchmark_forward(M, K, N, G, provider, dtype=torch.float16, device="cuda"): + """ + Benchmark the forward pass of the grouped GEMM implementation. + + Args: + M (int): Total batch size dimension + K (int): Hidden dimension + N (int): Output dimension + G (int): Number of groups + provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel') + dtype (torch.dtype): Data type to use + device (str): Device to use + + Returns: + float: Performance in TFLOPS + """ + # Create group sizes for M dimension (balanced across groups) + base_size = M // G + remainder = M % G + M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)] + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + print(f"N: {N}, M: {M}, K: {K}, G: {G}, dtype: {dtype}, device: {device}") + + # Create input and weight tensors + x = torch.randn(M, K, dtype=dtype, device=device) + w = torch.randn(N, K, dtype=dtype, device=device) + + # Pre-compute for PyTorch reference to ensure fair comparison + if provider == "pytorch_reference": + # Warmup + torch.cuda.synchronize() + compute_reference_forward(x, w, m_sizes) + torch.cuda.synchronize() + + # Benchmark + start_time = time.time() + for _ in range(10): # Average over 10 runs + compute_reference_forward(x, w, m_sizes) + torch.cuda.synchronize() + end_time = time.time() + else: # Optimized kernel + # Warmup + torch.cuda.synchronize() + grouped_gemm_forward(x, w, m_sizes) + torch.cuda.synchronize() + + # Benchmark + start_time = time.time() + for _ in range(10): # Average over 10 runs + grouped_gemm_forward(x, w, m_sizes) + torch.cuda.synchronize() + end_time = time.time() + + # Calculate FLOPs + # For GEMM: 2 * M * N * K FLOPs (multiply-add counts as 2 FLOPs) + flops = 2 * M * N * K + + # Convert to TFLOPS (tera-FLOPS) + avg_time = (end_time - start_time) / 10 # Average time per run + tflops = flops / avg_time / 1e12 + + return tflops + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["G"], # We'll vary the number of groups + x_vals=[1, 2, 4, 8, 16], # Different numbers of groups to test + line_arg="provider", # We'll compare different providers + line_vals=["pytorch_reference", "optimized_kernel"], + line_names=["PyTorch Reference", "Optimized Kernel"], + styles=[("blue", "-"), ("red", "-")], + ylabel="TFLOPS", # We'll measure TFLOPS + plot_name="mg_grouped_gemm_group_scaling", + args={ + "M": 8192, # Batch dimension, fixed for all tests + "K": 4096, # Hidden dimension, fixed for all tests + "N": 8192, # Output dimension, fixed for all tests + "dtype": torch.float16, + "device": "cuda", + }, + ) +) +def benchmark_forward_groups(M, K, N, G, provider, dtype=torch.float16, device="cuda"): + """ + Benchmark how performance scales with number of groups. + + Args: + M (int): Total batch size dimension + K (int): Hidden dimension + N (int): Output dimension + G (int): Number of groups + provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel') + dtype (torch.dtype): Data type to use + device (str): Device to use + + Returns: + float: Performance in TFLOPS + """ + # Create group sizes for M dimension (balanced across groups) + base_size = M // G + remainder = M % G + M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)] + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create input and weight tensors + x = torch.randn(M, K, dtype=dtype, device=device) + w = torch.randn(N, K, dtype=dtype, device=device) + + # Benchmark logic - same as previous function + if provider == "pytorch_reference": + torch.cuda.synchronize() + compute_reference_forward(x, w, m_sizes) + torch.cuda.synchronize() + + start_time = time.time() + for _ in range(10): + compute_reference_forward(x, w, m_sizes) + torch.cuda.synchronize() + end_time = time.time() + else: + torch.cuda.synchronize() + grouped_gemm_forward(x, w, m_sizes) + torch.cuda.synchronize() + + start_time = time.time() + for _ in range(10): + grouped_gemm_forward(x, w, m_sizes) + torch.cuda.synchronize() + end_time = time.time() + + # Calculate FLOPs and TFLOPS + flops = 2 * M * N * K + avg_time = (end_time - start_time) / 10 + tflops = flops / avg_time / 1e12 + + return tflops + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["group_balance"], # We'll vary the group balance factor + x_vals=[ + 0.0, + 0.25, + 0.5, + 0.75, + 0.9, + ], # Different imbalance factors (0 = balanced, 1 = max imbalance) + line_arg="provider", # We'll compare different providers + line_vals=["pytorch_reference", "optimized_kernel"], + line_names=["PyTorch Reference", "Optimized Kernel"], + styles=[("blue", "-"), ("red", "-")], + ylabel="TFLOPS", # We'll measure TFLOPS + plot_name="mg_grouped_gemm_imbalance", + args={ + "M": 8192, # Batch dimension, fixed for all tests + "K": 4096, # Hidden dimension, fixed for all tests + "N": 8192, # Output dimension, fixed for all tests + "G": 4, # Number of groups + "dtype": torch.float16, + "device": "cuda", + }, + ) +) +def benchmark_imbalance( + M, K, N, G, group_balance, provider, dtype=torch.float16, device="cuda" +): + """ + Benchmark how performance is affected by imbalanced group sizes. + + Args: + M (int): Total batch size dimension + K (int): Hidden dimension + N (int): Output dimension + G (int): Number of groups + group_balance (float): Balance factor from 0 to 1 (0 = balanced, 1 = max imbalance) + provider (str): Provider to use ('pytorch_reference' or 'optimized_kernel') + dtype (torch.dtype): Data type to use + device (str): Device to use + + Returns: + float: Performance in TFLOPS + """ + # Create imbalanced group sizes for M dimension + if group_balance == 0: + # Balanced case + base_size = M // G + remainder = M % G + M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)] + else: + # Imbalanced case + # First group gets more elements, last group gets fewer + # The imbalance is controlled by the group_balance factor + remaining = M + M_sizes = [] + for g in range(G): + # Interpolate from balanced to imbalanced based on group_balance + # For balanced (group_balance=0), each group gets M/G + # For imbalanced (group_balance=1), first group gets much more than last group + balanced_size = remaining // (G - g) + + # Adjusting size based on position and imbalance factor + # First groups get more, last groups get less + if g < G // 2: + # First half of groups get more + adjustment = int(balanced_size * group_balance * (1 - g / (G - 1))) + size = balanced_size + adjustment + else: + # Second half of groups get less + adjustment = int(balanced_size * group_balance * ((g / (G - 1)) - 0.5)) + size = balanced_size - adjustment + + # Ensure we don't go below 1 or take more than remaining + size = max(1, min(size, remaining)) + M_sizes.append(size) + remaining -= size + + # Handle any remaining elements + if remaining > 0: + M_sizes[-1] += remaining + + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create input and weight tensors + x = torch.randn(M, K, dtype=dtype, device=device) + w = torch.randn(N, K, dtype=dtype, device=device) + + # Benchmark logic + if provider == "pytorch_reference": + torch.cuda.synchronize() + compute_reference_forward(x, w, m_sizes) + torch.cuda.synchronize() + + start_time = time.time() + for _ in range(10): + compute_reference_forward(x, w, m_sizes) + torch.cuda.synchronize() + end_time = time.time() + else: + torch.cuda.synchronize() + grouped_gemm_forward(x, w, m_sizes) + torch.cuda.synchronize() + + start_time = time.time() + for _ in range(10): + grouped_gemm_forward(x, w, m_sizes) + torch.cuda.synchronize() + end_time = time.time() + + # Calculate FLOPs and TFLOPS + flops = 2 * M * N * K + avg_time = (end_time - start_time) / 10 + tflops = flops / avg_time / 1e12 + + return tflops + + +def benchmark_model_configs(): + """ + Benchmark common model configurations used in DeepSeek-like models. + """ + # Model configurations: (M, K, N, G) + configs = [ + (8192, 7168, 4096, 4), # Config 1 + (8192, 2048, 7168, 4), # Config 2 + (4096, 7168, 4096, 8), # Config 3 + (4096, 2048, 7168, 8), # Config 4 + ] + + results = [] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float16 + + for config_idx, (M, K, N, G) in enumerate(configs): + logging.info(f"\n===== Benchmarking DeepSeek Config {config_idx + 1} =====") + logging.info(f"M={M}, K={K}, N={N}, G={G}") + + # Create group sizes for M dimension + base_size = M // G + remainder = M % G + M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)] + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create tensors + x = torch.randn(M, K, dtype=dtype, device=device) + w = torch.randn(N, K, dtype=dtype, device=device) + + # Benchmark PyTorch reference + torch.cuda.synchronize() + compute_reference_forward(x, w, m_sizes) # Warmup + torch.cuda.synchronize() + + logging.info("Benchmarking PyTorch reference...") + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + for _ in range(10): + compute_reference_forward(x, w, m_sizes) + torch.cuda.synchronize() + end_time = time.time() + pt_time = (end_time - start_time) / 10 + pt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB + + # Benchmark optimized kernel + torch.cuda.synchronize() + grouped_gemm_forward(x, w, m_sizes) # Warmup + torch.cuda.synchronize() + + logging.info("Benchmarking optimized kernel...") + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + for _ in range(10): + grouped_gemm_forward(x, w, m_sizes) + torch.cuda.synchronize() + end_time = time.time() + opt_time = (end_time - start_time) / 10 + opt_memory = torch.cuda.max_memory_allocated() / (1024**2) # MB + + # Calculate FLOPs and speedup + flops = 2 * M * N * K + pt_tflops = flops / pt_time / 1e12 + opt_tflops = flops / opt_time / 1e12 + speedup = pt_time / opt_time + + # Store results + results.append( + { + "config": f"Config {config_idx + 1}", + "dimensions": f"M={M}, K={K}, N={N}, G={G}", + "pt_time_ms": pt_time * 1000, + "opt_time_ms": opt_time * 1000, + "pt_tflops": pt_tflops, + "opt_tflops": opt_tflops, + "speedup": speedup, + "pt_memory_mb": pt_memory, + "opt_memory_mb": opt_memory, + "memory_savings": ( + (pt_memory - opt_memory) / pt_memory * 100 if pt_memory > 0 else 0 + ), + } + ) + + logging.info( + f"PyTorch Reference: {pt_time * 1000:.2f} ms, {pt_tflops:.2f} TFLOPS, {pt_memory:.2f} MB" + ) + logging.info( + f"Optimized Kernel: {opt_time * 1000:.2f} ms, {opt_tflops:.2f} TFLOPS, {opt_memory:.2f} MB" + ) + logging.info( + f"Speedup: {speedup:.2f}x, Memory savings: {results[-1]['memory_savings']:.2f}%" + ) + + # Print summary table + logging.info("\n===== Benchmark Results Summary =====") + logging.info( + f"{'Config':<10} | {'Time (ms)':<20} | {'TFLOPS':<20} | {'Speedup':<10} | {'Memory (MB)':<20} | {'Memory Saved':<12}" + ) + logging.info( + f"{'':<10} | {'PyTorch':<9} {'Kernel':<9} | {'PyTorch':<9} {'Kernel':<9} | {'':<10} | " + f"{'PyTorch':<9} {'Kernel':<9} | {'':<12}" + ) + logging.info("-" * 100) + + for result in results: + logging.info( + f"{result['config']:<10} | " + f"{result['pt_time_ms']:<9.2f} {result['opt_time_ms']:<9.2f} | " + f"{result['pt_tflops']:<9.2f} {result['opt_tflops']:<9.2f} | " + f"{result['speedup']:<10.2f} | " + f"{result['pt_memory_mb']:<9.2f} {result['opt_memory_mb']:<9.2f} | " + f"{result['memory_savings']:<12.2f}%" + ) + + return results + + +def plot_benchmark_results(results): + """ + Plot benchmark results as bar charts. + """ + # Extract data + configs = [r["config"] for r in results] + pt_tflops = [r["pt_tflops"] for r in results] + opt_tflops = [r["opt_tflops"] for r in results] + speedups = [r["speedup"] for r in results] + + # Create figure with subplots + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) + + # Plot TFLOPS comparison + x = np.arange(len(configs)) + width = 0.35 + ax1.bar(x - width / 2, pt_tflops, width, label="PyTorch Reference") + ax1.bar(x + width / 2, opt_tflops, width, label="Optimized Kernel") + ax1.set_xlabel("Model Configuration") + ax1.set_ylabel("TFLOPS") + ax1.set_title("Performance Comparison (Higher is Better)") + ax1.set_xticks(x) + ax1.set_xticklabels(configs) + ax1.legend() + ax1.grid(axis="y", linestyle="--", alpha=0.7) + + # Plot speedup + ax2.bar(x, speedups, width=0.6, color="green") + ax2.set_xlabel("Model Configuration") + ax2.set_ylabel("Speedup (x)") + ax2.set_title("Speedup Factor (Higher is Better)") + ax2.set_xticks(x) + ax2.set_xticklabels(configs) + ax2.grid(axis="y", linestyle="--", alpha=0.7) + + # Add speedup values on top of bars + for i, v in enumerate(speedups): + ax2.text(i, v + 0.1, f"{v:.2f}x", ha="center") + + plt.tight_layout() + plt.savefig("mg_grouped_gemm_benchmark_results.png") + logging.info( + "Benchmark results plot saved to 'mg_grouped_gemm_benchmark_results.png'" + ) + + +def compare_mg_implementations(): + """ + Combine the M*G and N*G benchmark results for comparison. + """ + # Only run this if both NG and MG benchmarks have been run + try: + import pandas as pd + + # Try to load previous benchmark results + mg_results = pd.read_csv("mg_grouped_gemm_benchmark_results.csv") + ng_results = pd.read_csv("ng_grouped_gemm_benchmark_results.csv") + + # Create comparison plot + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + + # Plot speedup comparison + configs = mg_results["config"].unique() + mg_speedups = mg_results.groupby("config")["speedup"].mean() + ng_speedups = ng_results.groupby("config")["speedup"].mean() + + x = np.arange(len(configs)) + width = 0.35 + + axes[0].bar(x - width / 2, mg_speedups, width, label="M*G Grouping") + axes[0].bar(x + width / 2, ng_speedups, width, label="N*G Grouping") + axes[0].set_xlabel("Model Configuration") + axes[0].set_ylabel("Speedup (x)") + axes[0].set_title("Speedup Comparison: M*G vs N*G") + axes[0].set_xticks(x) + axes[0].set_xticklabels(configs) + axes[0].legend() + axes[0].grid(axis="y", linestyle="--", alpha=0.7) + + # Plot TFLOPS comparison for optimized kernels + mg_tflops = ( + mg_results[mg_results["implementation"] == "optimized"] + .groupby("config")["tflops"] + .mean() + ) + ng_tflops = ( + ng_results[ng_results["implementation"] == "optimized"] + .groupby("config")["tflops"] + .mean() + ) + + axes[1].bar(x - width / 2, mg_tflops, width, label="M*G Grouping") + axes[1].bar(x + width / 2, ng_tflops, width, label="N*G Grouping") + axes[1].set_xlabel("Model Configuration") + axes[1].set_ylabel("TFLOPS") + axes[1].set_title("Performance Comparison: M*G vs N*G") + axes[1].set_xticks(x) + axes[1].set_xticklabels(configs) + axes[1].legend() + axes[1].grid(axis="y", linestyle="--", alpha=0.7) + + plt.tight_layout() + plt.savefig("mg_vs_ng_comparison.png") + logging.info("Comparison plot saved to 'mg_vs_ng_comparison.png'") + + except Exception as e: + logging.error(f"Could not create comparison plot: {e}") + logging.info( + "Run both M*G and N*G benchmarks first to generate comparison plots" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Benchmark M*G Grouped GEMM implementations" + ) + parser.add_argument("--run-all", action="store_true", help="Run all benchmarks") + parser.add_argument( + "--triton-bench", action="store_true", help="Run Triton performance reports" + ) + parser.add_argument( + "--model-configs", action="store_true", help="Benchmark model configurations" + ) + parser.add_argument( + "--compare-mg-ng", + action="store_true", + help="Compare M*G and N*G implementations", + ) + args = parser.parse_args() + + # Check if CUDA is available + if not torch.cuda.is_available(): + logging.error( + "CUDA is not available. This benchmark requires a CUDA-capable GPU." + ) + exit(1) + + if args.run_all or args.model_configs: + # Benchmark model configurations + logging.info("Running benchmark for model configurations...") + results = benchmark_model_configs() + plot_benchmark_results(results) + + if args.run_all or args.triton_bench: + # Run Triton performance reports + logging.info("Running Triton performance reports...") + benchmark_forward.run(save_path="mg_grouped_gemm_benchmark_results") + benchmark_forward_groups.run(save_path="mg_grouped_gemm_benchmark_results") + benchmark_imbalance.run(save_path="mg_grouped_gemm_benchmark_results") + logging.info( + "Triton performance reports saved to 'mg_grouped_gemm_benchmark_results' directory" + ) + + if args.run_all or args.compare_mg_ng: + # Compare M*G and N*G implementations + logging.info("Comparing M*G and N*G implementations...") + compare_mg_implementations() diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c90da16c282d4b8280f72ad8a0deb94484f59372 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .mg_grouped_gemm import grouped_gemm_forward +from .tma_autotuning import ALIGN_SIZE_M + +__all__ = [ + "grouped_gemm_forward", + "ALIGN_SIZE_M", +] diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py new file mode 100644 index 0000000000000000000000000000000000000000..76e0b12d882fa46ed1f11139352141f06d899f59 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/fast_debug_ao.py @@ -0,0 +1,299 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging + +import numpy as np +import torch + +from reference_utils import ( + analyze_tensor_differences, + compute_reference_backward, + compute_reference_forward, +) + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +# Import grouped GEMM implementations +try: + from mg_grouped_gemm import grouped_gemm_backward, grouped_gemm_forward + +except ImportError: + logging.error( + "Error importing grouped GEMM modules. Make sure the implementation files are in the correct path." + ) + raise + + +def test_forward_pass(): + """ + A simple test for the M*G grouped GEMM forward pass with detailed error handling. + + In M*G grouping: + - M dimension is partitioned into G groups (M_total = sum(M_sizes)) + - N dimension is the same for all groups + """ + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Test parameters for DeepSeek-like models + G = 1 # Number of groups + M_sizes = [ + 2048, + ] # 2048, 2048, 2048] # Group sizes (will be adjusted) + M_total = sum(M_sizes) # Total M dimension + N = 4096 # Output dimension (same for all groups) + K = 7168 # Hidden dimension + + # Create group sizes tensor + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create input and weight tensors - using float16 for higher precision + x = torch.randn(M_total, K, dtype=torch.float16, device=device) + w = torch.randn(N, K, dtype=torch.float16, device=device) + + # Log the setup + logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}") + logging.info(f"Group sizes: {m_sizes}") + logging.info(f"Input x shape: {x.shape}") + logging.info(f"Weight w shape: {w.shape}") + + # Run forward pass + logging.info("Running forward pass with grouped GEMM") + result = grouped_gemm_forward(x, w, m_sizes) + logging.info(f"Forward result shape: {result.shape}") + + # Compute reference result + logging.info("Computing reference result with PyTorch") + reference_result = compute_reference_forward(x, w, m_sizes) + + # Compare results + logging.info("Comparing with PyTorch reference") + forward_close = analyze_tensor_differences( + result, reference_result, "Forward output" + ) + + return forward_close + + except Exception as e: + logging.error(f"Test failed with error: {e}") + import traceback + + logging.error(traceback.format_exc()) + return False + + +def test_backward_pass(): + """ + A simple test for the M*G grouped GEMM backward pass with detailed error handling. + + In M*G grouping: + - M dimension is partitioned into G groups (M_total = sum(M_sizes)) + - N dimension is the same for all groups + """ + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Test parameters for DeepSeek-like models + G = 4 # Number of groups + M_sizes = [2048, 2048, 2048, 2048] # Group sizes (will be adjusted) + M_total = sum(M_sizes) # Total M dimension + N = 4096 # Output dimension (same for all groups) + K = 7168 # Hidden dimension + + # Create group sizes tensor + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create input and weight tensors - using float16 for higher precision + x = torch.randn( + M_total, K, dtype=torch.float16, device=device, requires_grad=True + ) + w = torch.randn(N, K, dtype=torch.float16, device=device, requires_grad=True) + + # Log the setup + logging.info(f"Test setup - G: {G}, M_total: {M_total}, N: {N}, K: {K}") + logging.info(f"Group sizes: {m_sizes}") + logging.info(f"Input x shape: {x.shape}") + logging.info(f"Weight w shape: {w.shape}") + + # Step 1: Run forward pass + logging.info("Running forward pass") + result = grouped_gemm_forward(x, w, m_sizes) + logging.info(f"Forward result shape: {result.shape}") + + # Create a gradient for backpropagation + grad_output = torch.randn_like(result) + logging.info(f"Created gradient with shape: {grad_output.shape}") + + # Step 2: Run backward pass directly + logging.info("Running backward pass directly") + grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes) + + # Verify gradient shapes + logging.info( + f"Gradient shapes - grad_x: {grad_x.shape}, grad_w: {grad_w.shape}" + ) + + # Step 3: Verify gradient computation using PyTorch's autograd + logging.info("Running PyTorch reference implementation") + + # Compute reference gradients + x_ref_grad, w_ref_grad = compute_reference_backward(x, w, m_sizes, grad_output) + + # Compare gradients + logging.info("Comparing gradients with PyTorch reference") + grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x") + grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w") + + # Log overall result + if grad_x_close and grad_w_close: + logging.info("✓ SUCCESS: Gradients match the PyTorch reference") + else: + logging.error("✗ FAILURE: Gradient mismatch detected") + + return grad_x_close and grad_w_close + + except Exception as e: + logging.error(f"Test failed with error: {e}") + import traceback + + logging.error(traceback.format_exc()) + return False + + +def test_multiple_deepseek_configs(): + """ + Test multiple DeepSeek model configurations with both forward and backward pass verification. + """ + # DeepSeek configurations: (G, M, K, N) + configs = [ + (4, 8192, 7168, 4096), # Config 1 + (4, 8192, 2048, 7168), # Config 2 + (8, 4096, 7168, 4096), # Config 3 + (8, 4096, 2048, 7168), # Config 4 + ] + + results = [] + + for config_idx, (G, M, K, N) in enumerate(configs): + logging.info(f"\n\n===== Testing DeepSeek Config {config_idx+1} =====") + logging.info(f"G={G}, M={M}, K={K}, N={N}") + + try: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Create even group sizes + base_size = M // G + remainder = M % G + M_sizes = [base_size + (1 if i < remainder else 0) for i in range(G)] + m_sizes = torch.tensor(M_sizes, device=device, dtype=torch.int32) + + # Create input and weight tensors using float16 for higher precision + x = torch.randn( + M, K, dtype=torch.float16, device=device, requires_grad=True + ) + w = torch.randn( + N, K, dtype=torch.float16, device=device, requires_grad=True + ) + + logging.info(f"Input x shape: {x.shape}, Weight w shape: {w.shape}") + + # Run forward pass + result = grouped_gemm_forward(x, w, m_sizes) + logging.info(f"Forward result shape: {result.shape}") + + # ===== FORWARD PASS VERIFICATION ===== + # Compute reference forward result + reference_result = compute_reference_forward(x, w, m_sizes) + + # Compare forward results + forward_close = analyze_tensor_differences( + result, reference_result, "Forward output" + ) + + # ===== BACKWARD PASS VERIFICATION ===== + # Create gradient for backpropagation + grad_output = torch.randn_like(result) + + # Run backward pass + grad_x, grad_w = grouped_gemm_backward(grad_output, x, w, m_sizes) + + # Compute reference gradients + x_ref_grad, w_ref_grad = compute_reference_backward( + x, w, m_sizes, grad_output + ) + + # Compare backward results + grad_x_close = analyze_tensor_differences(grad_x, x_ref_grad, "grad_x") + grad_w_close = analyze_tensor_differences(grad_w, w_ref_grad, "grad_w") + + # Overall config result + backward_close = grad_x_close and grad_w_close + config_success = forward_close and backward_close + results.append( + (config_idx + 1, config_success, forward_close, backward_close) + ) + + # Log overall config result + if config_success: + logging.info(f"✓ SUCCESS: Config {config_idx+1} passed all tests!") + else: + logging.error( + f"✗ FAILURE: Config {config_idx+1} failed one or more tests" + ) + + except Exception as e: + logging.error(f"Config {config_idx+1} test failed with error: {e}") + import traceback + + logging.error(traceback.format_exc()) + results.append((config_idx + 1, False, False, False)) + + # Summary + logging.info("\n===== Test Results Summary =====") + for config_idx, overall_success, forward_success, backward_success in results: + overall_status = "✓ PASSED" if overall_success else "✗ FAILED" + forward_status = "✓ PASSED" if forward_success else "✗ FAILED" + backward_status = "✓ PASSED" if backward_success else "✗ FAILED" + + logging.info(f"Config {config_idx}: {overall_status}") + logging.info(f" - Forward pass: {forward_status}") + logging.info(f" - Backward pass: {backward_status}") + + return all(overall_success for _, overall_success, _, _ in results) + + +if __name__ == "__main__": + logging.info( + "Running verification for both forward and backward pass of M*G grouped GEMM" + ) + + # Run basic forward pass test + logging.info("\n===== Running basic forward pass test =====") + success_forward = test_forward_pass() + logging.info(f"Basic forward test {'succeeded' if success_forward else 'failed'}") + + # Run basic backward pass test + logging.info("\n===== Running basic backward pass test =====") + success_backward = test_backward_pass() + logging.info(f"Basic backward test {'succeeded' if success_backward else 'failed'}") + + # Run multiple DeepSeek configs with forward and backward verification + logging.info("\n===== Running tests for all DeepSeek configs =====") + success_configs = test_multiple_deepseek_configs() + logging.info( + f"DeepSeek configs tests {'all succeeded' if success_configs else 'had failures'}" + ) + + # Overall result + overall_success = success_forward and success_backward and success_configs + logging.info( + f"\nOverall test result: {'SUCCESS' if overall_success else 'FAILURE'}" + ) diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..37bf59f29e89b0bd3abb69d3e5d75bc14721b97b --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py @@ -0,0 +1,1304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# credit - flat index forward kernel is derived from FBGemm: +# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm + +# pyre-unsafe +import functools +import logging + +import os +import sys +from typing import Any, Dict, Optional, Tuple + +import torch + +import triton +import triton.language as tl +from triton import Config as TConfig + +from triton.runtime import driver # @manual + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from tma_autotuning import ( + ALIGN_SIZE_M, + _NV_CONFIGS, + CudaUtils, + early_config_prune, + TmaDescriptorHelper, +) + + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +# ============== Start Triton Kernels =============== + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_forward_hopper( + a_desc_ptr, + b_desc_ptr, + c_ptr, + workspace, + m_sizes, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + TMA_SIZE: tl.constexpr, + USE_EPILOGUE_SUBTILING: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + Flat index style forward kernel for Hopper. + For simplicity, we always use TMA Load and TMA Store + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = c_ptr.dtype.element_ty # output dtype + + c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store + + M_end = 0 + M_start = 0 + processed_tiles = 0 + # Size of individual weight matrix + n_size = N // G + n_start = 0 + + for g in range(G): + # Move down along groups + # reset to new M offset + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + n_start = n_size * g + + if m_size > 0: + # Process this group + + # Acquire hold on c_desc_ptr for TMA Store + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr + M_start * n_size, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[m_size, n_size], + element_ty=c_dtype, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + # tiles for this group + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + # columnwise + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + global_n_offset = (n_start + n_offset).to(tl.int32) + + for k_offset in range(0, K, BLOCK_SIZE_K): + # input block [M,K] + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + c_dtype, + ) + # weight block [N, K] + b = tl._experimental_descriptor_load( + b_desc_ptr, + [global_n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + c_dtype, + ) + + accumulator += tl.dot(a, b.T) + + # Store using TMA + + m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + + if USE_EPILOGUE_SUBTILING: + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + c0 = acc0.to(c_dtype) + tl._experimental_descriptor_store( + c_desc_ptr, c0, [m_offset, n_offset] + ) + c1 = acc1.to(c_dtype) + tl._experimental_descriptor_store( + c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2] + ) + else: + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_dtype), + [m_offset, n_offset], + ) + # move to next tile in group + tbidx += NUM_SMS + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_forward_tma( + a_desc_ptr, + b_desc_ptr, + c_ptr, + workspace, + m_sizes, + a_scale_ptr, + b_scale_ptr, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + USE_FP8: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + Flat index style forward kernel. + For simplicity, we always use TMA Load and TMA Store + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = c_ptr.dtype.element_ty + + c_desc_ptr = workspace + (tbidx * TMA_SIZE) + + M_end = 0 + processed_tiles = 0 + + for g in range(G): + # Move down along groups + # reset to new M offset + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + if m_size > 0: + # Process this group + n_size = N + + # TMA Store prep + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr + M_start * N, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[m_size, n_size], + element_ty=c_dtype, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + # tiles for this group + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + + for k_offset in range(0, K, BLOCK_SIZE_K): + # input block [M,K] + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + c_dtype, + ) + # weight block [N, K] + b = tl._experimental_descriptor_load( + b_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + c_dtype, + ) + + accumulator += tl.dot(a, b.T) + + # Store using TMA + + m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + # n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_dtype), + [m_offset, n_offset], + ) + + # Move to the next tile + tbidx += NUM_SMS + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_forward_no_tma( + a_ptr, + b_ptr, + c_ptr, + workspace, + m_sizes, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + Flat index style forward kernel. + For bc and Ampere, we never use TMA Load and TMA Store + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = c_ptr.dtype.element_ty + c_desc_ptr = None + + M_end = 0 + processed_tiles = 0 + + for g in range(G): + # Move down along groups + # reset to new M offset + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + if m_size > 0: + # Process this group + n_size = N + + # tiles for this group + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + group_num_tiles = num_m_tiles * num_n_tiles + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + tile_m_index = group_index % num_m_tiles + tile_n_index = group_index // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32) + + offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :] + b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :] + + for k_offset in range(0, K, BLOCK_SIZE_K): + # Load with bounds checking + a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) + b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) + + # Main matmul + accumulator += tl.dot(a, b.T) + + # Update pointers for next block + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + + # Store without TMA + offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + c = accumulator.to(c_dtype) + + tl.store( + c_ptr + + (M_start + offs_am[:, None]) * N # Row stride is N + + offs_bn[None, :], # Column offset + c, + mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size, + ) + # Move to the next tile + tbidx += NUM_SMS + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +""" +Backward pass for grouped GEMM with Triton, where grouping is M*G +We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`). +""" + + +# ---- dx flat linear indexed ---- +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_dx_tma( + grad_output_desc_ptr, # [MG, N] + w_desc_ptr, # [N, K] + grad_input_ptr, # output grad_x [MG, K] + workspace, # for TMA store + m_sizes, # group sizes [G] + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + # tiles + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +) -> None: + """ + TMA-optimized kernel for computing gradients with respect to input (dx). + For the forward pass Y = X @ W.T, the backward for input is: + grad_X = grad_Y @ W + + This maps to [MG, N] @ [N, K] -> [MG, K] + + Key differences from forward: + 1. W is used directly and not transposed + 2. The reduction dimension is now N (not K) + 3. Output is [M, K] instead of [M, N] + """ + tbidx = tl.program_id(0) # thread block index + + c_dtype = grad_input_ptr.dtype.element_ty + c_desc_ptr = workspace + (tbidx * TMA_SIZE) + + M_end = 0 + processed_tiles = 0 + + for g in range(G): + # Move down along groups - same as forward + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + if m_size > 0: + # Process this group + # tiles for this group - now producing [M, K] output + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + group_num_tiles = num_m_tiles * num_k_tiles + + # TMA Store prep for [M, K] output + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=grad_input_ptr + M_start * K, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K], + global_size=[m_size, K], + element_ty=c_dtype, + ) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + while tbidx >= processed_tiles and tbidx < ( + processed_tiles + group_num_tiles + ): + group_index = tbidx - processed_tiles + + # Different tiling scheme for [M, K] output + tile_m_index = group_index % num_m_tiles + tile_k_index = group_index // num_m_tiles + + # for grad_input block [M, K] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + + # Position in full matrix + m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32) + k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32) + + # reduce along N dimension (instead of K in forward) + for n_offset in range(0, N, BLOCK_SIZE_N): + # grad_output block [M, N] + grad_output = tl._experimental_descriptor_load( + grad_output_desc_ptr, + [m_offset, n_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + c_dtype, + ) + + # weight block [N, K] - no transpose needed + w = tl._experimental_descriptor_load( + w_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + c_dtype, + ) + + # grad_x = grad_output @ w + # reducing along N dimension + accumulator += tl.dot(grad_output, w) + + # Store using TMA + m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32) + # k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32) + + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_dtype), + [m_offset, k_offset], + ) + + # Move to the next tile + tbidx += NUM_SMS + + # Update the total tiles count for the next group + processed_tiles += group_num_tiles + + +# ---- dw flat linear indexed ---- + + +@triton.autotune( + configs=_NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_mg_dw_tma( + x_desc_ptr, # input descriptor [M_total, K] + grad_output_desc_ptr, # grad_output descriptor [M_total, N] + grad_weight_ptr, # output grad_w [N, K] + workspace, # workspace for TMA store + m_sizes, # group sizes [G] + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + # config + NUM_SMS: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + TMA_SIZE: tl.constexpr, + # tiles + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension +) -> None: + """ + Improved TMA-optimized kernel for computing gradients with respect to weights (dw). + Uses flat index structure similar to forward. + + For the forward pass Y = X @ W.T, + the backward for weights is: + grad_W = grad_Y.T @ X + + Where: + - grad_Y is [MG, N] + - X is [MG, K] + - grad_W is [N, K] + - we return [N,K] + """ + # Get thread block index l + tbidx = tl.program_id(0) + + # Get output data type + c_dtype = grad_weight_ptr.dtype.element_ty + + # Calculate number of output tiles + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + total_output_tiles = num_n_tiles * num_k_tiles + + # Process tiles in strided manner across SMs + for tile_idx in range(tbidx, total_output_tiles, NUM_SMS): + # Calculate tile indices + tile_n_idx = tile_idx % num_n_tiles + tile_k_idx = tile_idx // num_n_tiles + + # Calculate global offsets + n_offset = tile_n_idx * BLOCK_SIZE_N + k_offset = tile_k_idx * BLOCK_SIZE_K + + # Initialize accumulator for this output tile [N, K] + accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) + + # Process each group + M_end = 0 + for g in range(G): + # Get group boundaries + M_start = M_end + m_size = tl.load(m_sizes + g) + M_end = M_start + m_size + + # Only process if group is non-empty + if m_size > 0: + # Process this group in chunks along the M dimension + for m_offset in range(0, m_size, BLOCK_SIZE_M): + # Calculate actual block size (handling boundary) + m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset) + + # Only process if we have actual work to do + if m_block_size > 0: + # Global offset for this chunk + m_global_offset = M_start + m_offset + + if USE_TMA_LOAD: + # Load input chunk [M_chunk, K] using TMA + x_block = tl._experimental_descriptor_load( + x_desc_ptr, + [m_global_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + c_dtype, + ) + + # Load grad_output chunk [M_chunk, N] using TMA + grad_output_block = tl._experimental_descriptor_load( + grad_output_desc_ptr, + [m_global_offset, n_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + c_dtype, + ) + + # Apply masks for valid regions + offs_m = tl.arange(0, BLOCK_SIZE_M) + m_mask = offs_m < m_block_size + + # Zero out invalid elements + x_block = tl.where(m_mask[:, None], x_block, 0.0) + grad_output_block = tl.where( + m_mask[:, None], grad_output_block, 0.0 + ) + else: + # Manual load with bounds checking + offs_m = tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Create masks + m_mask = offs_m < m_block_size + n_mask = offs_n < N - n_offset + k_mask = offs_k < K - k_offset + + # Combined masks + mk_mask = m_mask[:, None] & k_mask[None, :] + mn_mask = m_mask[:, None] & n_mask[None, :] + + # Global offsets for loading + m_global_offs = m_global_offset + offs_m + + # Load x block [M_chunk, K] + x_block = tl.load( + x_desc_ptr + + m_global_offs[:, None] * K + + (k_offset + offs_k)[None, :], + mask=mk_mask, + other=0.0, + ) + + # Load grad_output block [M_chunk, N] + grad_output_block = tl.load( + grad_output_desc_ptr + + m_global_offs[:, None] * N + + (n_offset + offs_n)[None, :], + mask=mn_mask, + other=0.0, + ) + + # Compute partial contribution: grad_W += grad_Y.T @ X + # transpose grad_output for the matmul + contribution = tl.dot( + grad_output_block.to(tl.float32).T, # [N, M_chunk] + x_block.to(tl.float32), # [M_chunk, K] + ) + + # Accumulate + accumulator += contribution + + # Store the result + if USE_TMA_STORE: + # Store using TMA + tl._experimental_descriptor_store( + workspace, # TMA store descriptor + accumulator.to(c_dtype), + [n_offset, k_offset], + ) + else: + # Manual store with bounds checking + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Create masks for bounds checking + n_mask = offs_n < N - n_offset + k_mask = offs_k < K - k_offset + output_mask = n_mask[:, None] & k_mask[None, :] + + # Store the result + tl.store( + grad_weight_ptr + + (n_offset + offs_n)[:, None] * K + + (k_offset + offs_k)[None, :], + accumulator.to(c_dtype), + mask=output_mask, + ) + + +# ======== End Triton kernels ======== + +# ======== Triton wrapper functions ======== + +# ----- main forward pass wrapper ----- + + +def grouped_gemm_forward( + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + tma_size: int = 128, +) -> torch.Tensor: + """ + M*G style grouped GEMM with TMA and Float8 support. + # Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors. + + """ + if not CudaUtils.verify_tma(): + raise NotImplementedError("Grouped GEMM without TMA is not supported yet") + + G = m_sizes.shape[0] + + assert x.is_contiguous() + assert w.is_contiguous() + assert m_sizes.is_contiguous() + + # Total input size is now [M_total, K] where M_total is the sum of all group sizes + M_total, K = x.shape + N = w.shape[0] # N is now the same for all groups + + assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})" + + # Verify that all group sizes are multiples of ALIGN_SIZE_M + # This check is commented out because it will involve a GPU-CPU sync + # assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M" + + # Create output tensor with correct shape [M_total, N] + y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype) + + if M_total == 0: + return y + + NUM_SMS = CudaUtils.get_num_sms() + USE_TMA_LOAD = True + USE_TMA_STORE = True + USE_EPILOGUE_SUBTILING = False + + # TMA descriptor helper + desc_helper = None + desc_x = x + desc_w = w + workspace = None + + if USE_TMA_LOAD: + desc_helper = TmaDescriptorHelper(tma_size=tma_size) + desc_helper.init_tma_descriptor("x") + desc_helper.init_tma_descriptor("w") + desc_x = desc_helper.get_tma_descriptor_kernel_param("x") + desc_w = desc_helper.get_tma_descriptor_kernel_param("w") + + if USE_TMA_STORE: + workspace = torch.empty( + NUM_SMS * desc_helper.tma_size, + device=x.device, + dtype=torch.uint8, + ) + + def grid(META): + if USE_TMA_LOAD: + nonlocal desc_helper + desc_helper.fill_2d_tma_descriptor( + "x", + x.data_ptr(), + M_total, + K, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_K"], + x.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "w", + w.data_ptr(), + N, + K, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + w.element_size(), + ) + return (NUM_SMS,) + + M_BUCKET = triton.next_power_of_2(M_total) + + _kernel_mg_forward_hopper[grid]( + desc_x, + desc_w, + y, + workspace, + m_sizes, + G, + M_BUCKET, + N, + K, + NUM_SMS, + TMA_SIZE=tma_size, + USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING, + ) + + return y + + +# ======== Improved Backward ============= +def grouped_gemm_backward( + grad_output: torch.Tensor, + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + use_tma: bool = True, + tma_size: int = 128, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Unified backward pass for grouped GeMM with M*G grouping. + Uses optimized TMA-based implementations for both dx and dw when available. + + Args: + grad_output: Gradient of output, shape [M_total, N] + x: Input tensor from forward pass, shape [M_total, K] + w: Weight tensor from forward pass, shape [N, K] + m_sizes: Group sizes tensor, shape [G] + use_tma: Whether to try using TMA acceleration (if available) + tma_size: Size of TMA descriptor in bytes + + + Returns: + Tuple of gradients with respect to x and w: (grad_x, grad_w) + """ + logging.info("Starting unified grouped_gemm_backward") + + # do this once, seems expensive + NUM_SMS = CudaUtils.get_num_sms() + + # Basic validation + G = m_sizes.shape[0] + M_total, K_x = x.shape + M_grad, N = grad_output.shape + N_w, K_w = w.shape + + # Check dimensions + if K_x != K_w: + raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}") + if M_total != M_grad: + raise ValueError( + f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}" + ) + + # Check total M matches sum of group sizes + sum_m_sizes = m_sizes.sum().item() + if M_total != sum_m_sizes: + raise ValueError( + f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" + ) + + # Make sure inputs are contiguous + grad_output = grad_output.contiguous() + x = x.contiguous() + w = w.contiguous() + m_sizes = m_sizes.contiguous() + + # Check TMA support + can_use_tma = use_tma and CudaUtils.verify_tma() + if use_tma and not can_use_tma: + logging.info("TMA requested but not supported on this device") + use_tma = False + + # Compute grad_x using flat linear implementation + try: + logging.info(f"Computing grad_x with flat linear kernel") + + # Use TMA-optimized implementation + grad_x = grouped_gemm_dx_tma( + grad_output=grad_output, + w=w, + m_sizes=m_sizes, + num_sms=NUM_SMS, + tma_size=tma_size, + ) + + except Exception as e: + logging.error(f"Error in grad_x computation: {e}") + raise + + # Compute grad_w using flat linear style implementation + try: + logging.info(f"Computing grad_w with flat linear kernel") + + grad_w = grouped_gemm_dw_tma( + x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size + ) + except Exception as e: + logging.error(f"Error in grad_w computation: {e}") + raise + + return grad_x, grad_w + + +# ----- dx backward pass wrapper ----- + + +def grouped_gemm_dx_tma( + grad_output: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + num_sms: int = 132, + tma_size: int = 128, +) -> torch.Tensor: + """ + Optimized backward pass wrapper for computing gradient with respect to input (dx) + using TMA patterns similar to the forward pass. + + Args: + grad_output: Gradient of output, shape [M_total, N] + w: Weight tensor, shape [N, K] + m_sizes: Group sizes tensor, shape [G] + tma_size: Size of TMA descriptor + # using_fp8: Whether to use FP8 quantization + # grad_output_scale: Scale for grad_output in FP8 mode + # w_scale: Scale for w in FP8 mode + + Returns: + grad_x: Gradient with respect to x, shape [M_total, K] + """ + """ + Optimized backward pass for computing gradient with respect to input (dx) + using TMA patterns similar to the forward pass. + + Args: + grad_output: Gradient of output, shape [M_total, N] + w: Weight tensor, shape [N, K] + m_sizes: Group sizes tensor, shape [G] + tma_size: Size of TMA descriptor + using_fp8: Whether to use FP8 quantization + # grad_output_scale: Scale for grad_output in FP8 mode + # w_scale: Scale for w in FP8 mode + + Returns: + grad_x: Gradient with respect to x, shape [M_total, K] + """ + if not CudaUtils.verify_tma(): + raise NotImplementedError("Optimized dx computation requires TMA support") + + G = m_sizes.shape[0] + + assert grad_output.is_contiguous() + assert w.is_contiguous() + assert m_sizes.is_contiguous() + + M_total, N_grad = grad_output.shape + N_w, K = w.shape + + # Check dimensions + assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})" + + # Verify that the sum of m_sizes matches M_total + sum_m_sizes = m_sizes.sum().item() + assert ( + M_total == sum_m_sizes + ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" + + # Create output tensor (grad_x) with shape [M_total, K] + grad_x = torch.empty( + (M_total, K), device=grad_output.device, dtype=grad_output.dtype + ) + + NUM_SMS = num_sms # CudaUtils.get_num_sms() + USE_TMA_LOAD = True + USE_TMA_STORE = True + + # Set up TMA descriptors + desc_helper = TmaDescriptorHelper(tma_size=tma_size) + desc_helper.init_tma_descriptor("grad_output") + desc_helper.init_tma_descriptor("w") + desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output") + desc_w = desc_helper.get_tma_descriptor_kernel_param("w") + + # Allocate workspace for TMA store + workspace = torch.empty( + NUM_SMS * desc_helper.tma_size, + device=grad_output.device, + dtype=torch.uint8, + ) + + def grid(META): + # Fill TMA descriptors with appropriate dimensions + desc_helper.fill_2d_tma_descriptor( + "grad_output", + grad_output.data_ptr(), + M_total, + N_grad, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_N"], + grad_output.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "w", + w.data_ptr(), + N_w, + K, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + w.element_size(), + ) + return (NUM_SMS,) + + M_BUCKET = triton.next_power_of_2(M_total) + + # Launch the flat linear kernel for computing grad_x + _kernel_mg_dx_tma[grid]( + desc_grad_output, + desc_w, + grad_x, + workspace, + m_sizes, + G, + M_BUCKET, + N_grad, # N dimension is now the reduction dimension + K, + NUM_SMS, + USE_TMA_LOAD, + USE_TMA_STORE, + TMA_SIZE=tma_size, + ) + + return grad_x + + +# ======== dw wrapper function ========== + + +def grouped_gemm_dw_tma( + x: torch.Tensor, + grad_output: torch.Tensor, + m_sizes: torch.Tensor, + num_sms: int = 132, + tma_size: int = 128, +) -> torch.Tensor: + """ + Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA. + For the forward pass Y = X @ W.T, the backward for weights is: + grad_W = grad_Y.T @ X + + Args: + x: Input tensor, shape [M_total, K] + grad_output: Gradient of output, shape [M_total, N] + m_sizes: Group sizes tensor, shape [G] + tma_size: Size of TMA descriptor in bytes + + + Returns: + grad_w: Gradient with respect to weights, shape [N, K] + """ + # Check TMA support + has_tma_support = CudaUtils.verify_tma() + + # Get group count + G = m_sizes.shape[0] + + # Ensure contiguous tensors + x = x.contiguous() + grad_output = grad_output.contiguous() + m_sizes = m_sizes.contiguous() + + # Get dimensions + M_total, K_x = x.shape + M_grad, N = grad_output.shape + + # Check dimensions + assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})" + + # Verify that the sum of m_sizes matches M_total + sum_m_sizes = m_sizes.sum().item() + assert ( + sum_m_sizes == M_total + ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})" + + # Create output tensor (grad_w) with shape [N, K] + grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype) + + NUM_SMS = num_sms + + # TODO - hardcoded for now...but should set TMA flags based on hardware support + USE_TMA_LOAD = True # has_tma_support + USE_TMA_STORE = True # has_tma_support + + # Set up TMA descriptors or direct pointers + if USE_TMA_LOAD or USE_TMA_STORE: + desc_helper = TmaDescriptorHelper(tma_size=tma_size) + + if USE_TMA_LOAD: + desc_helper.init_tma_descriptor("x") + desc_helper.init_tma_descriptor("grad_output") + x_desc = desc_helper.get_tma_descriptor_kernel_param("x") + grad_output_desc = desc_helper.get_tma_descriptor_kernel_param( + "grad_output" + ) + else: + x_desc = x + grad_output_desc = grad_output + + if USE_TMA_STORE: + desc_helper.init_tma_descriptor("grad_w") + workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w") + else: + workspace = torch.empty(1, device=x.device, dtype=torch.uint8) + else: + # If not using TMA, just use the tensors directly + x_desc = x + grad_output_desc = grad_output + workspace = torch.empty(1, device=x.device, dtype=torch.uint8) + + # M_BUCKET for grid size + M_BUCKET = triton.next_power_of_2(M_total) + + # Define grid for kernel launch + def grid(META): + if USE_TMA_LOAD or USE_TMA_STORE: + + if USE_TMA_LOAD: + desc_helper.fill_2d_tma_descriptor( + "x", + x.data_ptr(), + M_total, + K_x, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_K"], + x.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "grad_output", + grad_output.data_ptr(), + M_total, + N, + META["BLOCK_SIZE_M"], + META["BLOCK_SIZE_N"], + grad_output.element_size(), + ) + + if USE_TMA_STORE: + desc_helper.fill_2d_tma_descriptor( + "grad_w", + grad_w.data_ptr(), + N, + K_x, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + grad_w.element_size(), + ) + + # Return grid size - one block per SM for balanced work distribution + return (NUM_SMS,) + + # Launch the optimized kernel + _kernel_mg_dw_tma[grid]( + x_desc, + grad_output_desc, + grad_w, + workspace, + m_sizes, + G, + M_BUCKET, + N, + K_x, + NUM_SMS, + USE_TMA_LOAD, + USE_TMA_STORE, + TMA_SIZE=tma_size, + ) + + return grad_w + + +# ======== End Backwards Wrapper Functions ============= + +# ======== PyTorch wrapper functions ======== + + +class GroupedGEMM_mg(torch.autograd.Function): + """ + Autograd function for GroupedGEMM with M*G grouping. + Supports both standard and FP8 quantized operations. + """ + + @staticmethod + def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128): + """ + Forward pass of GroupedGEMM. + + Args: + x: Input tensor, shape [M_total, K] + w: Weight tensor, shape [N, K] + m_sizes: Tensor of shape [G] containing the size of each group + use_tma: Whether to try using TMA acceleration (if available) + tma_size: Size of TMA descriptor in bytes + using_fp8: Whether to use FP8 quantization + + Returns: + Output tensor, shape [M_total, N] + """ + + # Use regular forward without quantization + output = grouped_gemm_forward( + x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False + ) + + # Save inputs and parameters for backward pass + ctx.save_for_backward(x, w, m_sizes) + ctx.use_tma = use_tma + ctx.tma_size = tma_size + + ctx.save_for_backward(x, w, m_sizes) + + return output + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass of M*G GroupedGEMM. + + Args: + grad_output: Gradient of output, shape [M_total, N] + + Returns: + Tuple of gradients: + - grad_x: Gradient with respect to x, shape [M_total, K] + - grad_w: Gradient with respect to w, shape [N, K] + - None: Gradient with respect to m_sizes (not differentiable) + - None: Gradient with respect to use_tma (not differentiable) + - None: Gradient with respect to tma_size (not differentiable) + + """ + # Retrieve saved tensors and parameters + + x, w, m_sizes = ctx.saved_tensors + + use_tma = ctx.use_tma + tma_size = ctx.tma_size + + # Compute gradients using the unified implementation + grad_x, grad_w = grouped_gemm_backward( + grad_output=grad_output, + x=x, + w=w, + m_sizes=m_sizes, + use_tma=use_tma, + tma_size=tma_size, + ) + + # Return gradients for all inputs (None for non-differentiable parameters) + return grad_x, grad_w, None, None + + +def mg_grouped_gemm( + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + use_tma: bool = True, + tma_size: int = 128, + using_fp8: bool = False, +) -> torch.Tensor: + """ + Unified differentiable grouped GEMM operation for M*G grouped GEMM. + Supports both standard precision and FP8 quantized operations. + + Args: + x: Input tensor, shape [M_total, K] + w: Weight tensor, shape [N, K] + m_sizes: Tensor of shape [G] containing the size of each group + use_tma: Whether to try using TMA acceleration (if available) + tma_size: Size of TMA descriptor in bytes + using_fp8: Whether to use FP8 quantization + + Returns: + Output tensor, shape [M_total, N] + """ + return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8) diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0835132c3ebf31f8c88a066e5bf19eed4c4acd69 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging + +import numpy as np +import torch + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +def compute_reference_forward(x, w, m_sizes): + """ + Compute reference forward pass using PyTorch operations. + + Args: + x (torch.Tensor): Input tensor of shape (M, K) + w (torch.Tensor): Weight tensor of shape (N, K) + m_sizes (torch.Tensor): Group sizes tensor of shape (G) + + Returns: + torch.Tensor: Reference output tensor of shape (M, N) + """ + result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device) + + m_start = 0 + for g in range(len(m_sizes)): + m_size = m_sizes[g].item() + if m_size > 0: + m_end = m_start + m_size + + # Extract group input + x_g = x[m_start:m_end] + + # Compute group output: y_g = x_g @ w.T + y_g = torch.matmul(x_g, w.T) + + # Store result + result[m_start:m_end] = y_g + + # Update start index + m_start = m_end + + return result + + +def compute_reference_backward(x, w, m_sizes, grad_output): + """ + Compute reference backward pass using PyTorch autograd. + + Args: + x (torch.Tensor): Input tensor of shape (M, K) + w (torch.Tensor): Weight tensor of shape (N, K) + m_sizes (torch.Tensor): Group sizes tensor of shape (G) + grad_output (torch.Tensor): Gradient tensor of shape (M, N) + + Returns: + tuple: (grad_x, grad_w) gradient tensors + """ + # Create autograd-enabled copies + x_autograd = x.detach().clone().requires_grad_(True) + w_autograd = w.detach().clone().requires_grad_(True) + + # Compute forward pass + output = compute_reference_forward(x_autograd, w_autograd, m_sizes) + + # Backpropagate + output.backward(grad_output) + + return x_autograd.grad, w_autograd.grad + + +def analyze_tensor_differences(actual, expected, name): + """ + Analyze differences between actual and expected tensors. + + Args: + actual (torch.Tensor): Actual tensor + expected (torch.Tensor): Expected tensor + name (str): Name of the tensor for logging + + Returns: + bool: True if tensors are close enough + """ + rtol = 0.5 # Relative tolerance for float16 + atol = 0.5 # Absolute tolerance for float16 + + # Analyze differences + diff = (actual - expected).abs() + max_idx = diff.argmax().item() + idx = np.unravel_index(max_idx, actual.shape) + max_diff = diff.max().item() + + logging.info(f"Largest {name} difference: {max_diff} at {idx}") + logging.info(f"Values: {actual[idx].item()} vs {expected[idx].item()}") + + is_close = torch.allclose(actual, expected, rtol=rtol, atol=atol) + + if is_close: + logging.info(f"✓ SUCCESS: {name} matches PyTorch reference") + else: + logging.error(f"✗ FAILURE: {name} mismatch detected") + + # Count zeros + zeros_actual = (actual == 0).sum().item() + zeros_expected = (expected == 0).sum().item() + logging.info( + f"Zeros in {name} (actual): {zeros_actual}/{actual.numel()} ({zeros_actual/actual.numel()*100:.2f}%)" + ) + logging.info( + f"Zeros in {name} (expected): {zeros_expected}/{expected.numel()} ({zeros_expected/expected.numel()*100:.2f}%)" + ) + + # Check for NaNs + nan_actual = torch.isnan(actual).sum().item() + if nan_actual > 0: + logging.error(f"NaN values detected in {name}: {nan_actual}") + + return is_close diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py new file mode 100644 index 0000000000000000000000000000000000000000..8fdd7a66c6afc6ca2c3d5d50d55cd9e7d1ae78f1 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/tma_autotuning.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# credit - TMAHelper class, AutoTuning are derived from FBGemm: +# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm + +# pyre-unsafe +import functools + +import os +import sys +from typing import Any, Dict, Optional, Tuple + +import torch + +import triton +import triton.language as tl +from triton import Config as TConfig + +from triton.runtime import driver # @manual + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + + +# ===== Supporting utils, CUDA and TMA ===== + + +class CudaUtils: + @staticmethod + def is_cuda() -> bool: + """Check if Triton is running on CUDA backend.""" + return driver.active.get_current_target().backend == "cuda" + + @staticmethod + def verify_tma() -> bool: + """Check if TMA is supported on the current device.""" + return ( + CudaUtils.is_cuda() + and torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 9 + ) + + @staticmethod + def get_num_sms() -> int: + """Get the number of streaming multiprocessors on the current device.""" + if not CudaUtils.is_cuda(): + raise RuntimeError("Triton is not running on CUDA backend") + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + return torch.cuda.get_device_properties("cuda").multi_processor_count + + +class TmaDescriptorHelper: + """Helper class for managing TMA descriptors in Triton kernels.""" + + class KernelParamWrapper: + """Wrapper to implement the TmaDescKernelParam interface.""" + + def __init__(self, desc: torch.Tensor): + self.desc = desc + + def tma_desc_cpu_ptr(self) -> int: + """Return the CPU pointer to the TMA descriptor.""" + return self.desc.data_ptr() + + def __init__(self, tma_size: int = 128): + """Initialize the TMA descriptor helper. + + Args: + tma_size: Size of the TMA descriptor in bytes + """ + if not CudaUtils.verify_tma(): + raise RuntimeError( + "TMA not supported on this device (requires Hopper or newer)" + ) + if "nv_tma_desc_type" not in dir(tl): + raise RuntimeError( + "TMA grid constant descriptors not supported in your Triton version" + ) + + self.tma_size = tma_size + self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor + self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor + self.descriptors: Dict[str, torch.Tensor] = {} + + def init_tma_descriptor(self, name: str) -> None: + """Initialize a TMA descriptor with the given name. + + Call this method outside of the lambda function for grid size. + """ + self.descriptors[name] = torch.empty( + self.tma_size, device="cpu", dtype=torch.int8 + ) + + def fill_1d_tma_descriptor( + self, name: str, ptr: int, dim: int, block_dim: int, element_size: int + ) -> None: + """Fill a 1D TMA descriptor. + + Call this method inside the lambda function for grid size. + """ + if name not in self.descriptors: + raise ValueError(f"TMA descriptor '{name}' not initialized") + + desc_x = self.descriptors[name] + if desc_x.data_ptr() % 64 != 0: + raise ValueError("TMA descriptor must be 64-byte aligned") + self.fill_1d_tma_descriptor_inner( + ptr, dim, block_dim, element_size, desc_x.data_ptr() + ) + + def fill_2d_tma_descriptor( + self, + name: str, + ptr: int, + dim1: int, + dim0: int, + block_dim1: int, + block_dim0: int, + element_size: int, + ) -> None: + """Fill a 2D TMA descriptor. + + Call this method inside the lambda function for grid size. + """ + if name not in self.descriptors: + raise ValueError(f"TMA descriptor '{name}' not initialized") + + desc_x = self.descriptors[name] + if desc_x.data_ptr() % 64 != 0: + raise ValueError("TMA descriptor must be 64-byte aligned") + self.fill_2d_tma_descriptor_inner( + ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr() + ) + + def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper: + """Get the TMA descriptor kernel parameter for the given name.""" + if name not in self.descriptors or self.descriptors[name] is None: + raise ValueError(f"TMA descriptor '{name}' not initialized") + return self.KernelParamWrapper(self.descriptors[name]) + + +# ====== Autotuning utilities ====== +ALIGN_SIZE_M = 128 + +_NV_CONFIGS = [ + triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + }, + num_stages=num_stages, + num_warps=num_warps, + num_ctas=num_ctas, + ) + for block_size_m in [ALIGN_SIZE_M, ] + for block_size_n in [64, 128, 256] + for block_size_k in [64, 128, 256] + for num_stages in [3, 4] + for num_warps in [4, 8] + for num_ctas in [1] +] + + +def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs): + device = torch.cuda.current_device() + # Check for all possible pointer parameter names + if "grad_input_ptr" in named_args: + ptr_name = "grad_input_ptr" + elif "c_ptr" in named_args: + ptr_name = "c_ptr" + elif "grad_weight_ptr" in named_args: + ptr_name = "grad_weight_ptr" + else: + raise KeyError("No recognized pointer parameter found in kernel arguments") + + if dtsize is None: + dtsize = named_args[ptr_name].element_size() + if dtype is None: + dtype = named_args[ptr_name].dtype + + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = ( + kw["BLOCK_SIZE_M"], + kw["BLOCK_SIZE_N"], + kw["BLOCK_SIZE_K"], + config.num_stages, + ) + G, M, N, K = ( + named_args["G"], + named_args["M_BUCKET"], + named_args["N"], + named_args["K"], + ) + + # 1. make sure we have enough smem + max_shared_memory = driver.active.utils.get_device_properties(device)[ + "max_shared_mem" + ] + + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory > max_shared_memory: + continue + + M_PER_GROUP = M // G + MIN_M_TILES = 64 + # 2. make sure we don't load M tiles that are too big + if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2): + continue + # 3. make sure we don't load N tiles that are too small + if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2): + continue + + num_sm = driver.active.utils.get_device_properties(device)[ + "multiprocessor_count" + ] + N_TILES = N // BLOCK_N + MIN_N_TILES = 64 + # 4. make sure we don't load N tiles that are too big + if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm: + continue + # 5. make sure we don't load N tiles that are too small + if BLOCK_N < 128 and M * N_TILES > 2 * num_sm: + continue + # 6. make sure K can be evenly divided + if K % BLOCK_K != 0: + continue + + pruned_configs.append(config) + + return pruned_configs + + +# ======== End Autotuning utilities ======== diff --git a/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py new file mode 100644 index 0000000000000000000000000000000000000000..2429432d756ae4d5bb6f91a6108c7ba8a4b9c627 --- /dev/null +++ b/torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging +import unittest +from typing import Tuple + +import torch +import torch.nn as nn + +from mg_grouped_gemm import grouped_gemm_forward + + +class TestMG_GroupedGEMM(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(2020) + + def _run_grouped_gemm_test( + self, + shape: Tuple[int, int, int, int], + device: torch.device, + dtype: torch.dtype = torch.bfloat16, + atol: float = 1e-5, + rtol: float = 1.6e-2, + ) -> None: + G, M, N, K = shape + # In M*G grouping, input is [M*G, K] and weights are [N*G, K] + a = torch.randn(M * G, K, dtype=dtype, device=device) + b = torch.randn(N * G, K, dtype=dtype, device=device) + + # Create equal-sized groups for simplicity + m_size = M + m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32) + + result = grouped_gemm_forward(a, b, m_sizes) + self.assertTrue(result.shape == (M * G, N)) + + expected_result = torch.zeros(M * G, N, dtype=dtype, device=device) + m_start = 0 + for g in range(G): + m_end = m_start + m_sizes[g] + b_slice = b[N * g : N * (g+1), :] + expected_result[m_start:m_end, :] = a[m_start:m_end, :] @ b_slice.T + m_start = m_end + + # Convert result to match input dtype if needed + result = result.to(dtype) + torch.testing.assert_close(result, expected_result, atol=atol, rtol=rtol) + + def test_MG_grouped_gemm_bf16(self) -> None: + for G in (1, 4, 16): + for M in (128, 512, 1024): + print(f"Testing BF16 M*G GroupGeMM with G={G}, M={M}") + self._run_grouped_gemm_test( + (G, M, 1024, 1024), + torch.device("cuda"), + dtype=torch.bfloat16, + atol=1e-5, + rtol=1.6e-2, + ) + + def test_MG_grouped_gemm_deepseek_shapes(self) -> None: + """Test with shapes from Deepseek model.""" + deepseek_shapes = [ + (4, 2048, 4096, 7168), # G, M, N, K + (4, 2048, 7168, 2048), + (8, 512, 4096, 7168), + (8, 512, 7168, 2048), + ] + + device = torch.device("cuda") + + for shape in deepseek_shapes: + G, M, N, K = shape + print(f"Testing BF16 M*G Deepseek shape: G={G}, M={M}, N={N}, K={K}") + self._run_grouped_gemm_test( + shape, device, dtype=torch.bfloat16, atol=1e-5, rtol=1.6e-2 + ) diff --git a/torchtitan/experiments/llama4/README.md b/torchtitan/experiments/llama4/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d912caaa80b12157f24693b63f8a9a5bd75c717f --- /dev/null +++ b/torchtitan/experiments/llama4/README.md @@ -0,0 +1,29 @@ +**The Llama 4 folder is still under development.** + +#### Available features +- Llama 4 model definition (text-only), including the MoE architecture with token-choice routing +- Basic FSDP, TP, PP, CP support +- DCP checkpoint conversion scripts + +#### Download Llama 4 tokenizer +```bash +# Llama 4 tokenizer.model +python scripts/download_tokenizer.py --repo_id meta-llama/Llama-4-Scout-17B-16E --tokenizer_path "" --hf_token=... +``` + +#### To be added +- Modeling + - iRoPE implementation + - load balance loss for token-choice MoE + - alternative expert-choice MoE + - multimodal support +- Kernel integration + - efficient bfloat16 GroupedGEMM kernels (from PyTorch core) + - efficient float8 GroupedGEMM kernels (from torchao) +- Parallelism + - performant TP implementation and torch.compile support for MoE layers + - Context Parallel support for FlexAttention, iRoPE, and multimodal inputs + - Expert Parallel support +- Testing + - perfomance and loss converging tests + - CI integration diff --git a/torchtitan/experiments/llama4/__init__.py b/torchtitan/experiments/llama4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0907e1892fa3840be81e7eefe12047d2e1cf1661 --- /dev/null +++ b/torchtitan/experiments/llama4/__init__.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer +from torchtitan.models.llama3 import pipeline_llama +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .infra.parallelize_llama import parallelize_llama +from .model.args import TransformerModelArgs +from .model.model import Transformer + +__all__ = [ + "TransformerModelArgs", + "Transformer", + "llama4_configs", +] + + +llama4_configs = { + "debugmodel": TransformerModelArgs( + dim=256, + n_layers=8, + n_heads=16, + rope_theta=500000, + ), + "17bx16e": TransformerModelArgs( + dim=5120, + n_layers=48, + n_heads=40, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=2048, + rope_theta=500000, + num_experts=16, + interleave_moe_layer_step=1, + ), + "17bx128e": TransformerModelArgs( + dim=5120, + n_layers=48, + n_heads=40, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=2048, + rope_theta=500000, + num_experts=128, + ), +} + + +register_train_spec( + TrainSpec( + name="llama4", + cls=Transformer, + config=llama4_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_tiktoken_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc b/torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e80257078623867845b3324a2e6882ecb034a0a Binary files /dev/null and b/torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc differ diff --git a/torchtitan/experiments/llama4/infra/__pycache__/parallelize_llama.cpython-312.pyc b/torchtitan/experiments/llama4/infra/__pycache__/parallelize_llama.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0f786f56e0f13ad9d667b11314c548822d75fdf Binary files /dev/null and b/torchtitan/experiments/llama4/infra/__pycache__/parallelize_llama.cpython-312.pyc differ diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/experiments/llama4/infra/expert_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..63945e8cd6a3f9509ca34c779b09a2f2f7581c2f --- /dev/null +++ b/torchtitan/experiments/llama4/infra/expert_parallel.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from functools import partial +from typing import Optional, Tuple + +import torch.nn as nn +from torch.distributed.tensor import ( + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + Partial, + Replicate, + Shard, +) +from torch.distributed.tensor.parallel import ParallelStyle +from torch.distributed.tensor.placement_types import Placement + + +# implementation of Tensor Parallel on the non-shared experts in MoE +class TensorParallel(ParallelStyle): + def __init__( + self, + *, + input_layouts: Optional[Tuple[Optional[Placement]]] = None, + output_layout: Optional[Placement] = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layouts = input_layouts or (Replicate(), None) + self.output_layout = output_layout or Partial() + self.desired_input_layouts = (Replicate(), None) + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): + # TODO: figure out dynamo support for instance method and switch this to instance method + + # annotate module input placements/sharding with input_layouts + input_tensor, input_layout, desired_input_layout = ( + inputs[0], + input_layouts[0], + desired_input_layouts[0], + ) + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, (input_layout,), run_check=False + ) + + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute( + placements=(desired_input_layout,), async_op=True + ) + return (input_tensor, *inputs[1:]) + + def _partition_fn(self, name, module, device_mesh): + module.register_parameter( + "w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)])) + ) # Column-wise sharding + module.register_parameter( + "w2", + nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(1)])), + ) # Row-wise sharding + module.register_parameter( + "w3", + nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(2)])), + ) # Column-wise sharding + + @staticmethod + def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): + if outputs.placements != (output_layout,): + outputs = outputs.redistribute(placements=(output_layout,), async_op=True) + # back to local tensor + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + self._partition_fn, + partial( + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts + ), + partial(self._prepare_output_fn, self.output_layout, self.use_local_output), + ) + + +# NOTE: This is to achieve replicate computation on the gate module in the MoE router. +# It does nothing other than (1) setting the module parameters as DTensors on the given mesh +# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back. +# TODO: The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh, +# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation. +class NoParallel(ParallelStyle): + def __init__( + self, + *, + input_layout: Optional[Placement] = None, + output_layout: Optional[Placement] = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layout = input_layout or Replicate() + self.output_layout = output_layout or Replicate() + self.desired_input_layout = Replicate() + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh): + # annotate module input placements/sharding with input_layouts + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, (input_layout,), run_check=False + ) + + if input_layout != desired_input_layout: + input_tensor = input_tensor.redistribute( + placements=(desired_input_layout,), async_op=True + ) + return (input_tensor, *inputs[1:]) + + @staticmethod + def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): + if outputs.placements != (output_layout,): + outputs = outputs.redistribute(placements=(output_layout,), async_op=True) + # back to local tensor + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + None, + partial( + self._prepare_input_fn, self.input_layout, self.desired_input_layout + ), + partial(self._prepare_output_fn, self.output_layout, self.use_local_output), + ) diff --git a/torchtitan/experiments/llama4/infra/parallelize_llama.py b/torchtitan/experiments/llama4/infra/parallelize_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..72842fc04f896896772beca4ec7b50b0ce66a7b5 --- /dev/null +++ b/torchtitan/experiments/llama4/infra/parallelize_llama.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims + +from torchtitan.models.llama3.parallelize_llama import ( + apply_ac, + apply_compile, + apply_ddp, + apply_fsdp, + apply_tp, +) +from torchtitan.tools.logging import logger + + +def parallelize_llama( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + if parallel_dims.tp_enabled: + if ( + job_config.parallelism.enable_async_tensor_parallel + and not job_config.training.compile + ): + raise RuntimeError("Async TP requires --training.compile") + + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + # For now, float8 all-gather with TP is only supported for tensorwise + # float8 scaling recipes. For rowwise recipes, we use regular TP and + # all-gather happens in high precision. + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + + apply_tp( + model, + world_mesh["tp"], + loss_parallel=parallel_dims.loss_parallel_enabled, + enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, + ) + + apply_moe_tp(model, world_mesh["tp"]) + + if job_config.activation_checkpoint.mode != "none": + if ( + job_config.activation_checkpoint.mode == "selective" + and job_config.model.use_flex_attn + ): + raise ValueError( + "FlexAttention is not compatible with selective AC yet. " + "See https://github.com/pytorch/pytorch/issues/147879" + ) + apply_ac(model, job_config.activation_checkpoint) + + # turn on per-TransformerBlock compile after AC wrapping and before FSDP + if job_config.training.compile: + apply_compile(model) + + # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE + torch._dynamo.config.capture_scalar_outputs = True + + if ( + parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled + ): # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + + apply_fsdp( + model, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + world_mesh, + enable_compile=job_config.training.compile, + enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, + ) + + return model + + +def apply_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, +): + from torch.distributed.tensor import Partial, Replicate, Shard + from torch.distributed.tensor.parallel import ( + parallelize_module, + PrepareModuleInputOutput, + ) + + from .expert_parallel import NoParallel, TensorParallel + + for _, transformer_block in model.layers.items(): + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_input=True, + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # replicate computation for the router + "moe.router.gate": NoParallel(), + # input Replicate, output Partial + "moe.experts": TensorParallel(), + "moe.shared_expert": TensorParallel(), + } + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=moe_layer_plan, + ) diff --git a/torchtitan/experiments/llama4/model/__pycache__/args.cpython-312.pyc b/torchtitan/experiments/llama4/model/__pycache__/args.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc63910731b9e5c45088e1a2e26ff1170e163203 Binary files /dev/null and b/torchtitan/experiments/llama4/model/__pycache__/args.cpython-312.pyc differ diff --git a/torchtitan/experiments/llama4/model/__pycache__/model.cpython-312.pyc b/torchtitan/experiments/llama4/model/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..414b14e0457072273c2b8ee24b9c191fda34345b Binary files /dev/null and b/torchtitan/experiments/llama4/model/__pycache__/model.cpython-312.pyc differ diff --git a/torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc b/torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcf7184ccbfe4575f21558c63677c94d072ea435 Binary files /dev/null and b/torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc differ diff --git a/torchtitan/experiments/llama4/model/args.py b/torchtitan/experiments/llama4/model/args.py new file mode 100644 index 0000000000000000000000000000000000000000..7e5757f08bced3ce6d5f92f343fd6e4beebaf400 --- /dev/null +++ b/torchtitan/experiments/llama4/model/args.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Optional + +from torch import nn +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig + +from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.tools.logging import logger + + +@dataclass +class TransformerModelArgs(BaseModelArgs): + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + + max_seq_len: int = 2048 + # If `True`, then each transformer block init uses its layer ID, and if + # `False`, each uses the total number of transformer blocks + depth_init: bool = True + norm_type: str = "rmsnorm" + + use_flex_attn: bool = False + attn_mask_type: str = "causal" + eos_id: int = 0 + + # MoE args + moe_enabled: bool = True + num_experts: int = 8 + use_shared_expert: bool = True + auto_scale_hidden_dim: bool = True + # frequency of using MoE layer instead of feedforward layer in a transformer block + interleave_moe_layer_step: int = 2 + # token-choice + top_k: int = 1 + + def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: + self.norm_type = job_config.model.norm_type + self.vocab_size = tokenizer.n_words + self.max_seq_len = job_config.training.seq_len + self.use_flex_attn = job_config.model.use_flex_attn + + def get_nparams_and_flops( + self, model: nn.Module, seq_len: int + ) -> tuple[int, float]: + nparams_embedding = 0 + nparams_moe_router = 0 + nparams_shared_expert = 0 + nparams_experts = 0 + nparams_dense = 0 + + for name, p in model.named_parameters(): + if "embedding" in name: + nparams_embedding += p.numel() + nparams_dense += p.numel() + elif "moe.shared_expert" in name: + nparams_shared_expert += p.numel() + elif "moe.router" in name: + nparams_moe_router += p.numel() + elif "moe.experts" in name: + nparams_experts += p.numel() + else: + nparams_dense += p.numel() + + nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts + nparams = nparams_dense + nparams_sparse + nparams_sparse_active = ( + nparams_moe_router + + nparams_shared_expert + + nparams_experts * self.top_k // self.num_experts + ) + + logger.info( + f"Total parameter count: dense {nparams_dense:,}, " + f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}" + ) + + l, h, q, t = ( + self.n_layers, + self.n_heads, + self.dim // self.n_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + num_flops_per_token = ( + 6 * (nparams_dense - nparams_embedding + nparams_sparse_active) + + 12 * l * h * q * t + ) + + return nparams, num_flops_per_token diff --git a/torchtitan/experiments/llama4/model/model.py b/torchtitan/experiments/llama4/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..39be49a5b0e645cc67b04a3e0957d057c3ec40d2 --- /dev/null +++ b/torchtitan/experiments/llama4/model/model.py @@ -0,0 +1,466 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F +from torch import nn + +from torchtitan.models.attention import build_attention, init_attention_mask +from torchtitan.models.norms import build_norm +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import TransformerModelArgs +from .moe import MoE + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert ndim > 1 + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """ + Multi-head attention module. + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.dim // model_args.n_heads + + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=False + ) + self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + output = self.sdpa(xq, xk, xv) + + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: float | None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TransformerBlock(nn.Module): + """ + TransformerBlock Module + + Args: + layer_id (int): Identifier for the layer. + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + + def __init__(self, layer_id: int, model_args: TransformerModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + self.attention = Attention(model_args) + + # use MoE layer for every interleave_moe_layer_step FFN layers + self.moe_enabled = ( + model_args.moe_enabled + and (layer_id + 1) % model_args.interleave_moe_layer_step == 0 + ) + if self.moe_enabled: + self.moe = MoE(model_args) + else: + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + + self.layer_id = layer_id + self.num_layers = model_args.n_layers + + self.attention_norm = build_norm( + model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps + ) + self.ffn_norm = build_norm( + model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps + ) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention(self.attention_norm(x), freqs_cis) + if self.moe_enabled: + out = h + self.moe(self.ffn_norm(h)) + else: + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + if self.moe_enabled: + self.moe.init_weights(self.weight_init_std) + else: + self.feed_forward.init_weights(self.weight_init_std) + + +class Transformer(nn.Module, ModelProtocol): + """ + Transformer Module + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + model_args (TransformerModelArgs): Model configuration arguments. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (RMSNorm): Layer normalization for the model output. + output (ColumnParallelLinear): Linear layer for final output. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + """ + + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + self.eos_id = model_args.eos_id + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + # TODO persistent should be set to false, since this buffer can be recomputed. + # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, + # compile or pipeline-tracer will not correctly handle non-persistent buffers, + # so we need to fix that. (2) if we initialize pipeline-parallel models from + # a seed checkpoint rather than calling init_weights, we need freqs_cis to be + # initialized by the checkpoint, or we need to add a separate initializer for + # just the non-persistent buffers that is called after loading checkpoints. + self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + + self.norm = build_norm( + model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps + ) + + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + self.init_weights() + + def init_weights( + self, + buffer_device: torch.device | None = None, + ): + """ + [Note: On ``init_weights`` vs. ``reset_parameters``] + Modules may define ``reset_parameters`` to initialize parameter values. + ``reset_parameters`` is meant to only initialize directly owned + parameters/buffers, not those of their child modules, and it can be + used to give the initial values for these tensors. + Separately, users may want custom initialization for their modules, + different from that in ``reset_parameters``. For this, we define + ``init_weights``. We only call it in the constructor of this + ``Transformer`` root module to avoid reinitializing tensors. + """ + buffer_device = buffer_device or self.freqs_cis.device + with torch.device(buffer_device): + self.freqs_cis = self._precompute_freqs_cis() + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights() + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_freqs_cis(self) -> torch.Tensor: + return precompute_freqs_cis( + self.model_args.dim // self.model_args.n_heads, + # Need to compute until at least the max token limit for generation + # TODO: explain in docs/composability.md why we removed the 2x + # relaxing in our CP enablement PR + self.model_args.max_seq_len, + self.model_args.rope_theta, + ) + + def forward(self, tokens: torch.Tensor): + """ + Perform a forward pass through the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices. + + Returns: + torch.Tensor: Output logits after applying the Transformer model. + + """ + # TODO: We will to change forward() signature to allow tokens to + # be always passed in. + if self.model_args.use_flex_attn: + init_attention_mask(tokens, eos_id=self.eos_id) + + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + + h = self.norm(h) if self.norm else h + output = self.output(h) if self.output else h + return output + + @classmethod + def from_model_args(cls, model_args: TransformerModelArgs) -> "Transformer": + """ + Initialize a Transformer model from a TransformerModelArgs object. + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Returns: + Transformer: Transformer model. + + """ + return cls(model_args) diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..0b925b36207875dedc13a16be10890c3671cdabb --- /dev/null +++ b/torchtitan/experiments/llama4/model/moe.py @@ -0,0 +1,228 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from torch import nn + +from .args import TransformerModelArgs + + +class GroupedExperts(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_experts: int, + ): + super().__init__() + self.num_experts = num_experts + self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) + self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + + def forward( + self, + x: torch.Tensor, + num_local_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_local_tokens_per_expert is not None: + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x, + split_size_or_sections=num_local_tokens_per_expert.tolist(), + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + w1, w2, w3 = ( + self.w1[expert_idx], + self.w2[expert_idx], + self.w3[expert_idx], + ) + h = F.silu(torch.matmul(x_expert, w1)) + h = h * torch.matmul(x_expert, w3) + h = torch.matmul(h, w2) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # TODO:optimize with GroupedGEMM + # https://github.com/pytorch/pytorch/pull/150374 + # _gouped_mm requires shapes to be multiple of 8 + # offsets = torch.cumsum(num_local_tokens_per_expert, dim=0, dtype=torch.int32) + # h = F.silu(torch._grouped_mm(x, self.w1.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)) + # h = h * torch._grouped_mm(x, self.w3.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16) + # out = torch._grouped_mm(h, self.w2.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = F.silu(torch.bmm(x, self.w1)) + h = h * torch.bmm(x, self.w3) + # out shape (num_experts, tokens_per_expert, dim) + out = torch.bmm(h, self.w2) + return out + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std) + + +class TokenChoiceTopKRouter(nn.Module): + """This class implements token-choice routing. In token-choice top-K routing, each token is + routed to top K experts based on the router scores. + + Args: + gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts). + dim (int): Dimension of input tokens. + num_experts (int): Number of experts in each moe layer. + top_k (int): Number of experts each token will be routed to in token-choice routing. + use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False. + """ + + def __init__( + self, + dim: int, + num_experts: int, + top_k: int, + use_sigmoid: bool = False, + ): + super().__init__() + self.gate = nn.Linear(dim, num_experts, bias=False) + self.num_experts = num_experts + self.top_k = top_k + self.use_sigmoid = use_sigmoid + + def forward( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. + + Returns: + routed_input (torch.Tensor): + Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. + token_indices (torch.Tensor): + Token indices for routed_input with shape ``(bs*slen*top_k,)``. + num_local_tokens_per_expert (torch.Tensor): + Number of tokens assigned to each expert with shape ``(num_experts,)``. + """ + # scores shape (bs*slen, num_experts) + scores = self.gate(x) + + # By default, sigmoid or softmax is performed in float32 to avoid loss explosion + if self.use_sigmoid: + scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype) + else: + scores = F.softmax(scores.to(torch.float32), dim=1).to(x.dtype) + + # top scores shape (bs*slen, top_k) + top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1) + # top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype) + + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward + num_local_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1), + bins=self.num_experts, + min=0, + max=self.num_experts, + ) + # token_indices_experts_sorted shape (bs*slen*top_k,) + token_indices_experts_sorted = torch.argsort( + selected_experts_indices.view(-1), stable=True + ) + top_scores = top_scores.view(-1)[token_indices_experts_sorted] + token_indices_experts_sorted = token_indices_experts_sorted // self.top_k + + return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) + + +# TODO: implement load balancing auxiliary loss for token-choice routing +class MoE(nn.Module): + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + dim = model_args.dim + hidden_dim = 4 * model_args.dim + ffn_dim_multiplier = model_args.ffn_dim_multiplier + hidden_dim = int(2 * hidden_dim / 3) + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + + num_experts = model_args.num_experts + + hidden_dim_denom = 1 + if model_args.auto_scale_hidden_dim: + hidden_dim_denom = model_args.top_k + int(model_args.use_shared_expert) + + if model_args.auto_scale_hidden_dim: + hidden_dim = int(hidden_dim / hidden_dim_denom) + hidden_dim += -hidden_dim % model_args.multiple_of + + self.experts = GroupedExperts( + dim=dim, hidden_dim=hidden_dim, num_experts=num_experts + ) + self.router = TokenChoiceTopKRouter( + dim=dim, num_experts=num_experts, top_k=model_args.top_k + ) + self.shared_expert = ( + GroupedExperts(dim=dim, hidden_dim=hidden_dim, num_experts=1) + if model_args.use_shared_expert + else None + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. + + Returns: + out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. + """ + bs, slen, dim = x.shape + # top_scores and selected_indices shape (bs*slen*top_k,) + # num_local_tokens_per_expert shape (num_experts,) + ( + top_scores, + token_indices, + num_local_tokens_per_expert, + ) = self.router(x.reshape(bs * slen, dim)) + + # shape (bs*slen*top_k, dim) + token_indices = token_indices.reshape(-1, 1).expand(-1, dim) + + # shape (bs*slen*top_k, dim) + routed_input = torch.gather( + x.view(-1, dim), + dim=0, + index=token_indices, + ) + routed_input = routed_input * top_scores.reshape(-1, 1) + + # shape (bs*slen*top_k, dim) + routed_output = self.experts(routed_input, num_local_tokens_per_expert) + + # shared expert + if self.shared_expert is not None: + out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( + bs * slen, dim + ) + else: + out = torch.zeros_like(x.reshape(bs * slen, dim)) + + out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + out = out.reshape(bs, slen, dim) + return out + + def init_weights(self, init_std: float): + self.experts.init_weights(init_std) + self.router.init_weights(init_std) + if self.shared_expert is not None: + self.shared_expert.init_weights(init_std) diff --git a/torchtitan/experiments/llama4/scripts/REAME.md b/torchtitan/experiments/llama4/scripts/REAME.md new file mode 100644 index 0000000000000000000000000000000000000000..c4cd6c32412522eb6efb0fa93eb09344b69ad3cc --- /dev/null +++ b/torchtitan/experiments/llama4/scripts/REAME.md @@ -0,0 +1,17 @@ +## How to convert a Llama 4 checkpoint for use in torchtitan + +To continue training from an existing model checkpoint, the checkpoint must be in the DCP format expected by the checkpoint manager. +This folder contains the scripts for converting officially released Llama 4 checkpoints into the expected DCP format, from original Meta format, or from HuggingFace format, using GPUs. + +#### Example usage + +From Meta format: +```bash +CONFIG_FILE=../train_configs/llama4_16.toml ./convert_meta_to_dcp.sh --checkpoint.enable_checkpoint --checkpoint.convert_path=[checkpoint_folder] --checkpoint.convert_load_every_n_ranks=8 +``` + + +From HuggingFace format: +```bash +CONFIG_FILE=../train_configs/llama4_16.toml ./convert_hf_to_dcp_with_gpus.sh --checkpoint.enable_checkpoint --checkpoint.convert_path=[checkpoint_folder] --checkpoint.convert_load_every_n_ranks=8 +``` diff --git a/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..99eb36ac6ffa8e546d8895358978e937088f7ee1 --- /dev/null +++ b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py @@ -0,0 +1,545 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import json +import math +import os +import pprint +import time +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Optional + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor, Shard +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset +from torchtitan.components.checkpoint import MODEL +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import init_logger, logger +from torchtitan.train import Trainer + + +def extract_layer_number(s): + import re + + match = re.search(r"layers\.(\d+)", s) + if match: + return int(match.group(1)) + else: + return None + + +def convert_to_titan_fqns(fqn: str) -> list[str]: + # From the stored checkpoint keys to TorchTitan keys. + if "language_model." not in fqn: + # TODO: Not support video model yet + return [fqn] + + layer = extract_layer_number(fqn) + + if layer is None: + if "embed_tokens.weight" in fqn: + return ["tok_embeddings.weight"] + elif "norm.weight" in fqn: + return ["norm.weight"] + elif "lm_head.weight" in fqn: + return ["output.weight"] + else: + raise ValueError(f"Unknown fqn {fqn}") + + if "feed_forward.experts.down_proj" in fqn: + return [f"layers.{layer}.moe.experts.w2"] + elif "feed_forward.experts.gate_up_proj" in fqn: + return [f"layers.{layer}.moe.experts.w1", f"layers.{layer}.moe.experts.w3"] + elif "feed_forward.router.weight" in fqn: + return [f"layers.{layer}.moe.router.gate.weight"] + elif "feed_forward.shared_expert.down_proj.weight" in fqn: + return [f"layers.{layer}.moe.shared_expert.w2"] + elif "feed_forward.shared_expert.gate_proj.weight" in fqn: + return [f"layers.{layer}.moe.shared_expert.w3"] + elif "feed_forward.shared_expert.up_proj.weight" in fqn: + return [f"layers.{layer}.moe.shared_expert.w1"] + elif "input_layernorm.weight" in fqn: + return [f"layers.{layer}.ffn_norm.weight"] + elif "self_attn.k_proj" in fqn: + return [f"layers.{layer}.attention.wk.weight"] + elif "self_attn.o_proj" in fqn: + return [f"layers.{layer}.attention.wo.weight"] + elif "self_attn.q_proj" in fqn: + return [f"layers.{layer}.attention.wq.weight"] + elif "self_attn.v_proj" in fqn: + return [f"layers.{layer}.attention.wv.weight"] + elif "post_attention_layernorm.weight" in fqn: + return [f"layers.{layer}.attention_norm.weight"] + else: + raise ValueError(f"Unknown fqn {fqn}") + + +def convert_to_hf_shape(fqn: str, titan_fqns: list[str], dtensor: DTensor) -> list[str]: + if "feed_forward.experts.gate_up_proj" in fqn: + assert len(titan_fqns) == 2 + shape = dtensor.shape + return torch.Size(list(shape[:-1]) + [shape[-1] * 2]) + elif "shared_expert" in fqn: + s = dtensor.shape + # TODO: this is not right but I have to do this to load the checkpoint. + return torch.Size((s[2], s[1])) + return dtensor.shape + + +def convert_to_titan_tensors(fqn: str, full_tensor: torch.Tensor) -> torch.Tensor: + if "feed_forward.experts.gate_up_proj" in fqn: + full_tensors = full_tensor.chunk(2, dim=-1) + elif "shared_expert" in fqn: + # TODO: this is not right but I have to do this to load the checkpoint. + full_tensor = full_tensor.transpose(1, 0) + full_tensors = [full_tensor.unsqueeze(0)] + else: + full_tensors = [full_tensor] + return full_tensors + + +@dataclass +class _Assignment: + loader_id: int + filename: str + fqns: list[str] + shapes: list[torch.Size] + dtypes: list[torch.dtype] + + +@dataclass +class _AssignmentRound: + loader_assignments: dict[int, _Assignment] # List of assignments for each loader + + +@dataclass +class TensorMetadata: + fqn: str + shape: torch.Size + dtype: torch.dtype + + +class CheckpointConverter: + def __init__( + self, + process_group: dist.ProcessGroup, + path: str, + token: Optional[str] = None, + loader_every_n_ranks: int = 8, + ) -> None: + self.path = path + self.token = token + self.pg = process_group + self.my_rank = dist.get_rank(self.pg) + + self.loader_every_n_ranks = loader_every_n_ranks + self.loader_id = self.my_rank // loader_every_n_ranks + self.should_load = self.my_rank % loader_every_n_ranks == 0 + self.total_loader = dist.get_world_size(self.pg) // loader_every_n_ranks + + self.titan_fqn_to_stored_fqn: dict[str, str] = {} + self.stored_fqn_to_titan_fqn: dict[str, list[str]] = {} + self.total_send_bytes = 0 + self.total_recv_bytes = 0 + + def convert(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + begin = time.time() + self._load_metadata() + self._create_fqn_mappings(state_dict) + rounds = self._get_load_assignments(state_dict) + + logger.info(f"Got {len(rounds)} rounds of assignments.") + for idx, assignments in enumerate(rounds): + loader_assignments = assignments.loader_assignments + loaded_state_dict = None + # Let each loader to load its own data and move to its GPU. + logger.info(f"Loading round {idx}") + for i in range(self.total_loader): + # This loader doesn't have any loading assignment for this round. + if i not in loader_assignments: + continue + # This rank is not the loader + if i != self.loader_id or not self.should_load: + continue + loaded_state_dict = self._load_round(loader_assignments[i]) + + torch.cuda.synchronize() + logger.info(f"Loading round {idx} finished") + for i in range(self.total_loader): + if i not in loader_assignments: + continue + + logger.info(f"Resharding round {idx} loader {i} data. ") + if i == self.loader_id and self.should_load: + # This rank is the loader. It needs to send the loaded data to + # the other ranks. + assert loaded_state_dict is not None + results = self._reshard_send( + loader_assignments[i], loaded_state_dict + ) + else: + results = self._reshard_receive(loader_assignments[i], state_dict) + torch.cuda.synchronize() + + logger.info(f"Communication round {idx} loader {i} is done.") + self._reshard(results, state_dict) + logger.info(f"Resharding round {idx} loader {i} is done.") + self._reshard(results, state_dict) + torch.cuda.synchronize() + + dist.barrier() + torch.cuda.synchronize() + logger.info(f"Checkpoint conversion took {time.time() - begin:.2f} seconds.") + logger.info(f"Total send bytes: {self.total_send_bytes / 1e9:.2f} GB") + logger.info(f"Total recv bytes: {self.total_recv_bytes / 1e9:.2f} GB") + return state_dict + + def _load_metadata(self) -> None: + metadata_path = os.path.join(self.path, "model.safetensors.index.json") + with open(metadata_path, "r") as f: + self.metadata = json.load(f)["weight_map"] + + def _create_fqn_mappings(self, state_dict: dict[str, torch.Tensor]) -> None: + if not self.metadata: + return + + # Create the mapping from the stored checkpoint keys to TorchTitan keys. + for fqn in list(self.metadata.keys()): + titan_fqns = convert_to_titan_fqns(fqn) + # We don't know how to process _extra_state + if "_extra_state" in fqn: + self.metadata.pop(fqn) + continue + + if titan_fqns[0] not in state_dict: + for titan_fqn in titan_fqns: + assert titan_fqn not in state_dict + self.metadata.pop(fqn) + continue + + self.stored_fqn_to_titan_fqn[fqn] = titan_fqns + for titan_fqn in titan_fqns: + self.titan_fqn_to_stored_fqn[titan_fqn] = fqn + + torchtitan_extra = sorted( + list(set(state_dict.keys()) - set(self.titan_fqn_to_stored_fqn.keys())) + ) + converted_extra = sorted( + list(set(self.titan_fqn_to_stored_fqn.keys()) - set(state_dict.keys())) + ) + assert set(state_dict.keys()) == set(self.titan_fqn_to_stored_fqn.keys()), ( + f"{pprint.pformat(torchtitan_extra)}", + f"{pprint.pformat(converted_extra)}", + ) + + def _get_load_assignments( + self, state_dict: dict[str, Any] + ) -> list[_AssignmentRound]: + if self.my_rank == 0: + filename_to_metas = defaultdict(list) + for fqn, filename in self.metadata.items(): + titan_fqns = self.stored_fqn_to_titan_fqn[fqn] + shape = convert_to_hf_shape(fqn, titan_fqns, state_dict[titan_fqns[0]]) + meta = TensorMetadata( + fqn=fqn, + shape=shape, + # TODO: don't hardcode this + dtype=torch.bfloat16, + ) + filename_to_metas[filename].append(meta) + + loader_filename_to_metas = [{} for _ in range(self.total_loader)] + for idx, (filename, metas) in enumerate(filename_to_metas.items()): + loader_id = idx % self.total_loader + loader_filename_to_metas[loader_id][filename] = metas + + rounds = [] + while any(len(remain) > 0 for remain in loader_filename_to_metas): + round_assignment = _AssignmentRound(loader_assignments={}) + for loader_id in range(self.total_loader): + if not loader_filename_to_metas[loader_id]: + continue + + filename, metas = loader_filename_to_metas[loader_id].popitem() + round_assignment.loader_assignments[loader_id] = _Assignment( + filename=filename, + fqns=[meta.fqn for meta in metas], + shapes=[meta.shape for meta in metas], + dtypes=[meta.dtype for meta in metas], + loader_id=loader_id, + ) + + rounds.append(round_assignment) + + object_list: list[Any] = [ + rounds, + self.titan_fqn_to_stored_fqn, + self.stored_fqn_to_titan_fqn, + ] + else: + object_list = [None, None, None] + + dist.broadcast_object_list(object_list, src=0, group=self.pg) + rounds = object_list[0] + self.titan_fqn_to_stored_fqn = object_list[1] + self.stored_fqn_to_titan_fqn = object_list[2] + return rounds + + def _load_round(self, assignment: _Assignment) -> dict[str, Any]: + from safetensors.torch import load_file as hf_load_file + + path = os.path.join(self.path, assignment.filename) + state_dict = hf_load_file(path) + return { + k: v.to(device="cuda") + for k, v in state_dict.items() + if k in assignment.fqns + } + + def _reshard_send( + self, + assignment: _Assignment, + loaded_state_dict: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + flatten_tensors = [t.flatten() for t in loaded_state_dict.values()] + flatten_tensor = torch.concat(flatten_tensors) + assert self.loader_id == assignment.loader_id + rank = self.loader_id * self.loader_every_n_ranks + assert rank == self.my_rank + logger.info( + f"Sending {assignment.filename} from {rank} {self.loader_id} " + f"{flatten_tensor.shape=} {flatten_tensor.dtype=} {loaded_state_dict.keys()=}." + ) + logger.info(f"Sending {assignment}") + dist.broadcast(flatten_tensor, src=rank, group=self.pg) + self.total_send_bytes += flatten_tensor.numel() * flatten_tensor.element_size() + return loaded_state_dict + + def _reshard_receive( + self, assignment: _Assignment, state_dict: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + + flatten_tensor = torch.empty( + sum(math.prod(s) for s, d in zip(assignment.shapes, assignment.dtypes)), + dtype=assignment.dtypes[0], + device="cuda", + ) + rank = assignment.loader_id * self.loader_every_n_ranks + logger.info( + f"Receiving {assignment.filename} from {rank} " + f"{flatten_tensor.shape=} {flatten_tensor.dtype=}" + ) + logger.info(f"Receiving {assignment}") + dist.broadcast(flatten_tensor, src=rank, group=self.pg) + self.total_recv_bytes += flatten_tensor.numel() * flatten_tensor.element_size() + + ret: dict[str, torch.Tensor] = {} + loc = 0 + for fqn, shape, dtype in zip( + assignment.fqns, assignment.shapes, assignment.dtypes + ): + n_ele = math.prod(shape) + ret[fqn] = flatten_tensor[loc : loc + n_ele].view(shape) + loc += n_ele + return ret + + def _reshard( + self, + result: dict[str, torch.Tensor], + state_dict: dict[str, torch.Tensor], + ) -> None: + def _inplace_copy(fqn: str, full_tensors: list[torch.Tensor]): + titan_fqns = self.stored_fqn_to_titan_fqn[fqn] + assert len(titan_fqns) == len(full_tensors) + for titan_fqn, full_tensor in zip(titan_fqns, full_tensors): + dtensor = state_dict[titan_fqn] + assert isinstance(dtensor, DTensor) + assert dtensor.shape == full_tensor.shape, ( + (fqn, titan_fqn), + dtensor.shape, + full_tensor.shape, + ) + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, dtensor.device_mesh, dtensor.placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] + logger.debug( + f"Copying {titan_fqn} with {slices=} {dtensor._local_tensor.shape=} " + f"{shape=} {offset=} {self.my_rank=} {dtensor.shape=} {full_tensor.shape=} " + f"{dtensor.placements=} {dtensor.device_mesh=} " + ) + dtensor.to_local().copy_(full_tensor[slices].to(dtensor.dtype)) + + for fqn, full_tensor in result.items(): + full_tensors = convert_to_titan_tensors(fqn, full_tensor) + _inplace_copy(fqn, full_tensors) + + +def _create_verified_state_dict( + pg: dist.ProcessGroup, mesh: DeviceMesh +) -> dict[str, torch.Tensor]: + placements = [Shard(0)] + state_dict = { + "vision_model.vision_adapter.mlp.fc1.weight": torch.rand( + 4096, 5632, device="cuda", dtype=torch.bfloat16 + ), + "vision_model.vision_adapter.mlp.fc2.weight": torch.rand( + 4096, 4096, device="cuda", dtype=torch.bfloat16 + ), + "language_model.model.layers.3.feed_forward.experts.gate_up_proj": torch.rand( + 16, 5120, 16384, device="cuda", dtype=torch.bfloat16 + ), + } + return {k: distribute_tensor(v, mesh, placements) for k, v in state_dict.items()} + + +def _verify_state_dict( + state_dict: dict[str, torch.Tensor], path: str, rank: int +) -> None: + metadata_path = os.path.join(path, "model.safetensors.index.json") + with open(metadata_path, "r") as f: + metadata = json.load(f)["weight_map"] + all_filenames = set() + for fqn, tensor in state_dict.items(): + filename = os.path.join(path, metadata[fqn]) + all_filenames.add(filename) + + stored_state_dict = {} + from safetensors.torch import load_file as hf_load_file + + for filename in all_filenames: + _sd = hf_load_file(filename) + for k in list(_sd.keys()): + if k not in state_dict: + _sd.pop(k) + else: + stored_state_dict[k] = _sd[k] + + def read_and_verify_tensor(fqn: str, dtensor: DTensor) -> None: + logger.info(f"Verifying {fqn} {dtensor.shape=} {dtensor.placements=} ") + stored_tensor = stored_state_dict[fqn] + full_tensor = dtensor.full_tensor() + logger.info(f"Gather {fqn} {full_tensor.shape} completely.") + + if rank > 0: + return + + stored_tensor = stored_tensor.to(device="cuda") + logger.info(f"Move to GPU {fqn} completely.") + + assert stored_tensor.shape == full_tensor.shape, fqn + assert stored_tensor.dtype == full_tensor.dtype, fqn + assert stored_tensor.device == full_tensor.device, fqn + assert torch.allclose(stored_tensor, full_tensor), fqn + + for k, v in state_dict.items(): + read_and_verify_tensor(k, v) + + +if __name__ == "__main__": + init_logger() + config = JobConfig() + config.parser.add_argument( + "--checkpoint.convert_path", + type=str, + default="", + help="""Specify the path of the target checkpoint to convert.""", + ) + config.parser.add_argument( + "--checkpoint.convert_hf_token", + type=str, + default="", + help="""Specify hf token.""", + ) + config.parser.add_argument( + "--checkpoint.convert_load_every_n_ranks", + type=int, + default=8, + help=""" + Specify the interval at which ranks are assigned to load checkpoints. + + For example, if this number is 4, then ranks 0, 4, 8, ... will load the + checkpoint. Each loader is responsible for loading one file. If there + are more loaders than files, only the first few loaders will be assigned + to load the checkpoint. The default value is 8. + """, + ) + config.parser.add_argument( + "--checkpoint.fake_model", + action="store_true", + help="""If true, the model will be fake.""", + ) + config.parse_args() + assert config.checkpoint.convert_path != "" + + trainer: Optional[Trainer] = None + + try: + trainer = Trainer(config) + if os.path.exists(trainer.checkpointer.folder): + raise RuntimeError( + "The checkpoint folder already exists. Abort to avoid overwriting " + f"the checkpoint. {trainer.checkpointer.folder=}" + ) + if config.checkpoint.fake_model: + state_dict = _create_verified_state_dict( + trainer.world_mesh.get_group(), trainer.world_mesh + ) + else: + state_dict = trainer.checkpointer.states[MODEL].state_dict() + + size = 0 + for v in state_dict.values(): + size += v.numel() * v.element_size() + logger.info(f"Total size of the model: {size / 1e9:.2f} GB") + + # Do not support PP yet, we will need to iterate over the PP dimension and + # extract the corresponding state_dict and device_mesh. + if "freqs_cis" in state_dict: + state_dict.pop("freqs_cis") + + # Our tokenizer is not up-to-date yet. + tok_embeddings_weight = state_dict.pop("tok_embeddings.weight") + output_weight = state_dict.pop("output.weight") + state_dict = CheckpointConverter( + process_group=trainer.world_mesh.get_group(), + path=config.checkpoint.convert_path, + token=config.checkpoint.convert_hf_token, + loader_every_n_ranks=config.checkpoint.convert_load_every_n_ranks, + ).convert(state_dict) + state_dict["tok_embeddings.weight"] = tok_embeddings_weight + state_dict["output.weight"] = output_weight + + class DummyModel: + def __init__(self, state_dict: dict[str, torch.Tensor]) -> None: + self._state_dict = state_dict + + def state_dict(self) -> dict[str, torch.Tensor]: + return self._state_dict + + if config.checkpoint.fake_model: + begin = time.time() + _verify_state_dict( + state_dict, + config.checkpoint.convert_path, + trainer.world_mesh.get_rank(), + ) + dist.barrier() + logger.info(f"Verifies state_dict {time.time() - begin}.") + else: + # oh, this is pretty bad, when can we get rid of the freqs_cis issue? + state_dict["freqs_cis"] = None + trainer.checkpointer.states[MODEL] = DummyModel(state_dict) + trainer.checkpointer.model_weights_only = True + trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype + trainer.checkpointer.save(curr_step=0, force=True) + time.sleep(2) + finally: + pass diff --git a/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.sh b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.sh new file mode 100644 index 0000000000000000000000000000000000000000..6530b8ce992c8c33ccec94614e026d73964710ee --- /dev/null +++ b/torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.sh @@ -0,0 +1,26 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +set -ex + +# use envs as local overrides for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./convert_hf_to_dcp_with_gpus.sh +NGPU=${NGPU:-"8"} +LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7} +CONFIG_FILE=${CONFIG_FILE:-"../train_configs/llama4_17bx16e.toml"} + +overrides="" +if [ $# -ne 0 ]; then + overrides="$*" +fi + +PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +convert_hf_to_dcp_with_gpus.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py b/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py new file mode 100644 index 0000000000000000000000000000000000000000..7756afe3de1527f469a38fc6a0bdc6c62eaa2526 --- /dev/null +++ b/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py @@ -0,0 +1,536 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import os +import time +from dataclasses import dataclass +from typing import Any, Optional + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor, Shard +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset +from torchtitan.components.checkpoint import MODEL +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import init_logger, logger +from torchtitan.train import Trainer + +# Sharding dims for MP checkpoints + +column_parallel = [ + "tok_embeddings", + "wq", + "wk", + "wv", + "wqkv", + "w_in_shared_FD", + "w_out_eF_D", + "w_swiglu_FD", + "output", + "_linear", + "c_fc", + "vision_projection", +] + +row_parallel = [ + "wo", + "w_out_shared_DF", + "w_in_eD_F", + "moe_w_swiglu_eD_F", + "c_proj", +] + + +def convert_to_titan_fqns(fqn: str) -> list[str]: + # From the stored checkpoint keys to TorchTitan keys. + if "wqkv" in fqn and "layer_norm_weight" not in fqn: + ret = [] + for k in ("wq", "wk", "wv"): + ret.append(fqn.replace("wqkv", k)) + return ret + return [fqn] + + +def get_shard_dim(fqn: str) -> Optional[int]: + if "bias" in fqn: + # Some bias params are still sharded + if "resblocks" in fqn: + for k in ("wq", "wk", "wv", "c_fc"): + if k in fqn: + return 0 + return None + elif any([x in fqn for x in column_parallel]): + return 0 + elif any([x in fqn for x in row_parallel]): + return 1 + else: + return None + + +def split_fused_qkv(shards: list[torch.Tensor]) -> tuple[torch.Tensor, ...]: + qkvs = [torch.split(shard, [640, 128, 128]) for shard in shards] + q = torch.cat([qkv[0] for qkv in qkvs], dim=0) + k = torch.cat([qkv[1] for qkv in qkvs], dim=0) + v = torch.cat([qkv[2] for qkv in qkvs], dim=0) + return q, k, v + + +@dataclass +class _Assignment: + loader_id: int + filename: str + fqns: tuple[str, ...] + shapes: tuple[torch.Size, ...] + dtypes: tuple[torch.dtype, ...] + + +@dataclass +class _AssignmentRound: + loader_assignments: dict[int, _Assignment] # List of assignments for each loader + + +class CheckpointConverter: + TOTAL_SHARDS = 8 + + def __init__( + self, + process_group: dist.ProcessGroup, + path: str, + loader_every_n_ranks: int = 8, + ) -> None: + self.path = path + self.pg = process_group + self.my_rank = dist.get_rank(self.pg) + self.loader_every_n_ranks = loader_every_n_ranks + self.loader_id = self.my_rank // loader_every_n_ranks + self.should_load = ( + self.my_rank % loader_every_n_ranks == 0 + and self.loader_id < CheckpointConverter.TOTAL_SHARDS + ) + self.total_loader = CheckpointConverter.TOTAL_SHARDS + self.titan_fqn_to_stored_fqn: dict[str, str] = {} + self.stored_fqn_to_titan_fqn: dict[str, list[str]] = {} + self.total_send_bytes = 0 + self.total_recv_bytes = 0 + + def convert(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + begin = time.time() + self._load_metadata() + self._create_fqn_mappings(state_dict) + rounds = self._get_load_assignments(state_dict) + + for assignments in rounds: + loader_assignments = assignments.loader_assignments + loaded_state_dict = None + # Let each loader to load its own data and move to its GPU. + for i in range(self.total_loader): + # This loader doesn't have any loading assignment for this round. + if i not in loader_assignments: + continue + # This rank is not the loader + if i != self.loader_id or not self.should_load: + continue + loaded_state_dict = self._load_round(loader_assignments[i]) + + results = [] + for i in range(self.total_loader): + if i not in loader_assignments: + continue + + if i == self.loader_id and self.should_load: + # This rank is the loader. It needs to send the loaded data to + # the other ranks. + assert loaded_state_dict is not None + results.append( + self._reshard_send(loader_assignments[i], loaded_state_dict) + ) + else: + results.append( + self._reshard_receive(loader_assignments[i], state_dict) + ) + + self._reshard(results, state_dict) + + torch.cuda.synchronize() + logger.info(f"Checkpoint conversion took {time.time() - begin:.2f} seconds.") + logger.info(f"Total send bytes: {self.total_send_bytes / 1e9:.2f} GB") + logger.info(f"Total recv bytes: {self.total_recv_bytes / 1e9:.2f} GB") + return state_dict + + def _get_file_path(self, loader_id: int) -> str: + return os.path.join(self.path, f"consolidated.0{loader_id}.pth") + + def _load_metadata(self) -> None: + if not self.should_load: + self.read_dict = {} + return + self.read_dict = torch.load( + self._get_file_path(self.loader_id), + mmap=True, + weights_only=False, + ) + + def _create_fqn_mappings(self, state_dict: dict[str, torch.Tensor]) -> None: + if not self.read_dict: + return + + # Create the mapping from the stored checkpoint keys to TorchTitan keys. + for fqn in list(self.read_dict.keys()): + titan_fqns = convert_to_titan_fqns(fqn) + # We don't know how to process _extra_state + if "_extra_state" in fqn: + self.read_dict.pop(fqn) + continue + + if titan_fqns[0] not in state_dict: + for titan_fqn in titan_fqns: + assert titan_fqns[0] not in state_dict + self.read_dict.pop(fqn) + continue + self.stored_fqn_to_titan_fqn[fqn] = titan_fqns + for titan_fqn in titan_fqns: + self.titan_fqn_to_stored_fqn[titan_fqn] = fqn + + assert set(state_dict.keys()) == set(self.titan_fqn_to_stored_fqn.keys()), ( + set(state_dict.keys()) - set(self.titan_fqn_to_stored_fqn.keys()), + set(self.titan_fqn_to_stored_fqn.keys()) - set(state_dict.keys()), + ) + + def _get_load_assignments( + self, state_dict: dict[str, torch.Tensor] + ) -> list[_AssignmentRound]: + if self.my_rank == 0: + rounds: list[_AssignmentRound] = [] + size = 0 + fqns = [] + shapes = [] + dtypes = [] + + # All loader must load all the FQNs because the checkpoint is purely TP sharded. + all_keys = list(self.read_dict.keys()) + for fqn in all_keys: + fqns.append(fqn) + shapes.append(self.read_dict[fqn].shape) + dtypes.append(self.read_dict[fqn].dtype) + size += self.read_dict[fqn].numel() * self.read_dict[fqn].element_size() + if size < 1e9 and fqn != all_keys[-1]: + continue + + logger.info(f"Adding {fqns} to round {len(rounds)}") + round_assignment = _AssignmentRound(loader_assignments={}) + for loader_id in range(self.total_loader): + path = self._get_file_path(loader_id) + round_assignment.loader_assignments[loader_id] = _Assignment( + filename=path, + fqns=tuple(fqns), + shapes=tuple(shapes), + dtypes=tuple(dtypes), + loader_id=loader_id, + ) + rounds.append(round_assignment) + size = 0 + fqns.clear() + shapes.clear() + dtypes.clear() + + object_list: list[Any] = [ + rounds, + self.titan_fqn_to_stored_fqn, + self.stored_fqn_to_titan_fqn, + ] + else: + object_list = [None, None, None] + + dist.broadcast_object_list(object_list, src=0, group=self.pg) + rounds = object_list[0] + self.titan_fqn_to_stored_fqn = object_list[1] + self.stored_fqn_to_titan_fqn = object_list[2] + return rounds + + def _load_round(self, assignment: _Assignment) -> dict[str, torch.Tensor]: + ret = {} + assert self.read_dict + for fqn in assignment.fqns: + ret[fqn] = self.read_dict[fqn].to(device="cuda") + return ret + + def _reshard_send( + self, + assignment: _Assignment, + loaded_state_dict: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + flatten_tensors = [t.flatten() for t in loaded_state_dict.values()] + flatten_tensor = torch.concat(flatten_tensors) + assert self.loader_id == assignment.loader_id + rank = self.loader_id * self.loader_every_n_ranks + assert rank == self.my_rank + logger.info(f"Sending {assignment.filename} from {rank} {self.loader_id}") + logger.info(f"Sending {assignment.fqns}") + dist.broadcast(flatten_tensor, src=rank, group=self.pg) + self.total_send_bytes += flatten_tensor.numel() * flatten_tensor.element_size() + return loaded_state_dict + + def _reshard_receive( + self, assignment: _Assignment, state_dict: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + flatten_tensor = torch.empty( + sum(math.prod(s) for s, d in zip(assignment.shapes, assignment.dtypes)), + dtype=assignment.dtypes[0], + device="cuda", + ) + rank = assignment.loader_id * self.loader_every_n_ranks + dist.broadcast(flatten_tensor, src=rank, group=self.pg) + self.total_recv_bytes += flatten_tensor.numel() * flatten_tensor.element_size() + + ret: dict[str, torch.Tensor] = {} + loc = 0 + for fqn, shape, dtype in zip( + assignment.fqns, assignment.shapes, assignment.dtypes + ): + n_ele = math.prod(shape) + ret[fqn] = flatten_tensor[loc : loc + n_ele].view(shape) + loc += n_ele + return ret + + def _reshard( + self, + results: list[dict[str, torch.Tensor]], + state_dict: dict[str, torch.Tensor], + ) -> None: + def _inplace_copy(fqn: str, full_tensors: tuple[torch.Tensor, ...]): + titan_fqns = self.stored_fqn_to_titan_fqn[fqn] + assert len(titan_fqns) == len(full_tensors) + for titan_fqn, full_tensor in zip(titan_fqns, full_tensors): + dtensor = state_dict[titan_fqn] + logger.info(f"{titan_fqn} {full_tensor.sum()}") + assert isinstance(dtensor, DTensor) + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, dtensor.device_mesh, dtensor.placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] + logger.info( + f"Copying {titan_fqn} with {slices=} {dtensor._local_tensor.shape=} " + f"{shape=} {offset=} {self.my_rank=} {dtensor.shape=} {full_tensor.shape=} " + f"{dtensor.placements=} {dtensor.device_mesh=} " + ) + dtensor.to_local().copy_(full_tensor[slices]) + + def _concat_shards(fqn, shards: list[torch.Tensor]) -> tuple[torch.Tensor, ...]: + if "wqkv" in fqn: + if "layer_norm" in fqn: + return (shards[0],) + return split_fused_qkv(shards) + + shard_dim = get_shard_dim(fqn) + if shard_dim is None: + return (shards[0],) + return (torch.cat(shards, dim=shard_dim),) + + fqns = list(results[0].keys()) + for result in results: + assert list(result.keys()) == fqns + + for fqn in fqns: + full_tensors = _concat_shards(fqn, [result[fqn] for result in results]) + _inplace_copy(fqn, full_tensors) + + +def _create_verified_state_dict( + pg: dist.ProcessGroup, mesh: DeviceMesh +) -> dict[str, torch.Tensor]: + placements = [Shard(0)] + state_dict = { + "tok_embeddings.weight": torch.rand( + 25256 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.attention.wqkv.layer_norm_weight": torch.rand( + 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.attention.wq.weight": torch.rand( + 640 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.attention.wk.weight": torch.rand( + 128 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.attention.wv.weight": torch.rand( + 128 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.attention.wo.weight": torch.rand( + 5120, 640 * 8, device="cuda", dtype=torch.bfloat16 + ), + # "layers.47.feed_forward.router_DE": torch.rand(5120, 128, device="cuda", dtype=torch.bfloat16), + # "layers.47.feed_forward.running_gate_stats_3E": torch.rand(3, 128, device="cuda", dtype=torch.bfloat16), + # "layers.47.feed_forward.global_gate_stats_3E": torch.rand(3, 128, device="cuda", dtype=torch.bfloat16), + "layers.47.feed_forward.w_in_shared_FD.weight": torch.rand( + 1024 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.w_out_shared_DF.weight": torch.rand( + 5120, 1024 * 8, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.w_swiglu_FD.weight": torch.rand( + 1024 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.norm.weight": torch.rand( + 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.experts.moe_w_in_eD_F": torch.rand( + 655360, 1024 * 8, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.experts.moe_w_out_eF_D": torch.rand( + 131072 * 8, 5120, device="cuda", dtype=torch.bfloat16 + ), + "layers.47.feed_forward.experts.moe_w_swiglu_eD_F": torch.rand( + 655360, 1024 * 8, device="cuda", dtype=torch.bfloat16 + ), + } + return {k: distribute_tensor(v, mesh, placements) for k, v in state_dict.items()} + + +def _verify_state_dict( + state_dict: dict[str, torch.Tensor], path: str, rank: int +) -> None: + stored_state_dicts = [ + torch.load( + os.path.join(path, f"consolidated.0{i}.pth"), + map_location="cpu", + weights_only=False, + mmap=True, + ) + for i in range(8) + ] + + def read_and_verify_tensor(fqn: str, dtensor: DTensor) -> None: + logger.info(f"Verifying {fqn} {dtensor.shape=} {dtensor.placements=} ") + shards = [stored_state_dicts[i][fqn] for i in range(8)] + full_tensor = dtensor.full_tensor() + logger.info(f"Gather {fqn} {full_tensor.shape} completely.") + + if rank > 0: + return + + if len(shards[0].shape) == 1: + assert full_tensor.shape == shards[0].shape, fqn + assert torch.allclose(shards[0].to(device="cuda"), full_tensor), fqn + return + elif shards[0].shape[0] == full_tensor.shape[0]: + concat_shards = torch.cat(shards, dim=1) + logger.info(f"Load {fqn} completely.") + elif shards[0].shape[1] == full_tensor.shape[1]: + concat_shards = torch.cat(shards, dim=0) + logger.info(f"Load {fqn} completely.") + + concat_shards = concat_shards.to(device="cuda") + logger.info(f"Move to GPU {fqn} completely.") + + assert concat_shards.shape == full_tensor.shape, fqn + assert concat_shards.dtype == full_tensor.dtype, fqn + assert concat_shards.device == full_tensor.device, fqn + assert torch.allclose(concat_shards, full_tensor), fqn + + for k, v in state_dict.items(): + if "wq" in k and "wqkv" not in k: + pass + elif "wk" in k: + pass + elif "wv" in k: + pass + else: + assert v is not None, k + read_and_verify_tensor(k, v) + + +if __name__ == "__main__": + init_logger() + config = JobConfig() + config.parser.add_argument( + "--checkpoint.convert_path", + type=str, + default="", + help="""Specify the path of the target checkpoint to convert.""", + ) + config.parser.add_argument( + "--checkpoint.convert_load_every_n_ranks", + type=int, + default=8, + help=""" + Specify the interval at which ranks are assigned to load checkpoints. + + For example, if this number is 4, then ranks 0, 4, 8, ... will load the + checkpoint. Each loader is responsible for loading one file. If there + are more loaders than files, only the first few loaders will be assigned + to load the checkpoint. The default value is 8. + """, + ) + config.parser.add_argument( + "--checkpoint.fake_model", + action="store_true", + help="""If true, the model will be fake.""", + ) + config.parse_args() + assert config.checkpoint.convert_path != "" + + trainer: Optional[Trainer] = None + + try: + trainer = Trainer(config) + if os.path.exists(trainer.checkpointer.folder): + raise RuntimeError( + "The checkpoint folder already exists. Abort to avoid overwriting " + f"the checkpoint. {trainer.checkpointer.folder=}" + ) + if config.checkpoint.fake_model: + state_dict = _create_verified_state_dict( + trainer.world_mesh.get_group(), trainer.world_mesh + ) + else: + state_dict = trainer.checkpointer.states[MODEL].state_dict() + + size = 0 + for v in state_dict.values(): + size += v.numel() * v.element_size() + logger.info(f"Total size of the model: {size / 1e9:.2f} GB") + + # Do not support PP yet, we will need to iterate over the PP dimension and + # extract the corresponding state_dict and device_mesh. + if "freq_cis" in state_dict: + state_dict.pop("freqs_cis") + + state_dict = CheckpointConverter( + process_group=trainer.world_mesh.get_group(), + path=config.checkpoint.convert_path, + loader_every_n_ranks=config.checkpoint.convert_load_every_n_ranks, + ).convert(state_dict) + + class DummyModel: + def __init__(self, state_dict: dict[str, torch.Tensor]) -> None: + self._state_dict = state_dict + + def state_dict(self) -> dict[str, torch.Tensor]: + return self._state_dict + + if config.checkpoint.fake_model: + begin = time.time() + _verify_state_dict( + state_dict, + config.checkpoint.convert_path, + trainer.world_mesh.get_rank(), + ) + dist.barrier() + logger.info(f"Verifies state_dict {time.time() - begin}.") + else: + # oh, this is pretty bad, when can we get rid of the freqs_cis issue? + state_dict["freqs_cis"] = None + trainer.checkpointer.states[MODEL] = DummyModel(state_dict) + trainer.checkpointer.model_weights_only = True + trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype + trainer.checkpointer.save(curr_step=0, force=True) + time.sleep(2) + finally: + pass diff --git a/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh b/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh new file mode 100644 index 0000000000000000000000000000000000000000..f3fd310934b1181ed83fa9fc4463f0c2336b46fc --- /dev/null +++ b/torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh @@ -0,0 +1,25 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# use envs as local overrides for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./convert_meta_to_dcp_with_gpus.sh +NGPU=${NGPU:-"8"} +LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7} +CONFIG_FILE=${CONFIG_FILE:-"../train_configs/llama4_17bx16e.toml"} + +overrides="" +if [ $# -ne 0 ]; then + overrides="$*" +fi + +PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +convert_meta_to_dcp_with_gpus_meta.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml new file mode 100644 index 0000000000000000000000000000000000000000..139a1f28bfff5136e1ff625ee00d6e015b7729ba --- /dev/null +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -0,0 +1,74 @@ +[job] +dump_folder = "./outputs" +description = "Llama 4 debug training" +print_args = false +use_for_integration_test = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "llama4" +flavor = "debugmodel" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm +# test tokenizer.model, for debug purpose only +tokenizer_path = "./tests/assets/test_tiktoken.model" +# converters = "float8" +use_flex_attn = false +attn_mask_type = "causal" # causal / block_causal + +[optimizer] +name = "AdamW" +lr = 4e-3 +eps = 1e-15 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 0.1 + +[training] +batch_size = 8 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +compile = false +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 10 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'none' # ['none', 'selective', 'full'] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = "output,router.gate" diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml new file mode 100644 index 0000000000000000000000000000000000000000..e947afba56fd3b8ee5bf1fe45e65160c99a6fd18 --- /dev/null +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml @@ -0,0 +1,65 @@ +# TODO: this toml config is still under development + +[job] +dump_folder = "./outputs" +description = "Llama 4 Maverick 17Bx128E training" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = false +save_tb_folder = "tb" + +[model] +name = "llama4" +flavor = "17bx128e" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm +tokenizer_path = "./assets/tokenizer/tokenizer.model" +# converters = "float8" + +[optimizer] +name = "AdamW" +lr = 4e-3 +eps = 1e-15 + +[lr_scheduler] +warmup_steps = 600 +lr_min = 0.1 + +[training] +batch_size = 1 +seq_len = 8192 +max_norm = 1.0 # grad norm clipping +steps = 3000 +compile = false +dataset = "c4" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 8 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 4 +# pipeline_parallel_schedule = "interleaved1f1b" +# pipeline_parallel_microbatches = 2 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 500 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'full' # ['none', 'selective', 'full'] + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = "output,router.gate" diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml new file mode 100644 index 0000000000000000000000000000000000000000..d464d2d8cfddecb0e338a48926d0650a8ecb7930 --- /dev/null +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml @@ -0,0 +1,63 @@ +# NOTE: this toml config is a preset for 64 H100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Llama 4 Scout 17Bx16E training" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = false +save_tb_folder = "tb" + +[model] +name = "llama4" +flavor = "17bx16e" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm +tokenizer_path = "./assets/tokenizer/tokenizer.model" +# converters = "float8" + +[optimizer] +name = "AdamW" +lr = 4e-3 +eps = 1e-15 + +[lr_scheduler] +warmup_steps = 600 +lr_min = 0.1 + +[training] +batch_size = 8 +seq_len = 8192 +max_norm = 1.0 # grad norm clipping +steps = 3000 +compile = false +dataset = "c4" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 8 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 500 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'full' # ['none', 'selective', 'full'] + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = "output,router.gate" diff --git a/torchtitan/experiments/multimodal/__init__.py b/torchtitan/experiments/multimodal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe08681bbc532dc23a734fd961648890cec7d497 --- /dev/null +++ b/torchtitan/experiments/multimodal/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from mm_dataset import build_mm_dataloader + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer +from torchtitan.models.llama3 import parallelize_llama, pipeline_llama +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .model import ModelArgs, MultimodalDecoder, VisionEncoder + +__all__ = ["VisionEncoder", "ModelArgs", "MultimodalDecoder"] + +llama4_mm_configs = { + # TODO: add configs for llama4 multimodal +} + +register_train_spec( + TrainSpec( + name="llama4_multimodal", + cls=MultimodalDecoder, + config=llama4_mm_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_mm_dataloader, + build_tokenizer_fn=build_tiktoken_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/multimodal/check_padding_mm.py b/torchtitan/experiments/multimodal/check_padding_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..0345009256e80ccd3e010ed270d36bff0271555a --- /dev/null +++ b/torchtitan/experiments/multimodal/check_padding_mm.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import click + +from mm_dataset import build_mm_dataloader +from tokenizer.tiktoken import build_tiktoken_tokenizer + +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import init_logger + + +@click.command() +@click.option("--dataset", default="OBELICS") +@click.option("--batch-size", default=4) +@click.option("--seq-len", default=4096) +@click.option("--tokenizer-path", required=True) +@click.option("--dp-rank", default=0) +@click.option("--dp-world-size", default=2) +@click.option("--batch-number", default=4) +def main( + dataset: str, + batch_size: int, + seq_len: int, + tokenizer_path: str, + dp_rank: int, + dp_world_size: int, + batch_number: int, +): + init_logger() + job_config = JobConfig() + job_config.parse_args( + [ + "--training.dataset", + dataset, + "--training.batch_size", + str(batch_size), + "--training.seq_len", + str(seq_len), + "--model.tokenizer_path", + tokenizer_path, + ] + ) + tokenizer = build_tiktoken_tokenizer(job_config) + dl = build_mm_dataloader( + dp_world_size=dp_world_size, + dp_rank=dp_rank, + tokenizer=tokenizer, + job_config=job_config, + ) + dl_iter = iter(dl) + + for _ in range(batch_number): + batch = next(dl_iter) + + # Analyze Batch + # input_ids + total_input_ids = batch["input_ids"].shape[0] * batch["input_ids"].shape[1] + total_non_padding_tokens = total_input_ids - int( + (batch["input_ids"] == 128004).sum() + ) + total_padding_tokens = total_input_ids - total_non_padding_tokens + print(f"Padding tokens in each sample: {(batch['input_ids'] == 128004).sum(dim=1)}") + print( + f"Unpadded tokens: {total_non_padding_tokens}, Total tokens in batch: {total_input_ids}" + ) + print( + f"Padded text tokens: {total_padding_tokens}, {(total_padding_tokens) / total_input_ids * 100:.2f}%" + ) + print(80 * "#") + # Images + padded_images = 0 + padded_tiles = 0 + for sample in batch["encoder_input"]["images"]: + for image in sample: + if int(image.sum()) == 0: + padded_images += 1 + for tile in image: + if int(tile.sum()) == 0: + padded_tiles += 1 + + total_images = ( + batch["encoder_input"]["images"].shape[0] + * batch["encoder_input"]["images"].shape[1] + ) + + print( + f"Unpadded images: {total_images - padded_images}, Total images in batch: {total_images}" + ) + print( + f'Padded images: {padded_images}, {padded_images / total_images * 100:.2f}% (Each image with shape {list(batch["encoder_input"]["images"][0, 0].shape)})' # noqa: B950 + ) + print(80 * "#") + # Tiles + total_number_of_tiles = total_images * batch["encoder_input"]["images"].shape[2] + + print( + f"Unpadded number of tiles: {total_number_of_tiles - padded_tiles}, Total number of tiles: {total_number_of_tiles}" + ) + print( + f'Padded tiles: {padded_tiles}, {padded_tiles / total_number_of_tiles * 100:.2f}% (Each with shape {list(batch["encoder_input"]["images"][0, 0, 0].shape)})' # noqa: B950 + ) + print(80 * "#") + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/multimodal/mm_collator.py b/torchtitan/experiments/multimodal/mm_collator.py new file mode 100644 index 0000000000000000000000000000000000000000..98793a7f6f9f9ad51a3f0b34a18fd102f8b99802 --- /dev/null +++ b/torchtitan/experiments/multimodal/mm_collator.py @@ -0,0 +1,227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F + +from tokenizer.tiktoken import IGNORE_INDEX + +from torch.nn.utils.rnn import pad_sequence + + +def padded_collate( + batch: List[Dict[str, List[int]]], + padding_idx: int = 0, + ignore_idx: int = -100, +) -> Dict[str, torch.Tensor]: + """Pad a batch of sequences to the longest sequence length in the batch, and + convert integer lists to tensors. + + Args: + batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs. + padding_idx (int): Padding index for input ids. Defaults to 0. + ignore_idx (int): Padding index for labels. Defaults to -100. + + Returns: + Dict[str, torch.Tensor]: Collated input and label tensors. + + Example: + >>> token_pairs = [ + >>> {"input_ids": [1, 2, 3], "labels": [4, 5, 6]}, + >>> {"input_ids": [7,], "labels": [10,]}, + >>> ] + >>> collated = padded_collate( + >>> batch=token_pairs, + >>> padding_idx=padding_idx, + >>> ignore_idx=ignore_idx, + >>> ) + >>> collated["input_ids"] + >>> tensor([[1, 2, 3], [7, 0, 0]]) + >>> collated["labels"] + >>> tensor([[4, 5, 6], [10, -100, -100]]) + """ + input_ids = pad_sequence( + [x["input_ids"] for x in batch], + batch_first=True, + padding_value=padding_idx, + ) + labels = pad_sequence( + [x["labels"] for x in batch], + batch_first=True, + padding_value=ignore_idx, + ) + + input_ids_seq_len = input_ids.shape[-1] + labels_seq_len = labels.shape[-1] + + # Hack to pad correctly and not use max_seq_len, which is costly + if input_ids_seq_len > labels_seq_len: + labels = F.pad( + labels, (0, input_ids_seq_len - labels_seq_len), value=ignore_idx + ) + elif labels_seq_len > input_ids_seq_len: + input_ids = F.pad( + input_ids, + (0, labels_seq_len - input_ids_seq_len), + value=padding_idx, + ) + return {"input_ids": input_ids, "labels": labels} + + +# NOTE Inspired from torchtune.data._collate.py +@dataclass +class MultiModalCollator: + padding_idx: int = 128004 + ignore_idx: int = IGNORE_INDEX + pad_max_tiles: Optional[int] = None + pad_max_images: Optional[int] = None + + def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + """Pad a batch of text sequences, tiled image tensors, aspect ratios, + and cross attention masks. This can be used for both training and inference. + + ``batch`` is expected to be a list of sample dicts containing the following:: + - "input_ids": List[int] of length text_seq_len, varies across samples + - "labels": List[int] of length text_seq_len, varies across samples + - "encoder_input": Dict[str, List[torch.Tensor]] + - "images": List[torch.Tensor], each with shape (n_tiles, c, h, w) + - "aspect_ratio": List[torch.Tensor], each with shape (2, ) to indicate h_ratio, w_ratio + + Shape notation: + - c = channel dim + - h = height dim + - w = weight dim + + Note: + For each element in the batch, ``len(images) == len(aspect_ratio)``. + + This collater does the following: + (1) Pad text sequence and encoder mask to the longest sequence length in the batch + (2) Pad image tensors in the tile dimension with zeros to the largest number + of tiles in the batch + (3) Add empty images of zeros to samples up to max number of images in the batch + (4) Pad aspect ratios with (1,1) for all added padding images + + Args: + batch (List[Dict[str, Any]]): A list of sample dicts containing input_ids, + labels, images, and aspect_ratio. + padding_idx (int): Padding index for input token ids. Defaults to 0. + ignore_idx (int): Padding index for labels. Defaults to -100. + pad_max_tiles (Optional[int]): Maximum number of tiles to pad to. If None, will pad to the largest number of tiles + in the batch. Defaults to None. + pad_max_images (Optional[int]): Maximum number of images to pad to. If None, will pad to the largest number of images + in the batch. Defaults to None. + + Returns: + Dict[str, Tensor]: Collated tokens, labels, images, aspect_ratio tensors. + - tokens: Tensor of shape (bsz, max_seq_len) + - labels: Tensor of shape (bsz, max_seq_len) + - images: Tensor of shape (bsz, max_num_images, max_num_tiles, c, h, w) + - aspect_ratio: Tensor of shape (bsz, max_num_images, 2) + + Example: + >>> image_id = 1 + >>> tokens_per_tile = 5 + >>> c, h, w = 1, 1, 1 + >>> batch = [ + ... { + ... "input_ids": [1, 2, 1, 3], "labels": [4, 5, 6, 7], + ... "encoder_input": { + ... # One image with two tiles, one image with three tiles + ... "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)], + ... "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])], + ... }, + ... }, + ... { + ... "input_ids": [1, 4], "labels": [8, 9], + ... "encoder_input": { + ... # One image with four tiles + ... "images": [torch.ones(4, c, h, w)], + ... "aspect_ratio": [torch.tensor([2, 2])], + ... }, + ... }, + ... ] + ... collator = MultiModalCollator(pad_max_tiles=4) + >>> model_inputs = collator(batch=batch) + >>> print(model_inputs["input_ids"]) + tensor([[1, 2, 1, 3], + [1, 4, 0, 0]]) + >>> print(model_inputs["labels"]) + tensor([[4, 5, 6, 7], + [8, 9, -100, -100]]) + >>> print(model_inputs["encoder_input"]["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w) + torch.Size([2, 2, 4, 1, 1, 1]) + >>> print(model_inputs["encoder_input"]["aspect_ratio"].shape) # (bsz, max_num_images, 2) + torch.Size([2, 2, 2]) + >>> print(model_inputs["encoder_input"]["images"][0, 0, ...]) # Image with two tiles got padded to four + tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]]) + >>> print(model_inputs["encoder_input"]["images"][0, 1, ...]) # Image with three tiles got padded to four + tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]]) + >>> print(model_inputs["encoder_input"]["images"][1, 0, ...]) # Image with four tiles did not get padded + tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]]) + >>> print(model_inputs["encoder_input"]["images"][1, 1, ...]) # Extra padding image was added to second sample + tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]]) + """ + # Text tokens can be handled independently by existing collaters + text_only = [ + {"input_ids": sample["input_ids"], "labels": sample["labels"]} + for sample in batch + ] + collated_text = padded_collate(text_only, self.padding_idx, self.ignore_idx) + + if self.pad_max_tiles is None: + # Get max number of tiles in batch + max_num_tiles = max(sample["images_tiles"].shape[0] for sample in batch) + else: + max_num_tiles = self.pad_max_tiles + + # Pad images and aspect ratios to max number of tiles + batch_images = [] + batch_aspect_ratios = [] + + for sample in batch: + sample_images = [] + for image in sample["encoder_input"]["images"]: + # Single image in each sample has shape (n_tiles, c, h, w) + n_tiles = image.shape[0] + # Single mask in each sample corresponds to a single image and has shape (text_seq_len, image_seq_len) + # where image_seq_len = n_tiles * tokens_per_tile + padding_tiles = max_num_tiles - n_tiles + + # Image should now have shape (max_num_tiles, c, h, w) + padded_image = F.pad( + image, (0, 0, 0, 0, 0, 0, 0, padding_tiles), value=0 + ) + + sample_images.append(padded_image) + # Stack multiple images and masks per sample in num_images dimension + batch_images.append(torch.stack(sample_images)) + batch_aspect_ratios.append( + torch.stack(sample["encoder_input"]["aspect_ratio"]) + ) + # Finally, pad images, masks, aspect ratios to max number of images in batch + # (bsz, max_num_images, max_num_tiles, c, h, w) + collated_images = pad_sequence(batch_images, batch_first=True, padding_value=0) + # (bsz, max_num_images, 2) + collated_aspect_ratios = pad_sequence( + batch_aspect_ratios, batch_first=True, padding_value=1 + ) + + batch_dict = { + "input_ids": collated_text["input_ids"], + "labels": collated_text["labels"], + "encoder_input": { + "images": collated_images, + "aspect_ratio": collated_aspect_ratios, + }, + } + + return batch_dict diff --git a/torchtitan/experiments/multimodal/mm_dataset.py b/torchtitan/experiments/multimodal/mm_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a29627aaceed17fd6b5f7f752d4b8a5fb006d47a --- /dev/null +++ b/torchtitan/experiments/multimodal/mm_dataset.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import torch + +from datasets import Dataset, load_dataset +from datasets.distributed import split_dataset_by_node + +from mm_collator import MultiModalCollator +from tokenizer.tiktoken import IGNORE_INDEX, Tokenizer +from torch.distributed.checkpoint.stateful import Stateful +from torch.utils.data import IterableDataset +from transform import CLIPTransform +from utils import load_image + +from torchtitan.components.dataloader import ParallelAwareDataloader +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger + + +def _load_obelics_dataset(dataset_path: str): + """Load C4 dataset with default configuration.""" + return load_dataset(dataset_path, split="train", streaming=True) + + +def _process_obelics_sample( + sample: dict[str, Any], image_token: str = "<|image|>" +) -> Dict[str, List[Union[str, "PIL.Image.Image"]]]: + """ + This function formats samples from the OBELICS dataset + Returns: + Dict[str, Any]: The transformed sample with the following fields: + - images: List[PIL.Image.Image] with the loaded images + - text: str with the text of the sample ready to be tokenized including the image tokens + Example: + >>> formatted_sample = format_obelics(sample, image_token="<|image|>") + >>> print(formatted_sample["text"]) + ... "<|image|><|image|><|image|> The elephant look cute!<|image|><|image|> The cats are sad :(" + """ + sample_images = [image for image in sample["images"] if image is not None] + sample_text = [ + text if text is not None else image_token for text in sample["texts"] + ] + return { + "images": [load_image(image) for image in sample_images], + "text": "".join(map(str, sample_text)), + } + + +@dataclass +class DatasetConfig: + path: str + loader: Callable + sample_processor: Callable + + +# Add your dataset here here - more information at docs/datasets.md +MM_DATASETS = { + "obelics": DatasetConfig( + path="HuggingFaceM4/OBELICS", + loader=_load_obelics_dataset, + sample_processor=_process_obelics_sample, + ), +} + + +def _validate_mm_dataset( + dataset_name: str, dataset_path: str = None +) -> tuple[str, Callable, Callable]: + """Validate dataset name and path.""" + if dataset_name not in MM_DATASETS: + raise ValueError( + f"Dataset {dataset_name} is not supported. " + f"Supported datasets are: {list(MM_DATASETS.keys())}" + ) + + config = MM_DATASETS[dataset_name] + path = dataset_path or config.path + logger.info(f"Preparing {dataset_name} dataset from {path}") + return path, config.loader, config.sample_processor + + +class MultiModalDataset(IterableDataset, Stateful): + """PyTorch MultiModal Dataset. + + Args: + dataset_name (str): name of the dataset to load + tokenizer (Tokenizer): + Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method. + world_size (int): number of data parallel processes participating in training + rank (int): rank of the current data parallel process + infinite (bool): whether to loop infinitely over the dataset + + We currently ONLY support the OBELICS dataset + + Example use: + >>> ds = MultiModalDataset(dataset_name="OBELICS", tokenizer=tokenizer) + >>> for batch in Dataloader(ds, batch_size=8): + print(f"Batch size: {len(batch)}") + Batch size: 8 + """ + + def __init__( + self, + dataset_name: str, + dataset_path: Optional[str], + tokenizer: Tokenizer, + image_token: str = "<|image|>", + tile_size: int = 448, + max_num_tiles: int = 4, + seq_len: int = 2048, + dp_rank: int = 0, + dp_world_size: int = 1, + infinite: bool = False, + ) -> None: + # Force lowercase for consistent comparison + dataset_name = dataset_name.lower() + + path, dataset_loader, sample_processor = _validate_mm_dataset( + dataset_name, dataset_path + ) + ds = dataset_loader(path) + + # TODO: support shuffling + self.dataset_name = dataset_name + self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) + self._tokenizer = tokenizer + self.seq_len = seq_len + self.infinite = infinite + self._sample_processor = sample_processor + self.image_token = ( + image_token # TODO(tj.solergibert) Add `image_token` to JobConfig + ) + # TODO(tj.solergibert) Add `tile_size` & `max_num_tiles` to JobConfig + self.transform_image = CLIPTransform( + image_mean=( + 0.48145466, + 0.4578275, + 0.40821073, + ), # TODO(tj.solergibert) What should we do with `image_mean` & `image_std`?, + image_std=(0.26862954, 0.26130258, 0.27577711), + tile_size=tile_size, + possible_resolutions=None, + max_num_tiles=max_num_tiles, + resample="bilinear", + resize_to_max_canvas=False, + ) + + # variables for checkpointing + self._sample_idx = 0 + + def __iter__(self): + + while True: + for sample in self._get_data_iter(): + try: + sample = self._sample_processor( + sample, image_token=self.image_token + ) + except Exception: + continue + self._sample_idx += 1 + + # CLIP Transform + encoder_input = {"images": [], "aspect_ratio": []} + for image in sample["images"]: + out = self.transform_image(image) + encoder_input["images"].append(out["image"]) + encoder_input["aspect_ratio"].append(out["aspect_ratio"]) + sample["encoder_input"] = encoder_input + + # Tokenize + tokens = self._tokenizer.encode( + sample["text"], + bos=True, + eos=True, + allowed_special=set(["<|image|>"]), + ) + sample["input_ids"] = torch.LongTensor(tokens[:-1]) + sample["labels"] = torch.LongTensor(tokens[1:]) + # Mask BOS, EOS & image tokens from the loss + sample["labels"] = torch.where( + torch.isin( + sample["labels"], + torch.LongTensor( + [ + self._tokenizer.bos_id, + self._tokenizer.eos_id, + self._tokenizer.image_id, + ] + ), + ), + IGNORE_INDEX, + sample["labels"], + ) + # Truncate + sample["input_ids"], sample["labels"] = ( + sample["input_ids"][: self.seq_len], + sample["labels"][: self.seq_len], + ) + yield sample + + if not self.infinite: + logger.warning(f"Dataset {self.dataset_name} has run out of data") + break + else: + # Reset offset for the next iteration + self._sample_idx = 0 + logger.warning(f"Dataset {self.dataset_name} is being re-looped") + + def _get_data_iter(self): + if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): + return iter([]) + + it = iter(self._data) + for _ in range(self._sample_idx): + next(it) + return it + + def load_state_dict(self, state_dict): + self._sample_idx = state_dict["sample_idx"] + + def state_dict(self): + return {"sample_idx": self._sample_idx} + + +def build_mm_dataloader( + dp_world_size: int, + dp_rank: int, + tokenizer: Tokenizer, + job_config: JobConfig, + infinite: bool = True, +) -> ParallelAwareDataloader: + """Build a data loader for HuggingFace datasets.""" + dataset_name = job_config.training.dataset + dataset_path = job_config.training.dataset_path + batch_size = job_config.training.batch_size + seq_len = job_config.training.seq_len + pad_max_tiles = 4 # TODO(tj.solergibert) Add `pad_max_tiles` to JobConfig + padding_idx = 128004 # TODO(tj.solergibert) Add `padding_idx` to JobConfig + + hf_ds = MultiModalDataset( + dataset_name=dataset_name, + dataset_path=dataset_path, + tokenizer=tokenizer, + seq_len=seq_len, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + infinite=infinite, + ) + + collate_fn = MultiModalCollator( + padding_idx=padding_idx, pad_max_tiles=pad_max_tiles + ) + + return ParallelAwareDataloader( + dataset=hf_ds, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + batch_size=batch_size, + collate_fn=collate_fn, + ) diff --git a/torchtitan/experiments/multimodal/model.py b/torchtitan/experiments/multimodal/model.py new file mode 100644 index 0000000000000000000000000000000000000000..419b3f8ab718923ac1478f951e22b9bd6391be5d --- /dev/null +++ b/torchtitan/experiments/multimodal/model.py @@ -0,0 +1,1464 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Llama 3 is licensed under the LLAMA 3 Community License, +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +@dataclass +class ModelArgs: + # encoder part + encoder_embed_dim: int = 4096 + encoder_num_layers: int = 32 + num_layers_projection: int = 32 + encoder_num_heads: int = 32 + encoder_num_kv_heads: Optional[int] = None + patch_size: int = 1 + tile_size: int = 128 + max_num_tiles: int = 8 + activation: nn.Module = nn.GELU() + # in_channels (int): The number of image input channels. + in_channels: int = 3 + # return_intermediates (Optional[List[int]]): The indices of hidden layers to return. + # If provided, it will return the intermediate results of the transformer layers + # before they go through a next layer. For example, ``return_intermediates=[0,3]`` + # will return the tokens before they go through the first and fourth layers. + return_intermediates: Optional[List[int]] = None + is_causal: bool = True + + # decoder part + decoder_embed_dim: int = 4096 # This is for linear projection to convert the output of encoder to decoder + fusion_interval: int = 1 # This is the interval of layers that are used for fusion + num_special_tokens: int = 2 # This is the number of special tokens in the tokenizer + decoder_num_layers: int = 16 + decoder_num_heads: int = 32 + decoder_num_kv_heads: Optional[int] = None + + # common part + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + + max_seq_len: int = 2048 + # If `True`, then each transformer block init uses its layer ID, and if + # `False`, each uses the total number of transformer blocks + depth_init: bool = True + norm_type: str = "rmsnorm" + + +class Fp32LayerNorm(nn.LayerNorm): + """ + Wrapper around :class:`~torch.nn.LayerNorm` to support mixed-precision training. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor. + Returns: + torch.Tensor: The normalized output tensor having the same shape as ``x``. + """ + output = nn.functional.layer_norm( + x.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(x) + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, num_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=num_rep)""" + bsz, seq_len, num_kv_heads, head_dim = x.shape + if num_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bsz, seq_len, num_kv_heads, num_rep, head_dim) + .reshape(bsz, seq_len, num_kv_heads * num_rep, head_dim) + ) + + +class Attention(nn.Module): + """ + Multi-head attention module. + + Args: + model_args (ModelArgs): Model configuration arguments. + + Attributes: + num_kv_heads (int): Number of key and value heads. + num_heads (int): Number of query heads. + num_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.num_heads = model_args.encoder_num_heads + self.num_kv_heads = ( + model_args.encoder_num_heads + if model_args.encoder_num_kv_heads is None + else model_args.encoder_num_kv_heads + ) + self.num_rep = self.num_heads // self.num_kv_heads + self.head_dim = model_args.encoder_embed_dim // model_args.encoder_num_heads + + self.wq = nn.Linear( + model_args.encoder_embed_dim, + model_args.encoder_num_heads * self.head_dim, + bias=False, + ) + self.wk = nn.Linear( + model_args.encoder_embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.wv = nn.Linear( + model_args.encoder_embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.wo = nn.Linear( + model_args.encoder_num_heads * self.head_dim, + model_args.encoder_embed_dim, + bias=False, + ) + self.is_causal = model_args.is_causal + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `num_heads` (or `num_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + if ( + freqs_cis is not None + ): # Only used in the self attention layers for text decoder + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if num_kv_heads < num_heads + keys = repeat_kv(xk, self.num_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.num_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + # we use casual mask for training + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=self.is_causal) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. + activation: (nn.Module): Activation function to use. Defaults to nn.silu. + + Attributes: + w1 (Linear): Linear transformation for the first layer, which projects input from input dim to + hidden dim, and multiplies by the projection from w3 for activation and second layer. + w2 (Linear): Linear transformation for the second layer. + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + activation: nn.Module = nn.SiLU(), + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.activation = activation + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + + def forward(self, x): + return self.w2(self.activation(self.w1(x))) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.w2.weight, mean=0.0, std=init_std) + + +class TanhGate(nn.Module): + """Implements a basic learnable gate to scale layer outputs""" + + def __init__(self) -> None: + super().__init__() + self.scale = nn.Parameter(torch.zeros(1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor to gate + + Returns: + torch.Tensor: The output tensor after gating. Has the same shape as ``x``. + """ + return x * self.scale.tanh() + + +class TilePositionalEmbedding(nn.Module): + """ + Positional embedding for tiles, different for every tile, same for every token within a tile. + + For details, please check the documentation of :class:`ViT`. + + Args: + max_num_tiles (int): The maximum number of tiles an image can be divided into. + emb_dim (int): The dimensionality of each tile embedding. + """ + + def __init__( + self, + max_num_tiles: int, + emb_dim: int, + ): + super().__init__() + self.max_num_tiles = max_num_tiles + self.emb_dim = emb_dim + self.embedding = nn.Parameter( + torch.randn(max_num_tiles, max_num_tiles, 1, emb_dim) / math.sqrt(emb_dim) + ) + self.gate = nn.Parameter(torch.zeros(1)) + + def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor): + """ + args: + x (torch.Tensor): torch.Tensor with shape (bsz * num_imgs, num_tiles, num_tokens, emb_dim). + aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * num_imgs, 2), + representing the aspect ratio of the image before tile-cropping, e.g. (2,1). + returns: + torch.Tensor: The input tensor with added positional embeddings. + """ + bsz_and_num_imgs, num_tiles, num_tokens, emb_dim = x.shape + + for batch_idx, (num_tiles_h, num_tiles_w) in enumerate(aspect_ratio): + # When we batch images, all are padded to the same amount of tiles. + # The aspect_ratio lets us know the non padded tiles for each image. + # We only add positional encoding to those. + num_non_padded_tiles = int(num_tiles_h * num_tiles_w) + + # We get only the positional encoding for non padded tiles, + # i.e. num_tiles_h, num_tiles_w. + pos_embed = self.embedding[:num_tiles_h, :num_tiles_w, :, :] + + # Add pos encoding to the non padded tiles. + pos_embed = pos_embed.reshape(num_non_padded_tiles, 1, self.emb_dim) + x[batch_idx, :num_non_padded_tiles, :, :] += pos_embed * self.gate.tanh() + + return x + + +class TokenPositionalEmbedding(nn.Module): + """ + Token positional embedding for images, different for every token in an image. + + Args: + emb_dim (int): The dimensionality of each token embedding. + tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, + the size of the input image. In this case, the function will consider your image as a single tile. + patch_size (int): The size of each patch. Used to divide the tiles into patches. + E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches + with shape (40, 40) each. + """ + + def __init__(self, emb_dim: int, tile_size: int, patch_size: int) -> None: + super().__init__() + patch_grid_size = tile_size // patch_size + scale = emb_dim**-0.5 + self.positional_embedding = nn.Parameter( + scale * torch.randn((patch_grid_size**2 + 1, emb_dim)) # +1 for CLS token + ) + + def forward(self, x: torch.Tensor, *args: Tuple[Any]) -> torch.Tensor: + """ + Args: + x (torch.Tensor): torch.Tensor with shape (..., num_tokens, emb_dim) + *args (Tuple[Any]): Optional args. + + Returns: + torch.Tensor: The input tensor with added positional embeddings. + """ + return x + self.positional_embedding + + +class TiledTokenPositionalEmbedding(nn.Module): + """ + + Token positional embedding for tiled images. There are two positional embeddings in this module: + + * local_token_positional_embedding: same for every tile, different for every token. Equivalent \ + to :class:`TokenPositionalEmbedding`, but gated. + * global_token_positional_embedding: different for every tile, different for every token. + + Notice that tile is different from patch (token). For details, please check the documentation of + :class:`ViT`. + + Args: + max_num_tiles (int): The maximum number of tiles an image can be divided into. + emb_dim (int): The dimensionality of each token embedding. + tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, + the size of the input image. In this case, the function will consider your image as a single tile. + patch_size (int): The size of each patch. Used to divide the tiles into patches. + E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches + with shape (40, 40) each. + """ + + def __init__( + self, max_num_tiles: int, emb_dim: int, tile_size: int, patch_size: int + ) -> None: + super().__init__() + patch_grid_size = tile_size // patch_size + self.num_tokens_per_tile = patch_grid_size**2 + 1 # +1 for cls token + scale = emb_dim**-0.5 + + # different for every token, same for every tile + self.local_token_positional_embedding = nn.Parameter( + scale * torch.randn((patch_grid_size**2 + 1, emb_dim)) # +1 for CLS token + ) + + # different for every token, different for every tile + self.global_token_positional_embedding = nn.Parameter( + scale + * torch.randn( + max_num_tiles, + max_num_tiles, + self.num_tokens_per_tile, + emb_dim, + ) + ) + + self.gate = nn.Parameter(torch.zeros(1)) + + def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): torch.Tensor with shape (bsz * num_imgs, num_tiles, num_tokens, emb_dim). + aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * num_imgs, 2), + where aspect_ratio[k] represents the aspect ratio of the k^th image + of the batch before tile-cropping, e.g. aspect_ratio[k] = (2,1). + Returns: + torch.Tensor: The input tensor with added positional embeddings. + """ + bsz_and_num_imgs, num_tiles, num_tokens, emb_dim = x.shape + + # apply local position embedding (same for every tile) + x = x + (self.local_token_positional_embedding * (1 - self.gate.tanh())) + + # apply global positional embedding (different for every tile) + x = x.view(bsz_and_num_imgs, num_tiles, num_tokens, emb_dim) + for batch_idx, (num_tiles_h, num_tiles_w) in enumerate(aspect_ratio): + # When we batch images, all are padded to the same amount of tiles. + # The aspect_ratio lets us know the non padded tiles for each image. + # We only add positional encoding to those. + num_non_padded_tiles = int(num_tiles_h * num_tiles_w) + + # We get only the positional encoding for non padded tiles, + # i.e. num_tiles_h, num_tiles_w. + pos_embed = self.global_token_positional_embedding[ + :num_tiles_h, :num_tiles_w, :, : + ] + + # Add pos encoding to the non padded tiles. + pos_embed = pos_embed.reshape( + num_non_padded_tiles, self.num_tokens_per_tile, emb_dim + ) + pos_embed = pos_embed * self.gate.tanh() + x[batch_idx, :num_non_padded_tiles, :, :] += pos_embed + + return x + + +class Conv2dModule(torch.nn.Module): + """Conv2D Module. + This is like Conv2D in PyTorch except: + + - PyTorch Conv2D outputs shape (*, out_channels, h_out, w_out), while this module + outputs (*, h_out * w_out, out_channels). + - We implement the conv as an unfold -> permute -> linear, where we can column-wise + shard the linear. + + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. This module also assumes a square kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + bias: bool = False, + ) -> None: + super().__init__() + self._unfold = torch.nn.Unfold( + kernel_size=(kernel_size, kernel_size), stride=stride + ) + self._linear = torch.nn.Linear( + in_channels * kernel_size * kernel_size, + out_channels, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Input: (bsz, in_channels, width, height) + # Output: (bsz, in_channels * kernel_size * kernel_size, num_tokens) + x = self._unfold(x) + x = x.permute(0, 2, 1) + # Output: (bsz, num_tokens, out_channels), when stride = kernel_size, + # num_tokens = grid ** 2 and out_channels is emd_dim. + return self._linear(x) + + +class VitTransformerBlock(nn.Module): + def __init__( + self, + model_args: ModelArgs, + attn_scale: Optional[nn.Module] = None, + mlp_scale: Optional[nn.Module] = None, + ): + super().__init__() + self.attn = Attention(model_args) + self.ln_attn = Fp32LayerNorm(model_args.encoder_embed_dim, eps=1e-5) + self.mlp = FeedForward( + dim=model_args.encoder_embed_dim, + hidden_dim=4 * model_args.encoder_embed_dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + activation=model_args.activation, + ) + self.ln_mlp = Fp32LayerNorm(model_args.encoder_embed_dim, eps=1e-5) + self.attn_scale = attn_scale or nn.Identity() + self.mlp_scale = mlp_scale or nn.Identity() + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ): + bsz, seq_len, emd_dim = x.shape + # x = x.view(bsz * seq_len, emd_dim) + x = x + self.attn_scale(self.attn(x=self.ln_attn(x), freqs_cis=None)) + x = x + self.mlp_scale(self.mlp(self.ln_mlp(x))) + # return x.view(bsz, seq_len, emd_dim) + return x + + +class CLSEmbedding(nn.Module): + """ + Adds a CLS token to every tile of an image in the beginning of each token. + + Args: + emb_dim (int): The dimensionality of the input patch embedding. + """ + + def __init__(self, emb_dim: int) -> None: + super().__init__() + + scale = emb_dim**-0.5 + self.weight = nn.Parameter(scale * torch.randn(emb_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + # add 1 CLS token to every tile + bsz_and_num_imgs, num_tiles, _, emb_dim = x.shape + cls_emb = self.weight.broadcast_to(bsz_and_num_imgs, num_tiles, 1, emb_dim) + return torch.cat([cls_emb, x], dim=2) + + +class Vit(nn.Module): + """ + Implementation of the ViT architecture (https://arxiv.org/abs/2010.11929), + with support for tile-cropped images, outputting of hidden layers. + + (credit for the documentation below: `vision_transformer.py + + `_). + + ViT is a transformer architecture that takes in images and outputs N embedded tokens that + represent this image. Each image is divided into **patches** by a convolution. + These patches are flattened and subsequently treated as **tokens** by the transformer. + + To further enhance the performance of ViT and avoid downscaling images, we support tile-cropped images, + which are images divided into **tiles** during the preprocessing stage. For example, instead of + downscaling an 800x400 image to fit 400x400, we may crop it into two 400x400 tiles, + if the ``tile_size=400``. + + Each of these tiles is further broken down into patches by a convolution operation. For example, if + your ``patch_size=40``, then each (400, 400) tile will become a grid of 10x10 patches, and your whole image will have + num_tiles * n_tokens -> num_tiles * (10x10 patches + 1 CLS token) -> num_tiles * 101. + + Before the transformer layers, a CLS token is added to each tile as the first token. + In transformers, a token called CLS is a special token that is added to the beginning of each sequence. + This token can be used to represent the whole input, instead of using a pooling operation, for example. + + To help the model "see" the whole image, we use positional embeddings. If your image + was tile-cropped, then you need to use tile positional embeddings: + + - token_pos_embedding (tiled): :class:`TiledTokenPositionalEmbedding` + - pre_tile_pos_embed: :class:`TilePositionalEmbedding` + - post_tile_pos_embed: :class:`TilePositionalEmbedding` + + Otherwise, pre and post tile_pos_embed should be None and all you need is a simple + token positional embedding: + + - token_pos_embedding (not tiled): :class:`TokenPositionalEmbedding` + + All images will be considered as a stack of tiles, even if your image was not tile-cropped. In such cases, + your image would be composed of a single tile. + + In summary: + + 1) An image is broken down into tiles during preprocessing. + 2) In the ViT, the tiles will be broken down into patches. + 3) The patches will be flattened and transformed. We call them tokens, because that's how the transformer sees them. + + Image: shape (8x8) + + .. code-block:: text + + | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | + | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | + | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | + | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | + | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | + | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | + | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | + | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | + + Tiles: shape (4,4,4) # (num_tiles, tile_size, tile_size) + + .. code-block:: text + + | 1 | 2 | 3 | 4 | | 5 | 6 | 7 | 8 | + | 9 | 10 | 11 | 12 | | 13 | 14 | 15 | 16 | + | 17 | 18 | 19 | 20 | | 21 | 22 | 23 | 24 | + | 25 | 26 | 27 | 28 | | 29 | 30 | 31 | 32 | + + | 33 | 34 | 35 | 36 | | 37 | 38 | 39 | 40 | + | 41 | 42 | 43 | 44 | | 45 | 46 | 47 | 48 | + | 49 | 50 | 51 | 52 | | 53 | 54 | 55 | 56 | + | 57 | 58 | 59 | 60 | | 61 | 62 | 63 | 64 | + + Patches: shape (4,4,2,2) # (num_tiles, num_patches_per_tile, patch_size, patch_size) + + .. code-block:: text + + | 1 | 2 | | 3 | 4 | | 5 | 6 | | 7 | 8 | + | 9 | 10 | | 11 | 12 | | 13 | 14 | | 15 | 16 | + + | 17 | 18 | | 19 | 20 | | 21 | 22 | | 23 | 24 | + | 25 | 26 | | 27 | 28 | | 29 | 30 | | 31 | 32 | + + | 33 | 34 | | 35 | 36 | | 37 | 38 | | 39 | 40 | + | 41 | 42 | | 43 | 44 | | 45 | 46 | | 47 | 48 | + + | 49 | 50 | | 51 | 52 | | 53 | 54 | | 55 | 56 | + | 57 | 58 | | 59 | 60 | | 61 | 62 | | 63 | 64 | + + token: shape (4, 4, 4) # (num_tiles, num_patches_per_tile, emb_dim) + + .. code-block:: text + + | 1 | 2 | 9 | 10 | | 3 | 4 | 11 | 12 | | 17 | 18 | 25 | 26 | | 19 | 20 | 27 | 28 | + | ... continuation of data ... + | ... continuation of data ... + | 37 | 38 | 45 | 46 | | 39 | 40 | 47 | 48 | | 53 | 54 | 61 | 62 | | 55 | 56 | 63 | 64 | + + For the positional embeddings: + + Same for every tile, different for every token. + + - :class:`TokenPositionalEmbedding` + + .. code-block:: text + + | 1 | 2 | 3 | 4 | | 1 | 2 | 3 | 4 | + | 9 | 10 | 11 | 12 | | 9 | 10 | 11 | 12 | + | 17 | 18 | 19 | 20 | | 17 | 18 | 19 | 20 | + | 25 | 26 | 27 | 28 | | 25 | 26 | 27 | 28 | + + | 1 | 2 | 3 | 4 | | 1 | 2 | 3 | 4 | + | 9 | 10 | 11 | 12 | | 9 | 10 | 11 | 12 | + | 17 | 18 | 19 | 20 | | 17 | 18 | 19 | 20 | + | 25 | 26 | 27 | 28 | | 25 | 26 | 27 | 28 | + + Different for every tile, different for every token. + + - :class:`TiledTokenPositionalEmbedding` + + .. code-block:: text + + | 1 | 2 | | 3 | 4 | | 5 | 6 | | 7 | 8 | + | 9 | 10 | | 11 | 12 | | 13 | 14 | | 15 | 16 | + + | 17 | 18 | | 19 | 20 | | 21 | 22 | | 23 | 24 | + | 25 | 26 | | 27 | 28 | | 29 | 30 | | 31 | 32 | + + | 33 | 34 | | 35 | 36 | | 37 | 38 | | 39 | 40 | + | 41 | 42 | | 43 | 44 | | 45 | 46 | | 47 | 48 | + + | 49 | 50 | | 51 | 52 | | 53 | 54 | | 55 | 56 | + | 57 | 58 | | 59 | 60 | | 61 | 62 | | 63 | 64 | + + different for every tile, same for every token within a tile. + + - :class:`TilePositionalEmbedding` + + .. code-block:: text + + | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | + | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | + | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | + | 1 | 1 | 1 | 1 | | 2 | 2 | 2 | 3 | + + | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | + | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | + | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | + | 3 | 3 | 3 | 3 | | 4 | 4 | 4 | 4 | + + Args: + model_args (ModelArgs): The model args. + + Raises: + ValueError: If `patch_size` is not greater than 0. + ValueError: If `len(return_intermediates)` is greater than `num_layers`. + """ + + def __init__( + self, + model_args: ModelArgs, + ): + super().__init__() + if model_args.patch_size <= 0: + raise ValueError(f"kernel size of conv {model_args.patch_size} must be > 0") + if model_args.return_intermediates and ( + len(model_args.return_intermediates) > model_args.encoder_num_layers + ): + raise ValueError( + "len(return_intermediates) must be <= num_layers." + f" Got {model_args.return_intermediate=} and {model_args.encoder_num_layers=}" + ) + + # For test validation purposes + patch_grid_size = model_args.tile_size // model_args.patch_size + self.patches_per_tile = patch_grid_size**2 + + self.return_intermediates = model_args.return_intermediates + + self.conv = Conv2dModule( + in_channels=model_args.in_channels, + out_channels=model_args.encoder_embed_dim, + kernel_size=model_args.patch_size, + stride=model_args.patch_size, + bias=False, + ) + + self.ln_post = Fp32LayerNorm(model_args.encoder_embed_dim) + self.ln_pre = Fp32LayerNorm(model_args.encoder_embed_dim) + self.transformer_layers = nn.ModuleList( + [ + VitTransformerBlock(model_args) + for _ in range(model_args.encoder_num_layers) + ] + ) + + self.class_embedding = CLSEmbedding(model_args.encoder_embed_dim) + # pre and post tile position embedding + if model_args.max_num_tiles > 1: + self.pre_tile_pos_embed = TilePositionalEmbedding( + max_num_tiles=model_args.max_num_tiles, + emb_dim=model_args.encoder_embed_dim, + ) + self.post_tile_pos_embed = TilePositionalEmbedding( + max_num_tiles=model_args.max_num_tiles, + emb_dim=model_args.encoder_embed_dim, + ) + self.token_pos_embedding = TokenPositionalEmbedding( + emb_dim=model_args.encoder_embed_dim, + tile_size=model_args.tile_size, + patch_size=model_args.patch_size, + ) + else: + self.pre_tile_pos_embed = None + self.post_tile_pos_embed = None + self.token_pos_embedding = TiledTokenPositionalEmbedding( + max_num_tiles=model_args.max_num_tiles, + emb_dim=model_args.encoder_embed_dim, + tile_size=model_args.tile_size, + patch_size=model_args.patch_size, + ) + + def forward( + self, images: torch.Tensor, aspect_ratio: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Processes images and returns the tokens and hidden states. + + Multiple images per sample: we add a dimension num_imgs to the input. This is useful when a single + sample constains multiple images, for example: + + - sample 1: " what animal is this?" + - sample 2: "I like more than " + + In this case, sample 1 has one image, and sample 2 has two images. max_n_imgs = max(2,1) = 2. + So your input should have shape (bsz=2, num_imgs=2, num_tiles, num_channels, tile_size_w, tile_size_h). + + Notice that to batch it, you will have to pad num_imgs to max_num_imgs and max_num_tiles. + + Args: + images (torch.Tensor): torch.Tensor with shape (bsz, num_imgs, num_tiles, num_channels, tile_size_w, tile_size_h). + aspect_ratio (Optional[torch.Tensor]): torch.Tensor with shape (bsz, n_imgs, 2). If all + images have a single tile, i.e. they were not tile-cropped, it should be None. + Used to calculate the positional embeddings for the tiles. + + Returns: + Tuple[torch.Tensor, List[torch.Tensor]]: A tuple: (x, hidden_states), + where x is a torch.tensor of shape (bsz, num_imgs, num_tiles, num_tokens, emb_dim) and + hidden_states has shape is a list of len(out_indices) torch.tensor with shape + (bsz, num_imgs, num_tiles, num_tokens, emb_dim). + + Raises: + ValueError: If aspect_ratio is None, but num_tiles > 1 in the batch. + """ + + bsz, num_imgs, num_tiles, num_channels, width, height = images.shape + + if aspect_ratio is None: + aspect_ratio = torch.ones((bsz * num_imgs, 2), dtype=torch.int).to( + device=images.device + ) + if num_tiles > 1: + raise ValueError( + f"aspect_ratio was not provided, but found num_tiles > 1 " + f"for {images.shape=}. Please provide aspect_ratio." + ) + + aspect_ratio = aspect_ratio.reshape(bsz * num_imgs, 2) + + # patch embedding + images = images.view(bsz * num_imgs * num_tiles, num_channels, width, height) + # The op is not behaving completely same as conv2d it contains a permute inside. + x = self.conv(images) # shape = [*, emb_dim, grid ** 2] + _, num_tokens, emb_dim = x.shape # num_tokens = grid ** 2 + x = x.reshape(bsz * num_imgs, num_tiles, num_tokens, emb_dim) + + # tile embeddings + if self.pre_tile_pos_embed: + x = self.pre_tile_pos_embed(x, aspect_ratio) + + # apply cls token + x = self.class_embedding(x) + num_tokens += 1 + + # apply position embeddings + x = self.token_pos_embedding(x, aspect_ratio) + + x = self.ln_pre(x) + x = x.view(bsz * num_imgs, -1, emb_dim) + + int_x = [] # intermediate outputs + for layer_idx, transformer_layer in enumerate(self.transformer_layers): + if layer_idx in self.return_intermediates: + h = x.view(bsz, num_imgs, num_tiles, num_tokens, emb_dim) + int_x.append(h) + x = transformer_layer(x) + + x = self.ln_post(x) + x = x.view(bsz * num_imgs, num_tiles, num_tokens, emb_dim) + + if self.post_tile_pos_embed: + x = self.post_tile_pos_embed(x, aspect_ratio) + + x = x.view(bsz, num_imgs, num_tiles, num_tokens, emb_dim) + return x, int_x + + +class Projection(nn.Module): + """Projection transformer to adapt the output of a + encoder (CLIP) to the decoder model. + """ + + def __init__( + self, + model_args: ModelArgs, + ) -> None: + super().__init__() + self.transformer_layers = nn.ModuleList( + [ + VitTransformerBlock( + model_args, attn_scale=TanhGate(), mlp_scale=TanhGate() + ) + for _ in range(model_args.num_layers_projection) + ] + ) + + self.num_hidden = len(model_args.return_intermediates or []) + self.output = nn.Linear( + model_args.encoder_embed_dim * (self.num_hidden + 1), + model_args.decoder_embed_dim, + ) + + def forward( + self, + x: torch.Tensor, + hidden_states: Optional[List[torch.Tensor]] = None, + ) -> torch.Tensor: + bsz, num_imgs, num_tiles, num_tokens, emb_dim = x.shape + + # apply transformer layers + x = x.view(bsz * num_imgs, num_tiles * num_tokens, emb_dim) + for layer in self.transformer_layers: + x = layer(x) + x = x.view(bsz, num_imgs, num_tiles, num_tokens, emb_dim) + + # interleave hidden states and cat with x + if self.num_hidden > 0: + assert hidden_states is not None + hidden_states = torch.stack(hidden_states, dim=-1) + hidden_states = hidden_states.view(bsz, num_imgs, num_tiles, num_tokens, -1) + x = torch.cat([x, hidden_states], dim=-1) + + # [bsz x seq x decoder_emb_dim] + return self.output(x).reshape(bsz, num_imgs * num_tiles * num_tokens, -1) + + +class VisionEncoder(nn.Module): + """Vision encoder model for Llama 3.2 Vision. This combines a vision + encoder with a projection. We define two different components. + + Args: + model_args (ModelArgs): configs for the vision encoder. + """ + + def __init__(self, model_args: ModelArgs) -> None: + super().__init__() + self.vit = Vit(model_args) + self.proj = Projection(model_args) + + def forward( + self, images: torch.Tensor, aspect_ratio: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Args: + images (torch.Tensor): + Image tensor with shape [bsz x num_imgs x num_tiles x num_channels x width x height]. + aspect_ratio (Optional[torch.Tensor]): Tensor with shape [bsz x num_imgs x 2]. If all + images have a single tile, i.e. they were not tile-cropped, it should be None. + Used to calculate the positional embeddings for the tiles. + Returns: + Tensor: output tensor of a sequence of embedings [bsz x seq_len x decoder_emb_dim] + where sequence length is num_imgs*num_tiles+num_embeds + """ + return self.proj(*self.vit(images, aspect_ratio)) + + +class FeedForwardForDecoder(nn.Module): + """ + FeedForward module for the decoder. It's different from the one in the encoder. + This is the component which is orignally used in llama3. + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class SelfAttention(nn.Module): + """ + Multi-head self attention module with rotary position. + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.num_heads = model_args.decoder_num_heads + self.num_kv_heads = ( + model_args.decoder_num_heads + if model_args.decoder_num_kv_heads is None + else model_args.decoder_num_kv_heads + ) + self.n_rep = self.num_heads // self.num_kv_heads + self.head_dim = model_args.decoder_embed_dim // model_args.decoder_num_heads + + self.wq = nn.Linear( + model_args.decoder_embed_dim, + model_args.decoder_num_heads * self.head_dim, + bias=False, + ) + self.wk = nn.Linear( + model_args.decoder_embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.wv = nn.Linear( + model_args.decoder_embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.wo = nn.Linear( + model_args.decoder_num_heads * self.head_dim, + model_args.decoder_embed_dim, + bias=False, + ) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `num_heads` (or `num_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if num_kv_heads < num_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + # we use casual mask for training + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class CrossAttention(nn.Module): + """ + Multi-head cross attention module. + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.num_heads = model_args.decoder_num_heads + self.num_kv_heads = ( + model_args.decoder_num_heads + if model_args.decoder_num_kv_heads is None + else model_args.decoder_num_kv_heads + ) + self.n_rep = self.num_heads // self.num_kv_heads + self.head_dim = model_args.decoder_embed_dim // model_args.decoder_num_heads + + self.wq = nn.Linear( + model_args.decoder_embed_dim, + model_args.decoder_num_heads * self.head_dim, + bias=False, + ) + self.wk = nn.Linear( + model_args.decoder_embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.wv = nn.Linear( + model_args.decoder_embed_dim, self.num_kv_heads * self.head_dim, bias=False + ) + self.wo = nn.Linear( + model_args.decoder_num_heads * self.head_dim, + model_args.decoder_embed_dim, + bias=False, + ) + self.q_norm = nn.RMSNorm(self.head_dim, eps=1e-05) + self.k_norm = nn.RMSNorm(self.head_dim, eps=1e-05) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + encoder_input: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ): + bs, seqlen_x, _ = x.shape + seqlen_y = encoder_input.shape[1] + xq, xk, xv = self.wq(x), self.wk(encoder_input), self.wv(encoder_input) + + # Use -1 instead of `num_heads` (or `num_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen_x, -1, self.head_dim) + xk = xk.view(bs, seqlen_y, -1, self.head_dim) + xv = xv.view(bs, seqlen_y, -1, self.head_dim) + + # repeat k/v heads if num_kv_heads < num_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen_y, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen_y, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen_x, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen_y, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen_y, head_dim) + + xq = self.q_norm(xq) + xk = self.k_norm(xk) + + # we use casual mask for training + output = F.scaled_dot_product_attention( + xq, xk, xv, attn_mask=mask, is_causal=False + ) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen_x, n_local_heads, head_dim) + output = output.view(bs, seqlen_x, -1) + return self.wo(output) + + +class DecoderTransformerSelfAttnBlock(nn.Module): + def __init__( + self, + model_args: ModelArgs, + ): + super().__init__() + self.attn = SelfAttention(model_args) + self.ln_attn = nn.RMSNorm(model_args.decoder_embed_dim, eps=1e-5) + self.mlp = FeedForwardForDecoder( + dim=model_args.decoder_embed_dim, + hidden_dim=4 * model_args.decoder_embed_dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + self.ln_mlp = nn.RMSNorm(model_args.decoder_embed_dim, eps=1e-5) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + **kwargs: Dict, + ): + bsz, seq_len, emd_dim = x.shape + x = x + self.attn(self.ln_attn(x), freqs_cis) + x = x + self.mlp(self.ln_mlp(x)) + return x + + +class DecoderTransformerCrossAttnBlock(nn.Module): + def __init__( + self, + model_args: ModelArgs, + ): + super().__init__() + self.attn = CrossAttention(model_args) + self.ln_attn = nn.RMSNorm(model_args.decoder_embed_dim) + self.mlp = FeedForward( + dim=model_args.decoder_embed_dim, + hidden_dim=4 * model_args.decoder_embed_dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + self.ln_mlp = nn.RMSNorm(model_args.decoder_embed_dim) + self.attn_scale = TanhGate() + self.mlp_scale = TanhGate() + + def _skip_mask(self, mask: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + """Some tokens in x may not attend to any encoder inputs + due to the cross attention mask (encoder_mask). This results in + a full row of the attention matrix being masked out. + + In the example below, the word "the" is masked from every embedding. + The False value means a token can't attend to an embedding. + + .. code-block:: text + + |emb||emb||emb| + |The| F F F + |red| T F T + |car| F T T + + This results in no inputs into the softmax layer which causes a NaN. + The skip mask is used to mask the outputs of attention and + mlp resulting in the token being skipped. + + The above example would result in a skip mask of: [[True], [False], [False]] + which specifies which tokens to fully mask out. + + """ + # no skip_mask if no masking + if mask is None: + return None + # negate mask and convert to boolean mask + if mask.dtype == torch.bool: + mask = ~mask + else: + mask = torch.isneginf(mask) + # True where all elements in a row are True + mask = torch.all(mask, dim=-1, keepdim=True) + return mask + + def forward( + self, + x: torch.Tensor, + *, + encoder_input: Optional[torch.Tensor] = None, + encoder_mask: Optional[torch.Tensor] = None, + **kwargs: Dict, + ) -> torch.Tensor: + # Skip cross attention when no secondary input as it's primary purpose + # is to attend between x and encoder_input. + if encoder_input is None: + return x + + # A mask of tokens (x) with no encoder_input + skip_mask = self._skip_mask(encoder_mask) + + attn_out = self.attn( + self.ln_attn(x), + encoder_input, + mask=encoder_mask, + ) + if skip_mask is not None: + attn_out.masked_fill_(skip_mask, 0) + + h = self.attn_scale(attn_out) + x + # Norm applied before the feedforward layer + mlp_out = self.mlp(self.ln_mlp(h)) + if skip_mask is not None: + mlp_out.masked_fill_(skip_mask, 0) + + # Residual connection; shape: [batch_size, seq_length, embed_dim] + out = h + self.mlp_scale(mlp_out) + + return out + + +class FusionLayer(nn.Module): + """ + Deep Fusion model architectures combine pretrained encoder models with pretrained + language models by infusing the encoder outputs into the middle layers of the LLM. + This allows the language model to interpret the enocder outputs as text and + "understand" any modality for which you can train an decoder. To enable the language model + to adapt to the encoder outputs, the FusionLayer fuses a new learnable layer to an existing + decoder (language model) layer. This additional layer can take the encoder embeddings and + learn to combine them with the token embeddings from the decoder. + """ + + def __init__( + self, layer: nn.Module, fusion_layer: nn.Module, fusion_first: bool = True + ): + super().__init__() + self.layer = layer + self.fusion_layer = fusion_layer + + def forward(self, x: torch.Tensor, **kwargs: Dict) -> torch.Tensor: + x = self.fusion_layer(x, **kwargs) + x = self.layer(x, **kwargs) + return x + + +class FusionEmbedding(nn.Module): + """ + Fusion embedding supports training additional special tokens while keeping + the original embedding frozen. When fusing new models with a language model, + there may be some additional tokens needed to support the fused language model. For + example, adding a vision encoder might necessitate additional tokens like ``<|image|>`` + to indicate an images position in text and require learning an embedding for this token. + The FusionEmbedding keeps the original embeddings frozen while learning a much smaller + second embedding for the additional tokens. During forward this module routes + the tokens to the appropriate embedding table. + """ + + def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> None: + super().__init__() + self.embedding = nn.Embedding(vocab_size, embed_dim) + self.fusion_embedding = nn.Embedding(fusion_vocab_size, embed_dim) + self.dim = embed_dim + self.num_embeddings = vocab_size + fusion_vocab_size + + def forward(self, input: torch.Tensor) -> torch.Tensor: + bsz, seq_len = input.size() + vocab_size = self.embedding.num_embeddings + + mask = input < vocab_size + # num_tokens = (input < vocab_size).sum() + tokens = torch.masked_select(input, mask) + # num_fusion_tokens = (input >= vocab_size).sum() + fusion_tokens = torch.masked_select(input, ~mask) - vocab_size + + # [batch_size x num_tokens x embed_dim] + embeds = self.embedding(tokens) + # [batch_size x num_fusion_tokens x embed_dim] + fusion_embeds = self.fusion_embedding(fusion_tokens) + + # [batch_size x seq_length x embed_dim] + out = torch.empty( + bsz, + seq_len, + self.dim, + device=self.embedding.weight.device, + dtype=self.embedding.weight.dtype, + ) + mask = mask.unsqueeze(-1).expand(bsz, seq_len, self.dim) + out.masked_scatter_(mask, embeds) + out.masked_scatter_(~mask, fusion_embeds) + return out + + +class MultimodalDecoder(nn.Module): + """Decoder multimodal model for Llama 3.2. + + Args: + model_args (ModelArgs): configs for the vision encoder. + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + + # TODO persistent should be set to false, since this buffer can be recomputed. + # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, + # compile or pipeline-tracer will not correctly handle non-persistent buffers, + # so we need to fix that. (2) if we initialize pipeline-parallel models from + # a seed checkpoint rather than calling init_weights, we need freqs_cis to be + # initialized by the checkpoint, or we need to add a separate initializer for + # just the non-persistent buffers that is called after loading checkpoints. + self.register_buffer( + "freqs_cis", self._precompute_freqs_cis(model_args), persistent=True + ) + + self.layers = [] + for idx in range(1, model_args.decoder_num_layers + 1): + # define a llama3-like decoder layer, we don't train this part. + decoder_layer = DecoderTransformerSelfAttnBlock(model_args) + # cross attention layers, mixing text and vision, + # placed every `fusion_interval` layers + if idx % model_args.fusion_interval == 0: + cross_attn_layer = DecoderTransformerCrossAttnBlock(model_args) + fusion_layer = FusionLayer( + layer=decoder_layer, fusion_layer=cross_attn_layer + ) + self.layers.append(fusion_layer) + else: + self.layers.append(decoder_layer) + + self.tok_embeddings = FusionEmbedding( + model_args.vocab_size, + model_args.num_special_tokens, + model_args.decoder_embed_dim, + ) + self.norm = nn.RMSNorm(model_args.decoder_embed_dim, eps=1e-05) + self.output = nn.Linear( + model_args.decoder_embed_dim, model_args.vocab_size, bias=False + ) + + def _precompute_freqs_cis(self, model_args) -> torch.Tensor: + return precompute_freqs_cis( + model_args.decoder_embed_dim // model_args.decoder_num_heads, + # Need to compute until at least the max token limit for generation + # (use 2x max sequence length to be safe) + model_args.max_seq_len * 2, + model_args.rope_theta, + ) + + def forward( + self, + tokens: torch.Tensor, + *, + encoder_input: Optional[torch.Tensor] = None, + encoder_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + tokens (torch.Tensor): input tensor with shape ``[b x s]`` + encoder_input (Optional[torch.Tensor]): Optional input embeds from the encoder. Shape ``[b x s_e x d_e]`` + encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between + tokens and encoder embeddings. A True value at position ``i,j`` means token ``i`` can attend + to embedding ``j`` in the decoder. Mask has shape ``[b x s x s_e]``. Default is None, + but this is required during inference if the model has been setup with any layers + which use encoder embeddings and caches have been setup. + """ + # input tensor of shape [b, s] + bsz, seq_len = tokens.shape + + # shape: [b, s, d] + h = self.tok_embeddings(tokens) + + for layer in self.layers: + # shape: [b, s, d] + h = layer( + h, + freqs_cis=self.freqs_cis, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + ) + + # shape: [b, s, d] + h = self.norm(h) + output = self.output(h).float() + + return output diff --git a/torchtitan/experiments/multimodal/tests/__init__.py b/torchtitan/experiments/multimodal/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e41cd717f6a439a9c08d76a9d0e4a54e190fc5a --- /dev/null +++ b/torchtitan/experiments/multimodal/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/experiments/multimodal/tests/test_utils.py b/torchtitan/experiments/multimodal/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7c3817db8699966a8d848ad744ccd6b6dabb3836 --- /dev/null +++ b/torchtitan/experiments/multimodal/tests/test_utils.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from typing import Optional, Union + +import torch +from torch import nn + + +def fixed_init_tensor( + shape: torch.Size, + min_val: Union[float, int] = 0.0, + max_val: Union[float, int] = 1.0, + nonlinear: bool = False, + dtype: torch.dtype = torch.float, +): + """ + Utility for generating deterministic tensors of a given shape. In general stuff + like torch.ones, torch.eye, etc can result in trivial outputs. This utility + generates a range tensor [min_val, max_val) of a specified dtype, applies + a sine function if nonlinear=True, then reshapes to the appropriate shape. + """ + n_elements = math.prod(shape) + step_size = (max_val - min_val) / n_elements + x = torch.arange(min_val, max_val, step_size, dtype=dtype) + x = x.reshape(shape) + if nonlinear: + return torch.sin(x) + return x + + +@torch.no_grad +def fixed_init_model( + model: nn.Module, + min_val: Union[float, int] = 0.0, + max_val: Union[float, int] = 1.0, + nonlinear: bool = False, + dtype: Optional[torch.dtype] = None, +): + """ + This utility initializes all parameters of a model deterministically using the + function fixed_init_tensor above. See that docstring for details of each parameter. + """ + for _, param in model.named_parameters(): + param.copy_( + fixed_init_tensor( + param.shape, + min_val=min_val, + max_val=max_val, + nonlinear=nonlinear, + dtype=param.dtype if dtype is None else dtype, + ) + ) diff --git a/torchtitan/experiments/multimodal/tokenizer/tiktoken.py b/torchtitan/experiments/multimodal/tokenizer/tiktoken.py new file mode 100644 index 0000000000000000000000000000000000000000..9d494a06f6557c0108b107dd3a3ba36832bb913f --- /dev/null +++ b/torchtitan/experiments/multimodal/tokenizer/tiktoken.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import os +from pathlib import Path +from typing import ( + AbstractSet, + Any, + cast, + Collection, + Dict, + Iterator, + List, + Literal, + Mapping, + Optional, + Sequence, + Union, +) + +import tiktoken +import torch +from tiktoken.load import load_tiktoken_bpe + +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger + +IMAGE_TOKEN_ID = 128256 +IGNORE_INDEX = -100 + + +class TikTokenizer(Tokenizer): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950 + + def __init__(self, model_path: str): + super().__init__(model_path) + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|reserved_special_token_2|>", + "<|reserved_special_token_3|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|reserved_special_token_4|>", + "<|eot_id|>", # end of turn + ] + [ + f"<|reserved_special_token_{i}|>" + for i in range(5, self.num_reserved_special_tokens - 5) + ] + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.special_tokens["<|image|>"] = IMAGE_TOKEN_ID + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + self._n_words: int = self.model.n_vocab + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.pad_id: int = -1 + self.image_id = IMAGE_TOKEN_ID + self.stop_tokens = { + self.special_tokens["<|end_of_text|>"], + self.special_tokens["<|eot_id|>"], + } + logger.info( + f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}, IMAGE ID {self.image_id}" + ) + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, + disallowed_special: Optional[Union[Literal["all"], Collection[str]]] = None, + ) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_tokens ("all"|set[str]): allowed special tokens in string + disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + assert type(s) is str + allowed_special = allowed_special or set() + disallowed_special = disallowed_special or () + + # The tiktoken tokenizer can handle <=400k chars without + # pyo3_runtime.PanicException. + TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + + # https://github.com/openai/tiktoken/issues/195 + # Here we iterate over subsequences and split if we exceed the limit + # of max consecutive non-whitespace or whitespace characters. + MAX_NO_WHITESPACES_CHARS = 25_000 + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: List[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(List[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + + def encode_multimodal(self, sample: Mapping[str, Any]) -> List[int]: + """ + Tokenizes a `str` of text and creates `labels` masking BOS, EOS and `image_id` tokens. + """ + # TODO(tj.solergibert) Should we keep `input_ids` OR `tokens` across this class, VisionCrossAttentionMask & the collator? + # For me it makes more sense to split `tokens` between `input_ids` & `labels` as in train.py BUT the `MultimodalDecoder` + # & everything else expects `tokens` + text = sample["text"] + tokens = self.encode( + text, bos=True, eos=True, allowed_special=set(["<|image|>"]) + ) + input_ids = torch.LongTensor(tokens[:-1]) + labels = torch.LongTensor(tokens[1:]) + labels = torch.where( + torch.isin( + labels, torch.LongTensor([self.bos_id, self.eos_id, self.image_id]) + ), + IGNORE_INDEX, + labels, + ) + + assert len(input_ids) == len(labels) # TODO(tj.solergibert) Delete + + sample.update({"tokens": input_ids, "labels": labels}) + + return sample + + +def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer: + return TikTokenizer(job_config.model.tokenizer_path) diff --git a/torchtitan/experiments/multimodal/transform.py b/torchtitan/experiments/multimodal/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..ecb0f989acd0b818f20116a60813c26e68438cec --- /dev/null +++ b/torchtitan/experiments/multimodal/transform.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Mapping, Optional, Tuple + +import torch + +import torchvision +from torchvision.transforms.v2 import functional as F + +from utils import ( + find_supported_resolutions, + get_canvas_best_fit, + resize_with_pad, + tile_crop, +) + +from torchtitan.tools.logging import logger + + +class CLIPTransform: + """ + This class accepts images of any size and dynamically resizes, pads, normalizes and tiles it + based on the image aspect ratio and the number of image tiles we allow. + + The algorithm will NOT distort the image to fit a certain aspect ratio, because + that leads to a significant degradation in image quality. + + The user can choose if they want to allow upscaling by using the flag ``resize_to_max_canvas``. + + For example, if an input image is of size 300x800, and we want to allow + a maximum of 16 image tiles, with side 224px, then: + + If ``resize_to_max_canvas=False``, then: + best_resolution = (448, 896) -> smallest canvas, up to 16 tiles, that doesn't require downscaling + image is NOT resized + image is padded (300, 800) -> 448,896 + Image is tiled 2x4, for a final output shape of (8, 3, 224, 224) + + If ``resize_to_max_canvas=True``, then: + best_resolution = (448, 1344) # canvas that allows maximum upscaling, with minimum padding, up to 16 tiles + image is resized without distortion (300,800) -> (448, 1194) #448 is the limiting side for the resize + image is padded (448, 1194) -> (448, 1344) + Image is tiled 2x6, for a final output shape of (10, 3, 224, 224) + + Args: + image_mean (Optional[List[float]]): Mean values of each channel, used for normalization. + Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. + image_std (Optional[List[float]]): Standard deviation values of each channel, used for normalization. + Should be the same used for the pre-trained model. If None, no normalization is performed. Default None. + possible_resolutions (Optional[List[Tuple[int, int]]]): List of possible resolutions as tuples (height, width). + where each tuple represents a possible canvas to fit the image into when calling ``get_canvas_best_fit``. + If None, this will be calculated using max_num_tiles and tile_size. Default None. + tile_size (int): Size of the tiles to divide the image into. Default 224. + max_num_tiles (Optional[int]): Only used if possible_resolutions is NOT given. + Maximum number of tiles to break an image into. + This will be used to generate possible_resolutions, + e.g. [(224, 224), (224, 448), (448, 224)] if max_num_tiles = 2 and tile_size = 224. + Default 4. + dtype (torch.dtype): Data type of the output image. Default torch.bfloat16. + resample (str): Resampling method used when resizing images. Supports any enum of + ``torchvision.transforms.InterpolationMode``, e.g. "nearest", "nearest_exact", "bilinear", "bicubic". + Default 'bilinear'. + resize_to_max_canvas (bool): "If True, the image will be upscaled without distortion to fit the largest possible + resolution from possible_resolutions. + If False, it will pick the resolution that minimizes downscaling, including no downscaling at all. + In this case, the image will only be upscaled if it's size < tile_size. Default False. + + Examples: + >>> image_transform = CLIPImageTransform( + ... image_mean=None, + ... image_std=None, + ... tile_size=224, + ... possible_resolutions=None, + ... max_num_tiles=4, + ... resample="bilinear", + ... resize_to_max_canvas=True, + ...) + >>> # create random image + >>> image = (np.random.rand(100,200,3) * 255).astype(np.uint8) + >>> image = PIL.Image.fromarray(image) + >>> output = image_transform(image) + >>> output['image'].shape # [num_tiles, num_channels, tile_size, tile_size] + torch.Size([2, 3, 224, 224]) + >>> output['ar'] # image best fits the canvas 224x448 + torch.tensor([1,2]) + """ + + def __init__( + self, + *, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, + possible_resolutions: Optional[List[Tuple[int, int]]] = None, + tile_size: int = 224, + max_num_tiles: Optional[int] = 4, + dtype: torch.dtype = torch.bfloat16, + resample: str = "bilinear", + resize_to_max_canvas: bool = False, + ) -> None: + + # get_canvas_best_fit + assert ( + possible_resolutions is not None or max_num_tiles is not None + ), f"Either possible_resolutions or max_num_tiles must be given. Got {possible_resolutions} and {max_num_tiles}" + + # If possible_resolutions are not given, then calculate possible ones based on max_num_tiles + if not possible_resolutions and max_num_tiles: + possible_resolutions = find_supported_resolutions( + max_num_tiles=max_num_tiles, tile_size=tile_size + ) + else: + possible_resolutions = possible_resolutions + + self.possible_resolutions = torch.tensor(possible_resolutions).reshape(-1, 2) + logger.debug( + f"Found possible_resolutions: {self.possible_resolutions}. Will fit the images into the canvas with best fit." + ) + + self.resize_to_max_canvas = resize_to_max_canvas + + # normalize + assert (image_mean is None) == ( + image_std is None + ), f"Need to provide both or none of image_mean and image_std. Got {image_mean=} and {image_std=}" + self.mean = image_mean + self.std = image_std + + # resize_with_pad + self.max_size = None if resize_to_max_canvas else tile_size + self.dtype = dtype + self.resample = torchvision.transforms.InterpolationMode[resample.upper()] + + # tile_crop + self.tile_size = tile_size + + def __call__(self, image: torch.Tensor) -> Mapping[str, Any]: + """ + Apply image decoding and transformations to the "image" field in the sample. + + Args: + sample (Mapping[str, Any]): A sample with an "image" field containing + a List[Message] to tokenize + + Returns: + Mapping[str, Any]: The sample with an updated "image" filed and added + "aspect_ratio" field. + """ + assert isinstance(image, torch.Tensor), "Input image must be a torch.Tensor." + + image = F.to_image(image) + image = F.grayscale_to_rgb_image(image) + image = F.to_dtype(image, dtype=self.dtype, scale=True) + + # Find the best canvas to fit the image without distortion + best_resolution = get_canvas_best_fit( + image=image, + possible_resolutions=self.possible_resolutions, + resize_to_max_canvas=self.resize_to_max_canvas, + ) + + # resize without distortion + pad to fit best_resolution + image = resize_with_pad( + image=image, + target_size=best_resolution, + resample=self.resample, + max_size=self.max_size, + ) + + # Normalize + if self.mean: + image = F.normalize(image, mean=self.mean, std=self.std) + + # Divide the image into equally sized tiles + image = tile_crop(image=image, tile_size=self.tile_size) + + aspect_ratio = torch.tensor(best_resolution).reshape(-1) // self.tile_size + + return { + "image": image, + "aspect_ratio": aspect_ratio, + } diff --git a/torchtitan/experiments/multimodal/utils.py b/torchtitan/experiments/multimodal/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c927772a5ef95ba65123c9387de4ead1e732490f --- /dev/null +++ b/torchtitan/experiments/multimodal/utils.py @@ -0,0 +1,437 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from collections import defaultdict + +from pathlib import Path +from typing import List, Optional, Set, Tuple, Union +from urllib import request + +import torch +import torchvision +from torchvision.transforms.v2 import functional as F + +# NOTE Copied from torchtune.modules.transforms.vision_utils.tile_crop.py +def tile_crop(image: torch.Tensor, tile_size: int) -> torch.Tensor: + """ + Divides a tensor into equally sized tiles. The tensor should be divisible by tile_size. + + Args: + image (torch.Tensor): Input image to crop into tiles. + tile_size (int): Size of each tile. + + Returns: + torch.Tensor: torch.Tensor of shape [num_tiles, channel_size, tile_size, tile_size] + + Examples: + >>> image = torch.rand(3, 200, 300) + >>> tiles = tile_crop(image, tile_size=50) + >>> tiles.shape # 4x6 = 24 tiles + torch.Size([24, 3, 50, 50]) + + >>> image = torch.rand(3, 400, 600) + >>> tiles = tile_crop(image, tile_size=200) + >>> tiles.shape # 2x3 = 6 tiles + torch.Size([6, 3, 200, 200]) + """ + + channel_size, height, width = image.shape + + # assert sizes are divisible + assert ( + height % tile_size == 0 and width % tile_size == 0 + ), f"Image size {height}x{width} is not divisible by tile size {tile_size}" + + # Reshape to split height and width into tile_size blocks + tiles_height = height // tile_size + tiles_width = width // tile_size + + reshaped = image.view(channel_size, tiles_height, tile_size, tiles_width, tile_size) + + # Transpose to bring tiles together + # We want [tiles_height, tiles_width, channel_size, tile_size, tile_size] + transposed = reshaped.permute(1, 3, 0, 2, 4) + + # Flatten the tiles + tiles = transposed.contiguous().view( + tiles_height * tiles_width, channel_size, tile_size, tile_size + ) + + return tiles + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py +def resize_with_pad( + image: torch.Tensor, + target_size: Tuple[int, int], + resample: torchvision.transforms.InterpolationMode, + max_size: Optional[int] = None, +) -> torch.Tensor: + """ + Resizes and pads an image to target_size without causing distortion. + The user can set max_size to limit upscaling when target_size exceeds image_size. + + Args: + image (torch.Tensor): The input image tensor in the format [..., H, W]. + target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width]. + resample (torchvision.transforms.InterpolationMode): Resampling method used when resizing images. + Supports torchvision.transforms.InterpolationMode.NEAREST, InterpolationMode.NEAREST_EXACT, + InterpolationMode.BILINEAR and InterpolationMode.BICUBIC. + max_size (Optional[int]): The maximum size to upscale the image to. + If None, will upscale up to target_size. + + Returns: + torch.Tensor: The resized and padded image tensor in the format [..., H, W]. + + Examples: + + Example 1: The image will be upscaled from (300, 800) to (448, 1194), since 448 is the limiting side, + and then padded from (448, 1194) to (448, 1344). + + >>> max_size = None + >>> image = torch.rand([3, 300, 800]) + >>> target_size = (448, 1344) + >>> resample = torchvision.transforms.InterpolationMode.BILINEAR + >>> output = resize_with_pad(image, target_size, resample, max_size) + + Example 2: The image will stay as is, since 800 > 600, and then padded from (300, 800) to (448, 1344). + + >>> max_size = 600 + >>> image = torch.rand([3, 300, 800]) + >>> target_size = (448, 1344) + >>> resample = torchvision.transforms.InterpolationMode.BILINEAR + >>> output = resize_with_pad(image, target_size, resample, max_size) + + Example 3: The image will be downscaled from (500, 1000) to (224, 448), + and padded from (224, 448) to (448, 448). + + >>> max_size = 600 + >>> image = torch.rand([3, 500, 1000]) + >>> target_size = (448, 488) + >>> resample = torchvision.transforms.InterpolationMode.BILINEAR + >>> output = resize_with_pad(image, target_size, resample, max_size) + + """ + + image_height, image_width = image.shape[-2:] + image_size = (image_height, image_width) + + # If target_size requires upscaling, we might want to limit the upscaling to max_size + if max_size is not None: + new_target_height = min(max(image_height, max_size), target_size[0]) + new_target_width = min(max(image_width, max_size), target_size[1]) + target_size_resize = (new_target_height, new_target_width) + else: + target_size_resize = target_size + + # resize to target_size while preserving aspect ratio + new_size_preserving_aspect_ratio = _get_max_res_without_distortion( + image_size=image_size, + target_size=target_size_resize, + ) + + image = F.resize( + inpt=image, + size=list(new_size_preserving_aspect_ratio), + interpolation=resample, + antialias=True, + ) + + image = _pad_image_top_left(image=image, target_size=target_size) + + return image + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py +def _pad_image_top_left( + image: torch.Tensor, + target_size: Tuple[int, int], +) -> torch.Tensor: + """ + Places the image at the top left of the canvas and pads with 0 the right and bottom + to fit to the target resolution. If target_size < image_size, it will crop the image. + + Args: + image (torch.Tensor): The input image tensor in the format [..., H, W]. + target_size (Tuple[int, int]): The desired resolution to fit the image into in the format [height, width]. + + Returns: + torch.Tensor: The padded image tensor in the format [..., H, W]. + """ + + image_size = image.shape[-2:] + + height, width = image_size + target_height, target_width = target_size + + pad_x = target_width - width + pad_y = target_height - height + + padding = [0, 0, pad_x, pad_y] + return F.pad(inpt=image, padding=padding) + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.resize_with_pad.py +def _get_max_res_without_distortion( + image_size: Tuple[int, int], + target_size: Tuple[int, int], +) -> Tuple[int, int]: + """ + Determines the maximum resolution to which an image can be resized to without distorting its + aspect ratio, based on the target resolution. + + For example, if image_size = (200,400) and target_size = (600,800), + scale_h = 600/200 = 3 + scale_w = 800/400 = 2 + So the maximum that we can upscale without distortion is min(scale_h, scale_w) = 2 + + Since scale_w is the limiting side, then new_w = target_w, and new_h = old_h*scale_w + + Args: + image_size (Tuple[int, int]): The original resolution of the image. + target_size (Tuple[int, int]): The desired resolution to fit the image into. + Returns: + Tuple[int, int]: The optimal dimensions to which the image should be resized. + Examples: + >>> _get_max_res_without_distortion([200, 300], target_size = (450, 200)) + (133, 200) + >>> _get_max_res_without_distortion([800, 600], target_size = (450, 1300)) + (450, 337) + """ + + original_height, original_width = image_size + target_height, target_width = target_size + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.floor(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.floor(original_width * scale_h), target_width) + + return new_height, new_width + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py +def _get_factors(n: int) -> Set[int]: + """ + Calculate all factors of a given number, i.e. a divisor that leaves no remainder. + + Args: + n (int): The number to find factors for. + + Returns: + set: A set containing all factors of the number. + + Examples: + >>> _get_factors(n=12) + {1, 2, 3, 4, 6, 12} + """ + factors_set = set() + + for i in range(1, int(n**0.5) + 1): + if n % i == 0: + factors_set.add(i) + factors_set.add(n // i) + return factors_set + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py +def get_canvas_best_fit( + image: torch.Tensor, possible_resolutions: torch.Tensor, resize_to_max_canvas: bool +) -> Tuple[int, int]: + """ + Determines the best canvas possible from a list of possible resolutions to + resize an image to, without distortion. + + For each possible resolution, calculates the scaling factors for + width and height, and selects the smallest one, which is the limiting side. + E.g. if to match a canvas shape you have to upscale an image's height by 2x, and width by 1.5x, + then the maximum upscaling without distortion is min(2, 1.5) = 1.5. + + If there are multiple canvases that satisfy the conditions, + we pick the one with the lowest area to minimize padding. + + Args: + image (torch.Tensor): The image we want to fit into a canvas. + possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each + row represents a possible canvas. + resize_to_max_canvas (bool): If True, pick the canvas that allows maximum scaling. + If False, pick the canvas that minimizes downscaling, including no downscaling at all. + + Returns: + Tuple[int, int]: The best resolution to fit the image into. + + Examples: + >>> image = torch.rand(3, 200, 300) + >>> possible_resolutions = torch.tensor([ + ... [224, 672], + ... [672, 224], + ... [224, 448], + ... [448, 224], + ... [224, 224] + ... ]) + >>> get_canvas_best_fit(image, possible_resolutions, resize_to_max_canvas=False) + (224, 448) + + In the example above, we calculate the scaling factors for each possible resolution + + >>> scale_height = torch.tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200]) + >>> scale_width = torch.tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467]) + >>> scales = torch.tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467]) + + Two options have scaling_factor > 1, since resize_to_max_canvas is False, we pick the smallest + + >>> upscaling_options = torch.tensor([1.1200, 1.1200]) + >>> selected_scale = torch.tensor(1.1200) + + There are two possible options, so we pick the one with the smallest area + + >>> areas = torch.tensor([150528, 100352]) # for resolutions [672, 224] and [224, 448], respectively + >>> optimal_canvas = torch.tensor([224, 448]) # resolution with the smallest area + """ + + original_height, original_width = image.shape[-2:] + + # possible resolutions heights/widths + target_heights, target_widths = ( + possible_resolutions[:, 0], + possible_resolutions[:, 1], + ) + + # scaling factors to resize the image without distortion + scale_w = target_widths / original_width + scale_h = target_heights / original_height + + # get limiting side scaling -> no distortion + scales = torch.where(scale_w > scale_h, scale_h, scale_w) + + # filter only scales that allow upscaling + upscaling_options = scales[scales >= 1] + if len(upscaling_options) > 0: + if resize_to_max_canvas: + selected_scale = torch.max(upscaling_options) + else: + selected_scale = torch.min(upscaling_options) + else: + # no upscaling possible, + # get the minimum downscaling (max scale for scales<1) + downscaling_options = scales[scales < 1] + selected_scale = torch.max(downscaling_options) + + # get all resolutions that support this scaling factor, + # e.g. you can upscale to 224x224, 224x448, 224x672 without distortion + chosen_canvas = possible_resolutions[scales == selected_scale] + + # if there are multiple resolutions, + # get the one with minimum area to reduce padding + if len(chosen_canvas) > 1: + areas = chosen_canvas[:, 0] * chosen_canvas[:, 1] + optimal_idx = torch.argmin(areas) + optimal_canvas = chosen_canvas[optimal_idx] + else: + optimal_canvas = chosen_canvas[0] + + return tuple(optimal_canvas.tolist()) + + +# NOTE Copied from torchtune.modules.transforms.vision_utils.get_canvas_best_fit.py +def find_supported_resolutions( + max_num_tiles: int, tile_size: int +) -> List[Tuple[int, int]]: + """ + Computes all combinations of resolutions, multiple of tile_size, + that contain up to max_num_tiles. Useful for when dividing an image into tiles. + + For example, if we want at most 2 tiles per image, then we can support the + following resolutions: (1x1, 1x2, 2x1) * tile_size + + Args: + max_num_tiles (int): Maximum number of tiles. + tile_size (int): Size of the side of the tile. + + Returns: + List[Tuple[int, int]]: List of possible resolutions as tuples (height, width). + + Examples: + + >>> max_num_tiles = 4 + >>> tile_size = 224 + >>> find_supported_resolutions(max_num_tiles, tile_size) + [(224, 896), (448, 448), (224, 224), (896, 224), (224, 672), (672, 224), (224, 448), (448, 224)] + """ + + # create dictionary {aspect_ratio: [resolution1, ..., resolution n]} + # example {0.25: [(1,4)], 1.0: [(2,2), (1,1)], 4.0: [(4,1)]} + asp_dict = defaultdict(list) + for _tile_size in range(max_num_tiles, 0, -1): + factors = sorted(_get_factors(_tile_size)) + asp_ratios = [(factor, _tile_size // factor) for factor in factors] + for height, width in asp_ratios: + ratio_float = height / width + asp_dict[ratio_float].append((height, width)) + + # get the resolutions multiplied by the tile_size + possible_resolutions = [] + for ar, resolution in asp_dict.items(): + for height, width in resolution: + possible_resolutions.append((height * tile_size, width * tile_size)) + + return possible_resolutions + + +# NOTE Copied from torchtune.data._utils.py +def load_image(image_loc: Union[Path, str]) -> torch.Tensor: + """ + Convenience method to load an image in torch.Tensor format from a local file path or remote source. + + Args: + image_loc (Union[Path, str]): Local file path or remote source pointing to the image + which will be loaded in PIL format. + + Note: + If loading an image from a remote source, the function expects the URL provided in ``image_loc`` + to start with "http" or "https" e.g. "https://www.wikipedia.org/en/bird.jpg". + + Raises: + ValueError: If the image cannot be loaded from remote source, **or** + if the image cannot be opened as a :class:`~torch.Tensor`. + + Examples: + >>> # Load from remote source + >>> image = load_image("https://www.wikipedia.org/en/bird.jpg") + + >>> # Load from local file path + >>> image = load_image(Path("/home/user/bird.jpg")) + + Returns: + torch.Tensor: The loaded image. + """ + + # If pointing to remote source, try to load to local + if isinstance(image_loc, str) and image_loc.startswith("http"): + try: + image_loc = request.urlopen(image_loc).read() + image = torchvision.io.decode_image( + torch.frombuffer(image_loc, dtype=torch.uint8), + mode="RGB", + ) + except Exception as e: + raise ValueError("Failed to load remote image as torch.Tensor") from e + + # Open the local image as a Tensor image + else: + try: + image = torchvision.io.decode_image(image_loc, mode="RGB") + except Exception as e: + raise ValueError("Failed to load local image as torch.Tensor") from e + + return image diff --git a/torchtitan/experiments/simple_fsdp/README.md b/torchtitan/experiments/simple_fsdp/README.md new file mode 100644 index 0000000000000000000000000000000000000000..887653ac0298369a04df9b791b9676bd7c6107c1 --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/README.md @@ -0,0 +1,40 @@ +## SimpleFSDP + +This folder includes an experimental frontend implementation for [SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile](https://arxiv.org/abs/2411.00284). SimpleFSDP is a compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations. + +### Enable SimpleFSDP Training + +```bash +CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ./run_train.sh --model.name llama3_simple_fsdp --training.compile --training.mixed_precision_param float32 +``` + +Note: The mixed precision training support is on-going. We set `training.mixed_precision_param` to `float32` for now and will remove it once the integration is completed. + +### Composability Support + +Some of the features require the updates from PyTorch, with which we are working on providing composability support for the following features: + +| Feature | Support | +| :--------: | :--------: | +|Meta Initialization| ✅ | +|Activation Checkpointing| ✅ | +|Mixed Precision Training| 🚧 | +|Tensor Parallelism| 🚧 | +|Context Parallelism| ✅ | +|Pipeline Parallelism| ✅ | +|Distributed Checkpointing| 🚧 | +|Float8 Training| ❌ | + + +### Citation + +If you find SimpleFSDP useful, please kindly consider citing the following paper: + +```latex +@article{zhang2024simplefsdp, + title={SimpleFSDP: Simpler Fully Sharded Data Parallel with torch. compile}, + author={Zhang, Ruisi and Liu, Tianyu and Feng, Will and Gu, Andrew and Purandare, Sanket and Liang, Wanchao and Massa, Francisco}, + journal={arXiv preprint arXiv:2411.00284}, + year={2024} +} +``` diff --git a/torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc b/torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95ce36f96551b910901aff75d4c96c5d27b240cf Binary files /dev/null and b/torchtitan/experiments/simple_fsdp/__pycache__/__init__.cpython-312.pyc differ diff --git a/torchtitan/experiments/simple_fsdp/__pycache__/model.cpython-312.pyc b/torchtitan/experiments/simple_fsdp/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..960459f76940ba7ad9eab050a0f67df5186c6b37 Binary files /dev/null and b/torchtitan/experiments/simple_fsdp/__pycache__/model.cpython-312.pyc differ diff --git a/torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-312.pyc b/torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..124bb39c98d765648d4ada2b5ed5a93af33f7cad Binary files /dev/null and b/torchtitan/experiments/simple_fsdp/__pycache__/parallelize_llama.cpython-312.pyc differ diff --git a/torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc b/torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff91850a63120f94746242c306acc1938eb24c84 Binary files /dev/null and b/torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc differ diff --git a/torchtitan/experiments/simple_fsdp/model.py b/torchtitan/experiments/simple_fsdp/model.py new file mode 100644 index 0000000000000000000000000000000000000000..63104169b8fa14ed7032182c1ad08b782cd715fe --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/model.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.models.llama3 import Transformer, TransformerModelArgs +from .simple_fsdp import disable_data_parallel + + +class SimpleFSDPTransformer(Transformer): + def __init__(self, model_args: TransformerModelArgs): + super().__init__(model_args) + self.init_weights() + + def init_weights(self, *args, **kwargs): + with disable_data_parallel(): + super().init_weights(*args, **kwargs) diff --git a/torchtitan/experiments/simple_fsdp/parallelize_llama.py b/torchtitan/experiments/simple_fsdp/parallelize_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..25d696db27e90e292465aa7b9c6ffa20ae8f0508 --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/parallelize_llama.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from torch.distributed import DeviceMesh + +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims +from torchtitan.models.llama3.parallelize_llama import apply_ac +from torchtitan.tools.logging import logger + +from .simple_fsdp import data_parallel, MixedPrecisionPolicy + + +def parallelize_llama( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + # TODO(ruisizhang123): Add support for TP (on-going) + # if parallel_dims.tp_enabled: + # if ( + # job_config.parallelism.enable_async_tensor_parallel + # and not job_config.training.compile + # ): + # raise RuntimeError("Async TP requires --training.compile") + + # enable_float8_linear = "float8" in job_config.model.converters + # float8_is_rowwise = job_config.float8.recipe_name in ( + # "rowwise", + # "rowwise_with_gw_hp", + # ) + + # # For now, float8 all-gather with TP is only supported for tensorwise + # # float8 scaling recipes. For rowwise recipes, we use regular TP and + # # all-gather happens in high precision. + # enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + + # apply_tp( + # model, + # world_mesh["tp"], + # loss_parallel=parallel_dims.loss_parallel_enabled, + # enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + # enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, + # ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + + # apply data parallel + if ( + parallel_dims.dp_replicate_enabled + or parallel_dims.dp_shard_enabled + or parallel_dims.cp_enabled + ): + if parallel_dims.dp_replicate_enabled: + if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mode = "hybrid_shard" + else: + dp_mesh_dim_names = ("dp_replicate",) + dp_mode = "replicate" + else: + dp_mesh_dim_names = ("dp_shard_cp",) + dp_mode = "fully_shard" + + mp_policy = MixedPrecisionPolicy( + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + ) + + model = data_parallel( + model, + world_mesh[tuple(dp_mesh_dim_names)], + mode=dp_mode, + ac_mode=job_config.activation_checkpoint.mode, + mp_policy=mp_policy, + ) + logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode) + + if job_config.training.compile: + torch._inductor.config.reorder_for_peak_memory = False + model = torch.compile(model, fullgraph=True) + + return model diff --git a/torchtitan/experiments/simple_fsdp/tests/__init__.py b/torchtitan/experiments/simple_fsdp/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e41cd717f6a439a9c08d76a9d0e4a54e190fc5a --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/experiments/simple_fsdp/tests/test_numerics.py b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py new file mode 100644 index 0000000000000000000000000000000000000000..3c15ce573b9c65f9f26cefcbdbcd0f5b2f5c9713 --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import copy + +import torch +from torch.distributed._composable.fsdp import fully_shard + +from torch.testing._internal.common_fsdp import FSDPTest + +from torchtitan.components.loss import cross_entropy_loss +from torchtitan.distributed import ParallelDims +from torchtitan.experiments.simple_fsdp.simple_fsdp import data_parallel + + +class TestSimpleFSDP(FSDPTest): + def init_test(self): + self.optimizer = torch.optim.Adam + self.loss_fn = cross_entropy_loss + data_parallel_shard_degree = -1 + if self.mode == "replicate": + self.dp_mesh_dim_names = ("dp_replicate",) + data_parallel_replicate_degree = self.world_size + elif self.mode == "fully_shard": + self.dp_mesh_dim_names = ("dp_shard_cp",) + data_parallel_replicate_degree = 1 + elif self.mode == "hybrid_shard": + self.dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + data_parallel_replicate_degree = self.world_size // 2 + else: + raise ValueError(f"Unsupported mode {mode}") + + self.parallel_dims = ParallelDims( + dp_shard=data_parallel_shard_degree, + dp_replicate=data_parallel_replicate_degree, + cp=1, + tp=1, + pp=1, + world_size=self.world_size, + enable_loss_parallel=True, + ) + self.device_mesh = self.parallel_dims.build_mesh(device_type="cuda") + + def get_input(self): + inputs = torch.randn(8, 8).cuda() + labels = torch.randn(8, 8).cuda() + model = torch.nn.Linear(8, 8) + return model, inputs, labels + + def run_fsdp2(self, model, inputs, labels, epoch=20): + fully_shard(model, mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)]) + optim = self.optimizer(model.parameters(), lr=1e-4) + losses = [] + for _ in range(epoch): + optim.zero_grad() + out = model(inputs) + loss = self.loss_fn(out, labels) + loss.backward() + optim.step() + losses.append(loss) + return losses + + def run_simple_fsdp(self, model, inputs, labels, epoch=20): + model = data_parallel( + model, + device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)], + mode=self.mode, + ) + optim = self.optimizer(model.parameters(), lr=1e-4) + losses = [] + for _ in range(epoch): + optim.zero_grad() + out = model(inputs) + loss = self.loss_fn(out, labels) + loss.backward() + optim.step() + losses.append(loss) + return losses + + def test_replicate_convergence(self): + # unit test for replicate mode + self.mode = "replicate" + self.init_test() + model, inputs, labels = self.get_input() + + fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels) + simple_fsdp_replicate_losses = self.run_simple_fsdp( + copy.deepcopy(model), inputs, labels + ) + + for fsdp2_loss, simple_fsdp_replicate_loss in zip( + fsdp2_losses, simple_fsdp_replicate_losses + ): + assert torch.allclose(fsdp2_loss, simple_fsdp_replicate_loss) + + def test_fullyshard_convergence(self): + # unit test for fully_shard mode + self.mode = "fully_shard" + self.init_test() + model, inputs, labels = self.get_input() + + fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels) + simple_fsdp_fullyshard_losses = self.run_simple_fsdp( + copy.deepcopy(model), inputs, labels + ) + + for fsdp2_loss, simple_fsdp_fullyshard_loss in zip( + fsdp2_losses, simple_fsdp_fullyshard_losses + ): + assert torch.allclose(fsdp2_loss, simple_fsdp_fullyshard_loss) + + def test_hybridshard_convergence(self): + # unit test for hybrid_shard mode + self.mode = "hybrid_shard" + self.init_test() + model, inputs, labels = self.get_input() + + fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels) + simple_fsdp_hybridshard_losses = self.run_simple_fsdp( + copy.deepcopy(model), inputs, labels + ) + + for fsdp2_loss, simple_fsdp_hybridshard_loss in zip( + fsdp2_losses, simple_fsdp_hybridshard_losses + ): + assert torch.allclose(fsdp2_loss, simple_fsdp_hybridshard_loss) diff --git a/torchtitan/models/llama3/__pycache__/model.cpython-312.pyc b/torchtitan/models/llama3/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab25d02968d44b2cca513aaa2dd88b486cad06e1 Binary files /dev/null and b/torchtitan/models/llama3/__pycache__/model.cpython-312.pyc differ diff --git a/torchtitan/models/llama3/__pycache__/parallelize_llama.cpython-312.pyc b/torchtitan/models/llama3/__pycache__/parallelize_llama.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d05b89f3b85c58a09ef22d880e251a3f984531c7 Binary files /dev/null and b/torchtitan/models/llama3/__pycache__/parallelize_llama.cpython-312.pyc differ