bazaar-research's picture
Upload folder using huggingface_hub
fb11af9 verified
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from typing import Any, Optional
import torch.distributed as dist
from torch.distributed import ProcessGroup
_DATA_PARALLEL_GROUP = None
_ULYSSES_SEQUENCE_PARALLEL_GROUP = {"default": None}
_ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP = {"default": None}
_ULYSSES_GROUP_KEY = "default"
_CONTEXT_PARALLEL_GROUP = None
_UNIFIED_SEQUENCE_PARALLEL_GROUP = None
_UNIFIED_SEQUENCE_PARALLEL_CPU_GROUP = None
# ------------------------------ Data Parallel ------------------------------ #
def set_data_parallel_group(group: dist.ProcessGroup):
"""
Set data parallel process group.
"""
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = group
def get_data_parallel_group() -> Optional[dist.ProcessGroup]:
"""
Get data parallel process group.
"""
global _DATA_PARALLEL_GROUP
return _DATA_PARALLEL_GROUP
def get_data_parallel_rank() -> Optional[dist.ProcessGroup]:
"""
Get data parallel rank.
"""
group = get_data_parallel_group()
return dist.get_rank(group)
def get_data_parallel_world_size() -> Optional[dist.ProcessGroup]:
"""
Get data parallel world_size.
"""
group = get_data_parallel_group()
return dist.get_world_size(group)
# ----------------------------- Ulysses Parallel ---------------------------- #
def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup, group_key: str = "default"):
"""
Set ulysses sequence parallel process group.
"""
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
_ULYSSES_SEQUENCE_PARALLEL_GROUP[group_key] = group
def set_ulysses_sequence_parallel_cpu_group(group: dist.ProcessGroup, group_key: str = "default"):
"""
Set ulysses sequence parallel process group.
"""
global _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP
_ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP[group_key] = group
def set_ulysses_sequence_parallel_group_key(group_key: str = "default"):
"""
Set ulysses sequence parallel process group key.
"""
global _ULYSSES_GROUP_KEY
_ULYSSES_GROUP_KEY = group_key
def get_ulysses_sequence_parallel_group_key() -> str:
"""
Get ulysses sequence parallel group key.
"""
global _ULYSSES_GROUP_KEY
return _ULYSSES_GROUP_KEY
def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
"""
Get ulysses sequence parallel process group.
"""
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
group_key = get_ulysses_sequence_parallel_group_key()
if group_key not in _ULYSSES_SEQUENCE_PARALLEL_GROUP:
raise RuntimeError(f"Unknown key {group_key} in Ulysses sequence parallel group!")
return _ULYSSES_SEQUENCE_PARALLEL_GROUP[group_key]
def get_ulysses_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]:
"""
Get ulysses sequence parallel CPU process group.
"""
global _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP
group_key = get_ulysses_sequence_parallel_group_key()
if group_key not in _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP:
raise RuntimeError(f"Unknown key {group_key} in Ulysses sequence parallel group!")
return _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP[group_key]
def get_ulysses_sequence_parallel_group_by_key(group_key: str = "default") -> Optional[dist.ProcessGroup]:
"""
Get ulysses sequence parallel process group.
"""
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
if group_key not in _ULYSSES_SEQUENCE_PARALLEL_GROUP:
raise RuntimeError(f"Unknown key {group_key} in Ulysses sequence parallel group!")
return _ULYSSES_SEQUENCE_PARALLEL_GROUP[group_key]
def get_ulysses_sequence_parallel_cpu_group_by_key(group_key: str = "default") -> Optional[dist.ProcessGroup]:
"""
Get ulysses sequence parallel CPU process group.
"""
global _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP
if group_key not in _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP:
raise RuntimeError(f"Unknown key {group_key} in Ulysses sequence parallel group!")
return _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP[group_key]
def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int:
"""
Get ulysses sequence parallel rank.
"""
group = get_ulysses_sequence_parallel_group() if group is None else group
return dist.get_rank(group) if group else 0
def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int:
"""
Get ulysses sequence parallel world size.
"""
group = get_ulysses_sequence_parallel_group() if group is None else group
return dist.get_world_size(group) if group else 1
# ----------------------------- Context Parallel ---------------------------- #
def set_context_parallel_group(cp_group: dist.ProcessGroup):
"""
Set context parallel process group.
"""
global _CONTEXT_PARALLEL_GROUP
_CONTEXT_PARALLEL_GROUP = cp_group
def get_context_parallel_group(check_initialized=True):
"""Get the context parallel group the caller rank belongs to."""
global _CONTEXT_PARALLEL_GROUP
if check_initialized:
assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
return _CONTEXT_PARALLEL_GROUP
def get_context_parallel_rank():
"""Return my rank for the context parallel group."""
if dist.is_available() and dist.is_initialized():
return dist.get_rank(group=get_context_parallel_group())
else:
return 0
def get_context_parallel_world_size():
"""Return world size for the context parallel group."""
if dist.is_available() and dist.is_initialized():
return dist.get_world_size(group=get_context_parallel_group())
else:
return 0
# ----------------------------- Unified Parallel ---------------------------- #
def set_unified_sequence_parallel_group(group: dist.ProcessGroup):
"""
Set unified sequence parallel process group.
"""
global _UNIFIED_SEQUENCE_PARALLEL_GROUP
_UNIFIED_SEQUENCE_PARALLEL_GROUP = group
def set_unified_sequence_parallel_cpu_group(group: dist.ProcessGroup):
"""
Set unified sequence parallel process group.
"""
global _UNIFIED_SEQUENCE_PARALLEL_CPU_GROUP
_UNIFIED_SEQUENCE_PARALLEL_CPU_GROUP = group
def get_unified_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
"""
Get unified sequence parallel process group.
"""
global _UNIFIED_SEQUENCE_PARALLEL_GROUP
return _UNIFIED_SEQUENCE_PARALLEL_GROUP
def get_unified_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]:
"""
Get unified sequence parallel CPU process group.
"""
global _UNIFIED_SEQUENCE_PARALLEL_CPU_GROUP
return _UNIFIED_SEQUENCE_PARALLEL_CPU_GROUP
def get_unified_sequence_parallel_rank() -> int:
"""
Get unified sequence parallel rank.
"""
group = get_unified_sequence_parallel_group()
return dist.get_rank(group) if group else 0
def get_unified_sequence_parallel_world_size() -> int:
"""
Get unified sequence parallel world size.
"""
group = get_unified_sequence_parallel_group()
return dist.get_world_size(group) if group else 1
# ------------------------------- Initialize ------------------------------- #
def init_sequence_parallel(
ulysses_size: int = 1, sep_dp: bool = False, ulysses_group_key: str = "default", cp_size: int = 1
):
"""
Initialize unified sequence parallel.
"""
global _CONTEXT_PARALLEL_GROUP
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
global _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP
set_ulysses_sequence_parallel_group(group=None, group_key="default")
set_ulysses_sequence_parallel_cpu_group(group=None, group_key="default")
if ulysses_size == 1 and cp_size == 1:
return
assert dist.is_initialized()
world_size = dist.get_world_size()
rank = dist.get_rank()
unified_sp_size = ulysses_size * cp_size
assert world_size % unified_sp_size == 0
data_parallel_size = world_size // unified_sp_size
if cp_size > 1:
assert _CONTEXT_PARALLEL_GROUP is None, "Context parallel group has already been initialized!"
if ulysses_size:
assert (ulysses_group_key == "default" and _ULYSSES_SEQUENCE_PARALLEL_GROUP[ulysses_group_key] is None) or (
ulysses_group_key != "default" and ulysses_group_key not in _ULYSSES_SEQUENCE_PARALLEL_GROUP
), f"Ulysses sequence parallel group ({ulysses_group_key}) has already been initialized!"
assert (
ulysses_group_key == "default" and _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP[ulysses_group_key] is None
) or (ulysses_group_key != "default" and ulysses_group_key not in _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP), (
f"Ulysses sequence parallel ({ulysses_group_key}) group has already been initialized!"
)
for i in range(data_parallel_size):
# build ulysses group
if ulysses_size > 1:
for j in range(cp_size):
start_rank = i * unified_sp_size + j * ulysses_size
end_rank = start_rank + ulysses_size
ulysses_ranks = range(start_rank, end_rank)
ulysses_group = dist.new_group(ulysses_ranks)
ulysses_cpu_group = dist.new_group(ulysses_ranks, backend="gloo")
if rank in ulysses_ranks:
set_ulysses_sequence_parallel_group(group=ulysses_group, group_key=ulysses_group_key)
set_ulysses_sequence_parallel_cpu_group(group=ulysses_cpu_group, group_key=ulysses_group_key)
# build cp group
if cp_size > 1:
for j in range(ulysses_size):
cp_global_ranks = range(i * unified_sp_size + j, (i + 1) * unified_sp_size, ulysses_size)
cp_group = dist.new_group(cp_global_ranks)
if rank in cp_global_ranks:
set_context_parallel_group(cp_group=cp_group)
# build unified sp group
unified_sp_ranks = range(i * unified_sp_size, (i + 1) * unified_sp_size)
sp_group = dist.new_group(unified_sp_ranks)
sp_cpu_group = dist.new_group(unified_sp_ranks, backend="gloo")
if rank in unified_sp_ranks:
set_unified_sequence_parallel_group(group=sp_group)
set_unified_sequence_parallel_cpu_group(group=sp_cpu_group)
if sep_dp:
for j in range(unified_sp_size):
dp_ranks = range(j, world_size, unified_sp_size)
dp_group = dist.new_group(dp_ranks)
if rank in dp_ranks:
set_data_parallel_group(dp_group)
class UlyssesGroupKeyManager:
def __init__(self, group_key: str):
self.group_key = group_key
def __enter__(self):
set_ulysses_sequence_parallel_group_key(group_key=self.group_key)
def __exit__(self, *args: Any):
set_ulysses_sequence_parallel_group_key(group_key="default")
def is_ulysses_sequence_parallel_initialized() -> bool:
"""
Check if ulysses sequence parallel is initialized.
"""
return get_ulysses_sequence_parallel_group() is not None
def is_context_parallel_initialized() -> bool:
"""
Check if ulysses sequence parallel is initialized.
"""
return get_context_parallel_group() is not None
def get_ulysses_group_key_context(group_key: str = "default"):
if not isinstance(group_key, str):
raise RuntimeError(f"A Ulysses group key must be specified, now get: {group_key}")
if group_key != "default":
return UlyssesGroupKeyManager(group_key)
else:
return nullcontext()