lingbot-vla / lingbotvla /utils /dist_utils.py
bazaar-research's picture
Upload folder using huggingface_hub
fb11af9 verified
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Union
import torch
from torch import distributed as dist
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
def all_gather(tensor: "torch.Tensor", world_size: int) -> "torch.Tensor":
"""
Gathers the tensor from all ranks and concats them along the first dim.
"""
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device="cuda")
dist.all_gather_into_tensor(output_tensor, tensor)
return output_tensor.view(-1, *tensor.size()[1:])
def all_reduce(
data: Union[int, float, List[Union[int, float]], "torch.Tensor"],
op: Literal["mean", "sum", "max"] = "mean",
group: Optional["ProcessGroup"] = None,
) -> Union[int, float, List[Union[int, float]]]:
"""
Performs all reduce in the given process group.
"""
if not isinstance(data, torch.Tensor):
data = torch.tensor(data, dtype=torch.float, device="cuda")
reduce_ops = {"mean": dist.ReduceOp.SUM, "sum": dist.ReduceOp.SUM, "max": dist.ReduceOp.MAX}
dist.all_reduce(data, op=reduce_ops[op], group=group)
if op == "mean": # ReduceOp.AVG is not supported by the NPU backend
data /= dist.get_world_size(group=group)
if data.numel() == 1:
return data.item()
else:
return data.tolist()
@contextmanager
def main_process_first(local_only: bool = True) -> None:
"""
A context manager for torch distributed environment to do something on the main process firstly.
"""
if int(os.getenv("WORLD_SIZE", "1")) > 1:
is_main_process = int(os.getenv("LOCAL_RANK")) == 0 if local_only else int(os.getenv("RANK")) == 0
try:
if not is_main_process:
dist.barrier()
yield
finally:
if is_main_process:
dist.barrier()
else:
yield
def execute_in_order(task: Callable, *, local_only: bool = True, **kwargs) -> Any:
"""
Executes the task in the order of rank.
"""
world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1") if local_only else os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("LOCAL_RANK", "1") if local_only else os.getenv("RANK", "1"))
if world_size > 1:
dist.barrier()
for i in range(world_size):
if rank == i:
result = task(**kwargs)
dist.barrier()
else:
dist.barrier()
return result
else:
return task(**kwargs)