Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import List | |
| import torch | |
| import torch.distributed as dist | |
| from megatron.core import mpu, parallel_state | |
| from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors | |
| from torch.autograd import Function | |
| from torch.distributed import broadcast, get_process_group_ranks | |
| from transformer_engine.pytorch.jit import no_torch_dynamo | |
| from transformer_engine.pytorch.module.base import TransformerEngineBaseModule | |
| from transformer_engine.pytorch.module.rmsnorm import RMSNorm as RMSNormTE | |
| from transformer_engine.pytorch.module.rmsnorm import _RMSNorm | |
| from cosmos_predict1.utils import log | |
| def get_batch_on_this_cp_rank(inputs): | |
| """Slice batch input along sequence dimension into multiple chunks, | |
| which are parallelized across GPUs in a context parallel group. | |
| """ | |
| # With causal masking, each token only attends to its prior tokens. Simply split | |
| # sequence into CP chunks can result in severe load imbalance. That's to say, chunks | |
| # at the end of sequence have bigger workload than others. To address this issue, | |
| # we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0 | |
| # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so | |
| # that we can get balanced workload among GPUs in a context parallel group. | |
| cp_size = parallel_state.get_context_parallel_world_size() | |
| if cp_size > 1: | |
| cp_rank = mpu.get_context_parallel_rank() | |
| seq_dim = 1 # if key != 'attention_mask' else 2 | |
| inputs = inputs.view( | |
| *inputs.shape[0:seq_dim], | |
| 2 * cp_size, | |
| inputs.shape[seq_dim] // (2 * cp_size), | |
| *inputs.shape[(seq_dim + 1) :], | |
| ) | |
| index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda( | |
| non_blocking=True | |
| ) | |
| inputs = inputs.index_select(seq_dim, index) | |
| inputs = inputs.view(*inputs.shape[0:seq_dim], -1, *inputs.shape[(seq_dim + 2) :]) | |
| return inputs | |
| def gather_batch_from_cp_ranks(outputs): | |
| """ | |
| Gather and reconstruct the full batch from chunks distributed across GPUs in a context parallel group. | |
| """ | |
| cp_size = parallel_state.get_context_parallel_world_size() | |
| cp_rank = mpu.get_context_parallel_rank() | |
| if cp_size > 1: | |
| seq_dim = 1 # Assuming sequence dimension is 1 | |
| try: | |
| # Reshape output to separate the two chunks | |
| chunk_size = outputs.shape[seq_dim] // 2 | |
| outputs = outputs.view(*outputs.shape[:seq_dim], 2, chunk_size, *outputs.shape[seq_dim + 1 :]) | |
| # Prepare a list to gather all chunks from all ranks | |
| gathered_chunks = [torch.zeros_like(outputs) for _ in range(cp_size)] | |
| # Gather all chunks | |
| dist.barrier() | |
| dist.all_gather(gathered_chunks, outputs, group=parallel_state.get_context_parallel_group()) | |
| dist.barrier() | |
| # Reorder chunks | |
| reordered_chunks = [None] * (2 * cp_size) | |
| for i in range(cp_size): | |
| reordered_chunks[i] = gathered_chunks[i].select(seq_dim, 0) | |
| reordered_chunks[2 * cp_size - 1 - i] = gathered_chunks[i].select(seq_dim, 1) | |
| # Concatenate all chunks | |
| outputs = torch.cat(reordered_chunks, dim=seq_dim) | |
| except Exception as e: | |
| log.info(f"[Rank {cp_rank}] Error in gather_batch_from_cp_ranks: {str(e)}") | |
| raise | |
| return outputs | |
| def broadcast_data_batch_in_tp_cp_group(data_batch): | |
| """ | |
| Broadcast data batch across tensor model parallel and context parallel groups. | |
| """ | |
| keys = sorted(data_batch.keys()) | |
| tp_size = parallel_state.get_tensor_model_parallel_world_size() | |
| cp_size = parallel_state.get_context_parallel_world_size() | |
| tp_group = parallel_state.get_tensor_model_parallel_group() if tp_size > 1 else None | |
| cp_group = parallel_state.get_context_parallel_group() if cp_size > 1 else None | |
| tp_ranks = get_process_group_ranks(tp_group) if tp_size > 1 else None | |
| cp_ranks = get_process_group_ranks(cp_group) if cp_size > 1 else None | |
| if tp_size > 1 or cp_size > 1: | |
| for key in keys: | |
| tensor = data_batch[key] | |
| if isinstance(tensor, torch.Tensor): | |
| tensor = tensor.contiguous() | |
| if tp_size > 1: | |
| broadcast(tensor, min(tp_ranks), group=tp_group) | |
| if cp_size > 1: | |
| broadcast(tensor, min(cp_ranks), group=cp_group) | |
| def allreduce_layernorm_grads(model: List[torch.nn.Module], tensor_model_parallel_size: int, sequence_parallel: bool): | |
| """ | |
| All-reduce layernorm grads (for sequence parallelism). | |
| Note: | |
| - We skip QK Normalization layers and the last normalization layer of Transformer, | |
| since we use AllReduceBWDRMSNormTE for these layers, which already applies all-reduce in the backward pass. | |
| - TransformerEngine's LayernormLinear and LayernormMLP modules have `*.layer_norm_weight` parameters that | |
| we must all-reduce in the backward pass as well. So we implement this function to cover these parameters. | |
| """ | |
| # All-reduce layernorm parameters across model parallel nodes | |
| # when sequence parallelism is used | |
| if tensor_model_parallel_size > 1 and sequence_parallel: | |
| grads = [] | |
| for model_chunk in model: | |
| for name, param in model_chunk.named_parameters(): | |
| if not param.requires_grad: | |
| continue | |
| if name.endswith(".layer_norm_weight"): # TP # Q-layernorm # K-layernorm | |
| grad = param.grad | |
| if grad is not None: | |
| grads.append(grad.data) | |
| if grads: | |
| coalesced = _flatten_dense_tensors(grads) | |
| torch.distributed.all_reduce(coalesced, group=parallel_state.get_tensor_model_parallel_group()) | |
| for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): | |
| buf.copy_(synced) | |
| def sync_1d_parameters(model: torch.nn.Module, process_group=None): | |
| """ | |
| Synchronize layernorm parameters (1D) across ranks by performing all-reduce with mean operation. | |
| LayerNorm parameters are identified by having ndim==1. | |
| Note: If parameters other than LayerNorm are 1D, they will also be synchronized. | |
| Args: | |
| model (torch.nn.Module): The model containing layernorm parameters | |
| process_group (optional): The process group to perform all-reduce. | |
| If None, uses the default process group. | |
| """ | |
| if not torch.distributed.is_initialized(): | |
| return | |
| # Synchronize each 1D parameter (layernorm parameters) | |
| for name, param in model.named_parameters(): | |
| if param.ndim == 1 and param.requires_grad: # LayerNorm weights/biases are 1D | |
| torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.AVG, group=process_group) | |
| class AllReduceBWD(Function): | |
| """ | |
| Custom autograd Function that performs an all-reduce operation during the backward pass. | |
| Args: | |
| tensor (Tensor): The input tensor. | |
| process_group: The process group to perform the all-reduce operation. | |
| Returns: | |
| Tensor: The input tensor in the forward pass, and the all-reduced gradient in the backward pass. | |
| """ | |
| def forward(ctx, tensor, process_group): | |
| ctx.process_group = process_group | |
| return tensor | |
| def backward(ctx, grad_output): | |
| dist.all_reduce(grad_output, group=ctx.process_group) | |
| return grad_output, None | |
| class AllReduceBWDRMSNormTE(RMSNormTE): | |
| """ | |
| A custom RMSNorm layer that applies all-reduce operation during backward pass. | |
| Used in tensor parallel training with Transformer Engine. | |
| Args: | |
| hidden_size (int): The size of the hidden dimension. | |
| process_group: Megatron Core's process group. | |
| **kwargs: Additional arguments to be passed to RMSNormTE. | |
| """ | |
| def __init__(self, hidden_size, process_group, **kwargs): | |
| super().__init__(hidden_size, **kwargs) | |
| self.process_group = process_group | |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: | |
| """RMSNorm FWD""" | |
| # Set the activation type for AMP. | |
| TransformerEngineBaseModule.set_activation_dtype(self, inp) | |
| if torch.is_grad_enabled(): | |
| fwd_fn = _RMSNorm.apply | |
| args = [] | |
| else: | |
| fwd_fn = _RMSNorm.forward | |
| args = [None] | |
| args += ( | |
| inp, | |
| AllReduceBWD.apply(self.weight, self.process_group), | |
| self.eps, | |
| self.fwd_rmsnorm_sm_margin, | |
| self.bwd_rmsnorm_sm_margin, | |
| self.inf_rmsnorm_sm_margin, | |
| self.zero_centered_gamma, | |
| torch.is_grad_enabled(), | |
| self.activation_dtype, | |
| ) | |
| return fwd_fn(*args) | |