| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| from torch import Tensor | |
| from torch.distributed import ProcessGroup, all_gather, get_world_size | |
| def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: | |
| """ | |
| Concatenates tensors from multiple processes along a specified dimension. | |
| This function gathers tensors from all processes in the given process group | |
| and concatenates them along the specified dimension. | |
| Args: | |
| x (Tensor): The input tensor to be gathered and concatenated. | |
| seq_dim (int): The dimension along which to concatenate the gathered tensors. | |
| cp_group (ProcessGroup): The process group containing all the processes involved in the gathering. | |
| Returns: | |
| Tensor: A tensor resulting from the concatenation of tensors from all processes. | |
| Raises: | |
| RuntimeError: If the gathering of tensors fails. | |
| """ | |
| # Number of processes in the group | |
| world_size = get_world_size(cp_group) | |
| # List to hold tensors from each rank | |
| gathered_tensors = [torch.zeros_like(x) for _ in range(world_size)] | |
| # Attempt to gather tensors from all ranks | |
| try: | |
| all_gather(gathered_tensors, x, group=cp_group) | |
| except RuntimeError as e: | |
| raise RuntimeError(f"Gathering failed: {e}") | |
| # Concatenate tensors along the specified dimension | |
| return torch.cat(gathered_tensors, dim=seq_dim) | |