Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| from megatron.core import parallel_state | |
| from torch import Tensor | |
| from torch.distributed import ProcessGroup, all_gather, broadcast_object_list, get_process_group_ranks, get_world_size | |
| from torch.distributed.utils import _verify_param_shape_across_processes | |
| from cosmos_predict1.utils import distributed | |
| def split_inputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: | |
| """ | |
| Split input tensor along the sequence dimension for checkpoint parallelism. | |
| This function divides the input tensor into equal parts along the specified | |
| sequence dimension, based on the number of ranks in the checkpoint parallelism group. | |
| It then selects the part corresponding to the current rank. | |
| Args: | |
| x: Input tensor to be split. | |
| seq_dim: The dimension along which to split the input (sequence dimension). | |
| cp_group: The process group for checkpoint parallelism. | |
| Returns: | |
| A slice of the input tensor corresponding to the current rank. | |
| Raises: | |
| AssertionError: If the sequence dimension is not divisible by the number of ranks. | |
| """ | |
| cp_ranks = get_process_group_ranks(cp_group) | |
| cp_size = len(cp_ranks) | |
| assert x.shape[seq_dim] % cp_size == 0, f"{x.shape[seq_dim]} cannot divide cp_size {cp_size}" | |
| x = x.view(*x.shape[:seq_dim], cp_size, x.shape[seq_dim] // cp_size, *x.shape[(seq_dim + 1) :]) | |
| seq_idx = torch.tensor([cp_group.rank()], device=x.device) | |
| x = x.index_select(seq_dim, seq_idx) | |
| # Note that the new sequence length is the original sequence length / cp_size | |
| x = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) | |
| return x | |
| def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: | |
| """ | |
| Concatenate outputs from different ranks in the checkpoint parallelism group. | |
| This function gathers tensors from all ranks in the checkpoint parallelism group | |
| and concatenates them along the specified sequence dimension. | |
| Args: | |
| x: Input tensor to be concatenated. | |
| seq_dim: The dimension along which to concatenate the tensors (sequence dimension). | |
| cp_group: The process group for checkpoint parallelism. | |
| Returns: | |
| A tensor that is the concatenation of tensors from all ranks in the cp_group. | |
| Raises: | |
| RuntimeError: If the gather operation fails. | |
| """ | |
| # Get the world size (number of processes in the group) | |
| world_size = get_world_size(cp_group) | |
| # Create a list to store tensors from all ranks | |
| gathered_tensors = [torch.zeros_like(x) for _ in range(world_size)] | |
| # Gather tensors from all ranks | |
| try: | |
| all_gather(gathered_tensors, x, group=cp_group) | |
| except RuntimeError as e: | |
| raise RuntimeError(f"Failed to gather tensors: {e}") | |
| # Concatenate the gathered tensors along the specified dimension | |
| return torch.cat(gathered_tensors, dim=seq_dim) | |
| def broadcast(item: torch.Tensor | str | None, to_tp: bool = True, to_cp: bool = True) -> torch.Tensor | str | None: | |
| """ | |
| Broadcast the item from the minimum rank in the specified group(s). | |
| Since global rank = tp_rank + cp_rank * tp_size + ... | |
| First broadcast in the tp_group and then in the cp_group will | |
| ensure that the item is broadcasted across ranks in cp_group and tp_group. | |
| Parameters: | |
| - item: The item to broadcast (can be a torch.Tensor, str, or None). | |
| - to_tp: Whether to broadcast to the tensor model parallel group. | |
| - to_cp: Whether to broadcast to the context parallel group. | |
| """ | |
| if not parallel_state.is_initialized(): | |
| return item | |
| tp_group = parallel_state.get_tensor_model_parallel_group() | |
| cp_group = parallel_state.get_context_parallel_group() | |
| to_tp = to_tp and parallel_state.get_tensor_model_parallel_world_size() > 1 | |
| to_cp = to_cp and parallel_state.get_context_parallel_world_size() > 1 | |
| if to_tp: | |
| min_tp_rank = min(get_process_group_ranks(tp_group)) | |
| if to_cp: | |
| min_cp_rank = min(get_process_group_ranks(cp_group)) | |
| if isinstance(item, torch.Tensor): # assume the device is cuda | |
| # log.info(f"{item.shape}", rank0_only=False) | |
| if to_tp: | |
| # torch.distributed.broadcast(item, min_tp_rank, group=tp_group) | |
| item = _robust_broadcast(item, min_tp_rank, tp_group) | |
| if to_cp: | |
| # torch.distributed.broadcast(item, min_cp_rank, group=cp_group) | |
| item = _robust_broadcast(item, min_cp_rank, cp_group) | |
| elif item is not None: | |
| broadcastable_list = [item] | |
| if to_tp: | |
| # log.info(f"{broadcastable_list}", rank0_only=False) | |
| broadcast_object_list(broadcastable_list, min_tp_rank, group=tp_group) | |
| if to_cp: | |
| broadcast_object_list(broadcastable_list, min_cp_rank, group=cp_group) | |
| item = broadcastable_list[0] | |
| return item | |
| def _robust_broadcast(tensor: torch.Tensor, src: int, pg, is_check_shape: bool = False) -> torch.Tensor: | |
| """ | |
| Perform a robust broadcast operation that works regardless of tensor shapes on different ranks. | |
| Args: | |
| tensor (torch.Tensor): The tensor to broadcast (on src rank) or receive (on other ranks). | |
| src (int): The source rank for the broadcast. Defaults to 0. | |
| Returns: | |
| torch.Tensor: The broadcasted tensor on all ranks. | |
| """ | |
| # First, broadcast the shape of the tensor | |
| if distributed.get_rank() == src: | |
| shape = torch.tensor(tensor.shape).cuda() | |
| else: | |
| shape = torch.empty(tensor.dim(), dtype=torch.long).cuda() | |
| if is_check_shape: | |
| _verify_param_shape_across_processes(pg, [shape]) | |
| torch.distributed.broadcast(shape, src, group=pg) | |
| # Resize the tensor on non-src ranks if necessary | |
| if distributed.get_rank() != src: | |
| tensor = tensor.new_empty(shape.tolist()).type_as(tensor) | |
| # Now broadcast the tensor data | |
| torch.distributed.broadcast(tensor, src, group=pg) | |
| return tensor | |