lingbot-vla / lingbotvla /distributed /parallel_state.py
bazaar-research's picture
Upload folder using huggingface_hub
fb11af9 verified
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py
import math
import os
from dataclasses import dataclass
from functools import wraps
from typing import TYPE_CHECKING, Callable, Literal, Optional
import torch
from torch import distributed as dist
from ..utils import logging
from ..utils.import_utils import is_torch_npu_available, is_torch_version_greater_than
if is_torch_version_greater_than("2.4"):
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
from torch.distributed.device_mesh import DeviceMesh
logger = logging.get_logger(__name__)
_PARALLEL_STATE: "ParallelState" = None
def requires_mesh(fn: Callable) -> Callable:
@wraps(fn)
def _inner(self: "ParallelState", *args, **kwargs):
if self.device_mesh is None:
raise ValueError("Device mesh is not initialized.")
return fn(self, *args, **kwargs)
return _inner
def init_ep_mesh_matrix(ep_size: int, ep_fsdp_size: int, ep_outside: bool = False) -> "DeviceMesh":
"""
Initialize the device mesh matrix for the EP.
Args:
ep_size (int): The size of the EP.
ep_fsdp_size (int): The size of the EP-FSDP.
ep_outside (bool): Whether the EP is outside in ep-fsdp group.
"""
if ep_outside:
with torch.device("cpu"):
mesh = torch.arange(math.prod((ep_size, ep_fsdp_size)), dtype=torch.int).view(ep_size, ep_fsdp_size)
else:
with torch.device("cpu"):
mesh = (
torch.arange(math.prod((ep_size, ep_fsdp_size)), dtype=torch.int)
.view(ep_fsdp_size, ep_size)
.transpose(0, 1)
)
return mesh
@dataclass(frozen=True)
class ParallelState:
dp_size: int = 1
dp_replicate_size: int = 1
dp_shard_size: int = 1
tp_size: int = 1
ep_size: int = 1
pp_size: int = 1
cp_size: int = 1
ulysses_size: int = 1
dp_mode: Literal["ddp", "fsdp1", "fsdp2"] = "fsdp1"
device_type: str = "npu" if is_torch_npu_available() else "cuda"
include_sp_in_fsdp: bool = True
device_mesh: Optional["DeviceMesh"] = None
ep_fsdp_device_mesh: Optional["DeviceMesh"] = None
def __post_init__(self):
if not self.include_sp_in_fsdp:
raise NotImplementedError("Decoupled sequence parallel has not been implemented.")
if self.cp_size > 1:
raise NotImplementedError("Ring attention is not supported yet.")
if self.pp_size * self.dp_size * self.cp_size * self.ulysses_size * self.tp_size != self.world_size:
raise ValueError("The product of parallel sizes should be equal to the world size.")
if self.dp_replicate_size * self.dp_shard_size != self.dp_size:
raise ValueError(
f"The product of dp_replicate_size: {self.dp_replicate_size} and dp_shard_size: {self.dp_shard_size} should be equal to dp_size: {self.dp_size}."
)
if self.sp_enabled:
from ..distributed.sequence_parallel import (
init_sequence_parallel,
set_context_parallel_group,
set_data_parallel_group,
set_ulysses_sequence_parallel_group,
set_unified_sequence_parallel_group,
)
if self.device_mesh is not None:
set_data_parallel_group(self.device_mesh.get_group("dp"))
if self.ulysses_size > 1:
set_ulysses_sequence_parallel_group(self.device_mesh.get_group("ulysses"))
if self.cp_size > 1:
set_context_parallel_group(self.device_mesh.get_group("cp"))
# set unified sequence parallel group
set_unified_sequence_parallel_group(self.device_mesh.get_group("sp"))
else:
init_sequence_parallel(
ulysses_size=self.ulysses_size,
sep_dp=True,
ulysses_group_key="default",
cp_size=self.cp_size,
)
@property
def is_initialized(self) -> bool:
return dist.is_initialized()
@property
def local_rank(self) -> int:
return int(os.getenv("LOCAL_RANK", "-1"))
@property
def global_rank(self) -> int:
if self.is_initialized:
return dist.get_rank()
return -1
@property
def world_size(self) -> int:
if self.is_initialized:
return dist.get_world_size()
return 1
# ------------------------------ DP ------------------------------ #
@property
def dp_group(self) -> Optional["ProcessGroup"]:
if self.device_mesh is not None:
return self.device_mesh.get_group("dp")
if self.sp_enabled:
from ..distributed.sequence_parallel import get_data_parallel_group
return get_data_parallel_group()
return self.fsdp_group
@property
def dp_rank(self) -> int:
if self.device_mesh is not None:
return self.device_mesh.get_local_rank("dp")
if self.sp_enabled:
from ..distributed.sequence_parallel import get_data_parallel_rank
return get_data_parallel_rank()
return self.fsdp_rank
@property
@requires_mesh
def dp_mesh(self) -> "DeviceMesh":
if self.device_mesh is not None:
return self.device_mesh["dp"]
raise self.fsdp_mesh
@property
def dp_enabled(self) -> bool:
return self.dp_size > 1
# ------------------------------ DP replicate ------------------------------ #
@property
def dp_replicate_group(self) -> Optional["ProcessGroup"]:
if self.device_mesh is not None:
return self.device_mesh.get_group("dp_replicate")
@property
def dp_replicate_rank(self) -> int:
if self.device_mesh is not None:
return self.device_mesh.get_local_rank("dp_replicate")
@property
@requires_mesh
def dp_replicate_mesh(self) -> "DeviceMesh":
if self.device_mesh is not None:
return self.device_mesh["dp_replicate"]
@property
def dp_replicate_enabled(self) -> bool:
return self.dp_replicate_size > 1
# ------------------------------ DP shard ------------------------------ #
@property
def dp_shard_group(self) -> Optional["ProcessGroup"]:
if self.device_mesh is not None:
return self.device_mesh.get_group("dp_shard")
@property
def dp_shard_rank(self) -> int:
if self.device_mesh is not None:
return self.device_mesh.get_local_rank("dp_shard")
@property
@requires_mesh
def dp_shard_mesh(self) -> "DeviceMesh":
if self.device_mesh is not None:
return self.device_mesh["dp_shard"]
@property
def dp_shard_enabled(self) -> bool:
return self.dp_shard_size >= 1
# ----------------------------- FSDP ----------------------------- #
@property
def fsdp_group(self) -> Optional["ProcessGroup"]:
if self.device_mesh is not None:
return self.device_mesh.get_group("dp_sp")
@property
def fsdp_rank(self) -> int:
if self.device_mesh is not None:
return self.device_mesh.get_local_rank("dp_sp")
return self.global_rank
@property
def dp_shard_sp_enabled(self) -> bool:
return self.dp_shard_enabled and self.sp_enabled
@property
@requires_mesh
def fsdp_mesh(self) -> "DeviceMesh":
if self.dp_replicate_enabled:
# HSDP
if self.dp_shard_sp_enabled:
return self.device_mesh["dp_replicate", "dp_shard_sp"]
elif self.dp_shard_enabled:
return self.device_mesh["dp_replicate", "dp_shard"]
else:
# DDP
return self.device_mesh["dp_replicate"]
# FSDP
elif self.dp_shard_sp_enabled:
return self.device_mesh["dp_shard_sp"]
elif self.dp_shard_enabled:
return self.device_mesh["dp_shard"]
else:
return self.device_mesh["dp"]
@property
def fsdp_enabled(self) -> bool:
return self.fsdp_size > 1
@property
def fsdp_size(self) -> int:
return self.world_size // (self.pp_size * self.tp_size)
# ------------------------------ TP ------------------------------ #
@property
@requires_mesh
def tp_rank(self) -> int:
return self.device_mesh.get_local_rank("tp")
@property
@requires_mesh
def tp_mesh(self) -> "DeviceMesh":
return self.device_mesh["tp"]
@property
def tp_enabled(self) -> bool:
return self.tp_size > 1
# ------------------------------ PP ------------------------------ #
@property
@requires_mesh
def pp_rank(self) -> int:
return self.device_mesh.get_local_rank("pp")
@property
@requires_mesh
def pp_mesh(self) -> "DeviceMesh":
return self.device_mesh["pp"]
@property
def pp_enabled(self) -> bool:
return self.pp_size > 1
@property
@requires_mesh
def is_first_pp_stage(self) -> bool:
return self.pp_rank == 0
@property
@requires_mesh
def is_last_pp_stage(self) -> bool:
return self.pp_rank == (self.pp_size - 1)
# ------------------------------ EP ------------------------------ #
@property
@requires_mesh
def ep_mesh(self) -> "DeviceMesh":
return self.ep_fsdp_device_mesh["ep"]
@property
@requires_mesh
def ep_fsdp_mesh(self) -> "DeviceMesh":
return self.ep_fsdp_device_mesh["ep", "ep_fsdp"]
@property
@requires_mesh
def ep_group(self) -> "ProcessGroup":
return self.ep_mesh.get_group()
@property
def ep_enabled(self) -> bool:
return self.ep_size > 1
@property
def ep_rank(self) -> int:
return self.ep_fsdp_device_mesh.get_local_rank("ep")
# ------------------------------ SP ------------------------------ #
@property
def sp_group(self) -> Optional["ProcessGroup"]:
if self.device_mesh is not None:
return self.device_mesh.get_group("sp")
if self.sp_enabled:
from .sequence_parallel import get_unified_sequence_parallel_group
return get_unified_sequence_parallel_group()
return None
@property
def sp_rank(self) -> int:
if self.device_mesh is not None:
return self.device_mesh.get_local_rank("sp")
if self.sp_enabled:
from .sequence_parallel import get_unified_sequence_parallel_rank
return get_unified_sequence_parallel_rank()
return -1
@property
def sp_enabled(self) -> bool:
return self.cp_size > 1 or self.ulysses_size > 1
@property
def sp_size(self) -> int:
return self.ulysses_size * self.cp_size
@property
def ulysses_group(self) -> Optional["ProcessGroup"]:
if self.device_mesh is not None:
return self.device_mesh.get_group("ulysses")
if self.sp_enabled:
from .sequence_parallel import get_ulysses_sequence_parallel_group
return get_ulysses_sequence_parallel_group()
return None
@property
def ulysses_rank(self) -> int:
if self.device_mesh is not None:
return self.device_mesh.get_local_rank("ulysses")
if self.sp_enabled:
from .sequence_parallel import get_ulysses_sequence_parallel_rank
return get_ulysses_sequence_parallel_rank()
return -1
@property
def ulysses_enabled(self) -> bool:
return self.ulysses_size > 1
@property
def cp_group(self) -> Optional["ProcessGroup"]:
if self.device_mesh is not None:
return self.device_mesh.get_group("cp")
if self.sp_enabled:
from .sequence_parallel import get_context_parallel_group
return get_context_parallel_group()
return None
@property
def cp_rank(self) -> int:
if self.device_mesh is not None:
return self.device_mesh.get_local_rank("cp")
if self.sp_enabled:
from .sequence_parallel import get_context_parallel_rank
return get_context_parallel_rank()
return -1
@property
def cp_enabled(self) -> bool:
return self.cp_size > 1
def init_parallel_state(
dp_size: int = 1,
dp_replicate_size: int = 1,
dp_shard_size: int = 1,
tp_size: int = 1,
ep_size: int = 1,
pp_size: int = 1,
cp_size: int = 1,
ulysses_size: int = 1,
dp_mode: Literal["ddp", "fsdp1", "fsdp2"] = "fsdp1",
device_type: str = None,
include_sp_in_fsdp: bool = True,
ep_outside: bool = False,
) -> None:
"""
Initializes global parallel state.
"""
global _PARALLEL_STATE
if _PARALLEL_STATE is not None:
logger.warning("Parallel state has already been initialized.")
return
if device_type is None:
device_type = "npu" if is_torch_npu_available() else "cuda"
# Set dp_shard_size to dp_size if dp_shard_size and dp_replicate_size are not set when dp enabled
if dp_size > 1 and dp_shard_size == 1 and dp_replicate_size == 1:
dp_shard_size = dp_size
logger.info_rank0(
f"Initializing parallel state... dp_size {dp_size}, dp_replicate_size {dp_replicate_size}, dp_shard_size {dp_shard_size},tp_size {tp_size}, pp_size {pp_size}, cp_size {cp_size}, ulysses_size {ulysses_size}"
)
device_mesh, ep_fsdp_device_mesh = None, None
if is_torch_version_greater_than("2.4"):
mesh_shape = []
mesh_dim_names = []
for d, name in zip(
[pp_size, dp_replicate_size, dp_shard_size, ulysses_size, cp_size, tp_size],
["pp", "dp_replicate", "dp_shard", "ulysses", "cp", "tp"],
):
if d > 1 or name in ["dp_shard"]:
mesh_shape.append(d)
mesh_dim_names.append(name)
device_mesh = init_device_mesh(
device_type=device_type,
mesh_shape=tuple(mesh_shape),
mesh_dim_names=tuple(mesh_dim_names),
)
# Mesh for data loading (no communication on this mesh)
dp_mesh_dim_names = []
# Mesh for param sharding
dp_shard_sp_mesh_dim_names = []
# Mesh for loss all-reduce
dp_sp_mesh_dim_names = []
# Mesh for sequence parallel
sp_mesh_dim_names = []
if dp_replicate_size > 1:
dp_mesh_dim_names.append("dp_replicate")
dp_sp_mesh_dim_names.append("dp_replicate")
if dp_shard_size >= 1:
dp_mesh_dim_names.append("dp_shard")
dp_shard_sp_mesh_dim_names.append("dp_shard")
dp_sp_mesh_dim_names.append("dp_shard")
if ulysses_size > 1:
dp_shard_sp_mesh_dim_names.append("ulysses")
sp_mesh_dim_names.append("ulysses")
dp_sp_mesh_dim_names.append("ulysses")
if cp_size > 1:
dp_shard_sp_mesh_dim_names.append("cp")
sp_mesh_dim_names.append("cp")
dp_sp_mesh_dim_names.append("cp")
if dp_mesh_dim_names != []:
device_mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
if dp_shard_sp_mesh_dim_names != []:
device_mesh[tuple(dp_shard_sp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_sp")
if dp_sp_mesh_dim_names != []:
device_mesh[tuple(dp_sp_mesh_dim_names)]._flatten(mesh_dim_name="dp_sp")
if sp_mesh_dim_names != []:
device_mesh[tuple(sp_mesh_dim_names)]._flatten(mesh_dim_name="sp")
if ep_size > 1:
world_size = dist.get_world_size()
assert world_size % ep_size == 0, "ep_size must be a factor of world_size"
ep_fsdp_size = world_size // ep_size
mesh = init_ep_mesh_matrix(ep_size=ep_size, ep_fsdp_size=ep_fsdp_size, ep_outside=ep_outside)
ep_fsdp_device_mesh = DeviceMesh(
device_type=device_type,
mesh=mesh,
mesh_dim_names=("ep", "ep_fsdp"),
)
logger.info_rank0(f"Device mesh: {device_mesh}")
logger.info_rank0(f"EP FSDP device mesh: {ep_fsdp_device_mesh}")
_PARALLEL_STATE = ParallelState(
dp_size=dp_size,
dp_replicate_size=dp_replicate_size,
dp_shard_size=dp_shard_size,
tp_size=tp_size,
ep_size=ep_size,
pp_size=pp_size,
cp_size=cp_size,
ulysses_size=ulysses_size,
dp_mode=dp_mode,
device_type=device_type,
include_sp_in_fsdp=include_sp_in_fsdp,
device_mesh=device_mesh,
ep_fsdp_device_mesh=ep_fsdp_device_mesh,
)
def get_parallel_state() -> "ParallelState":
"""
Returns global parallel state.
"""
if _PARALLEL_STATE is None:
logger.warning_once("Parallel state has not been initialized. returning default Single-process state.")
return ParallelState()
return _PARALLEL_STATE