|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
|
import random |
|
|
import socket |
|
|
import sys |
|
|
from collections import Counter |
|
|
from importlib.machinery import SourceFileLoader |
|
|
from pathlib import Path |
|
|
from typing import Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
|
|
|
from internlm.utils.common import SingletonMeta |
|
|
from internlm.utils.logger import get_logger |
|
|
from internlm.utils.timeout import LLM_NCCL_TIMEOUT |
|
|
|
|
|
from . import process_group_initializer as pgroup_initializer |
|
|
from .process_group_initializer import ParallelMode |
|
|
from .random import add_seed, get_seeds, set_mode |
|
|
|
|
|
IS_TENSOR_PARALLEL = "is_tensor_parallel" |
|
|
|
|
|
logger = get_logger(__file__) |
|
|
|
|
|
|
|
|
class Config(dict): |
|
|
"""This is a wrapper class for dict objects so that values of which can be |
|
|
accessed as attributes. |
|
|
|
|
|
Args: |
|
|
config (dict): The dict object to be wrapped. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: dict = None): |
|
|
if config is not None: |
|
|
for k, v in config.items(): |
|
|
self._add_item(k, v) |
|
|
|
|
|
def __missing__(self, key): |
|
|
raise KeyError(key) |
|
|
|
|
|
def __getattr__(self, key): |
|
|
try: |
|
|
value = super().__getitem__(key) |
|
|
return value |
|
|
except KeyError: |
|
|
raise AttributeError(key) |
|
|
|
|
|
def __setattr__(self, key, value): |
|
|
super().__setitem__(key, value) |
|
|
|
|
|
def _add_item(self, key, value): |
|
|
if isinstance(value, dict): |
|
|
self.__setattr__(key, Config(value)) |
|
|
else: |
|
|
self.__setattr__(key, value) |
|
|
|
|
|
def update(self, config): |
|
|
assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects." |
|
|
for k, v in config.items(): |
|
|
self._add_item(k, v) |
|
|
return self |
|
|
|
|
|
@staticmethod |
|
|
def from_file(filename: str) -> object: |
|
|
"""Reads a python file and constructs a corresponding :class:`Config` object. |
|
|
|
|
|
Args: |
|
|
filename (str): Name of the file to construct the return object. |
|
|
|
|
|
Returns: |
|
|
:class:`Config`: A :class:`Config` object constructed with information in the file. |
|
|
|
|
|
Raises: |
|
|
AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file |
|
|
""" |
|
|
|
|
|
|
|
|
if isinstance(filename, str): |
|
|
filepath = Path(filename).absolute() |
|
|
elif isinstance(filename, Path): |
|
|
filepath = filename.absolute() |
|
|
|
|
|
assert filepath.exists(), f"{filename} is not found, please check your configuration path" |
|
|
|
|
|
|
|
|
extension = filepath.suffix |
|
|
assert extension == ".py", "only .py files are supported" |
|
|
|
|
|
|
|
|
remove_path = False |
|
|
if filepath.parent not in sys.path: |
|
|
sys.path.insert(0, (filepath)) |
|
|
remove_path = True |
|
|
|
|
|
module_name = filepath.stem |
|
|
source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath)) |
|
|
module = source_file.load_module() |
|
|
|
|
|
|
|
|
config = Config() |
|
|
|
|
|
for k, v in module.__dict__.items(): |
|
|
if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v): |
|
|
continue |
|
|
else: |
|
|
config._add_item(k, v) |
|
|
|
|
|
|
|
|
del sys.modules[module_name] |
|
|
if remove_path: |
|
|
sys.path.pop(0) |
|
|
|
|
|
return config |
|
|
|
|
|
|
|
|
class ParallelContext(metaclass=SingletonMeta): |
|
|
"""This class provides interface functions for users to get the parallel context, |
|
|
such as the global rank, the local rank, the world size, etc. of each device. |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
self._global_ranks = dict() |
|
|
self._local_ranks = dict() |
|
|
self._world_sizes = dict() |
|
|
self._groups = dict() |
|
|
self._cpu_groups = dict() |
|
|
self._ranks_in_group = dict() |
|
|
|
|
|
|
|
|
self._config = None |
|
|
|
|
|
|
|
|
self.world_size = 1 |
|
|
self.data_parallel_size = 1 |
|
|
self.pipeline_parallel_size = 1 |
|
|
self.tensor_parallel_size = 1 |
|
|
self.zero1_parallel_size = -1 |
|
|
self.nettest_parallel_size = 1 |
|
|
self.num_processes_on_current_node = -1 |
|
|
self.virtual_pipeline_parallel_size = None |
|
|
self.virtual_pipeline_parallel_rank = None |
|
|
|
|
|
@property |
|
|
def config(self): |
|
|
return self._config |
|
|
|
|
|
def load_config(self, config: Union[dict, str]): |
|
|
"""Loads the configuration from either a dict or a file. |
|
|
|
|
|
Args: |
|
|
config (dict or str): Either a dict containing the configuration information or the filename |
|
|
of a file containing the configuration information. |
|
|
|
|
|
Raises: |
|
|
TypeError: Raises a TypeError if `config` is neither a dict nor a str. |
|
|
""" |
|
|
if isinstance(config, str): |
|
|
self._config = Config.from_file(config) |
|
|
elif isinstance(config, dict): |
|
|
self._config = Config(config) |
|
|
else: |
|
|
raise TypeError("Invalid type for config, only dictionary or string is supported") |
|
|
|
|
|
def detect_num_processes_on_current_node(self): |
|
|
hostname = socket.gethostname() |
|
|
hostname_list = [None for _ in range(self.get_world_size(ParallelMode.GLOBAL))] |
|
|
dist.all_gather_object(hostname_list, hostname, group=self.get_group(ParallelMode.GLOBAL)) |
|
|
counter = Counter(hostname_list) |
|
|
self.num_processes_on_current_node = counter[hostname] |
|
|
|
|
|
@staticmethod |
|
|
def _check_parallel_mode(parallel_mode: ParallelMode): |
|
|
assert isinstance( |
|
|
parallel_mode, ParallelMode |
|
|
), f"expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}" |
|
|
|
|
|
def get_global_rank(self): |
|
|
"""Returns the global rank of the current device. |
|
|
|
|
|
Returns: |
|
|
int: The global rank of the current device |
|
|
""" |
|
|
return self._global_ranks[ParallelMode.GLOBAL] |
|
|
|
|
|
def get_local_rank(self, parallel_mode: ParallelMode): |
|
|
"""Returns the local rank of the current device. |
|
|
|
|
|
Args: |
|
|
parallel_mode: The parallel mode for the rank. |
|
|
|
|
|
Returns: |
|
|
int: The local rank of the current device for `parallel_mode`. |
|
|
""" |
|
|
self._check_parallel_mode(parallel_mode) |
|
|
return self._local_ranks.get(parallel_mode, 0) |
|
|
|
|
|
def get_next_global_rank(self, parallel_mode: ParallelMode): |
|
|
"""Returns the global rank of the next device. |
|
|
|
|
|
Args: |
|
|
parallel_mode: The parallel mode for the rank. |
|
|
|
|
|
Returns: |
|
|
int: The global rank of the next device for `parallel_mode`. |
|
|
""" |
|
|
self._check_parallel_mode(parallel_mode) |
|
|
|
|
|
|
|
|
local_rank = self.get_local_rank(parallel_mode) |
|
|
world_size = self.get_world_size(parallel_mode) |
|
|
ranks_in_group = self.get_ranks_in_group(parallel_mode) |
|
|
|
|
|
return ranks_in_group[(local_rank + 1) % world_size] |
|
|
|
|
|
def get_prev_global_rank(self, parallel_mode: ParallelMode): |
|
|
"""Returns the global rank of the previous device. |
|
|
|
|
|
Args: |
|
|
parallel_mode: The chosen parallel mode. |
|
|
|
|
|
Returns: |
|
|
int: The global rank of the previous device for `parallel_mode`. |
|
|
""" |
|
|
self._check_parallel_mode(parallel_mode) |
|
|
|
|
|
|
|
|
local_rank = self.get_local_rank(parallel_mode) |
|
|
world_size = self.get_world_size(parallel_mode) |
|
|
ranks_in_group = self.get_ranks_in_group(parallel_mode) |
|
|
|
|
|
return ranks_in_group[(local_rank - 1) % world_size] |
|
|
|
|
|
def is_using_dp(self): |
|
|
"""Returns a boolean value indicating whether the current device is initilized with |
|
|
ParallelMode.DATA and its world_size is greater than 1. |
|
|
""" |
|
|
return self.is_initialized(ParallelMode.DATA) and self.get_world_size(ParallelMode.DATA) > 1 |
|
|
|
|
|
def is_using_tp(self): |
|
|
"""Returns a boolean value indicating whether the current device is initilized with |
|
|
ParallelMode.TENSOR and its world_size is greater than 1. |
|
|
""" |
|
|
return self.is_initialized(ParallelMode.TENSOR) and self.get_world_size(ParallelMode.TENSOR) > 1 |
|
|
|
|
|
def is_using_pp(self): |
|
|
"""Returns a boolean value indicating whether the current device is initilized with |
|
|
ParallelMode.PIPELINE and its world_size is greater than 1. |
|
|
""" |
|
|
return self.is_initialized(ParallelMode.PIPELINE) and self.get_world_size(ParallelMode.PIPELINE) > 1 |
|
|
|
|
|
def is_using_sequence(self): |
|
|
"""Returns a boolean value indicating whether the current device is initilized with |
|
|
ParallelMode.SEQUENCE and its world_size is greater than 1. |
|
|
""" |
|
|
return False |
|
|
|
|
|
|
|
|
def is_first_rank(self, parallel_mode: ParallelMode): |
|
|
"""Returns a boolean value indicating whether the current device is the first one |
|
|
among its group for `parallel_mode`. |
|
|
|
|
|
Args: |
|
|
parallel_mode: The chosen parallel mode. |
|
|
|
|
|
Returns: |
|
|
bool: a boolean value indicating whether the current device is the first one |
|
|
among its group for `parallel_mode`. |
|
|
""" |
|
|
rank = 0 |
|
|
if self.is_initialized(parallel_mode): |
|
|
rank = self.get_local_rank(parallel_mode) |
|
|
return rank == 0 |
|
|
|
|
|
def is_rank_for_log(self): |
|
|
"""Returns a boolean value indicating whether the current device should print log.""" |
|
|
is_log_rank = ( |
|
|
self.is_first_rank(ParallelMode.DATA) |
|
|
and self.is_first_rank(ParallelMode.TENSOR) |
|
|
and self.is_last_rank(ParallelMode.PIPELINE) |
|
|
) |
|
|
return is_log_rank |
|
|
|
|
|
def is_last_rank(self, parallel_mode: ParallelMode): |
|
|
"""Returns a boolean value indicating whether the current device is the last one |
|
|
among its group for `parallel_mode`. |
|
|
|
|
|
Args: |
|
|
parallel_mode: The chosen parallel mode. |
|
|
|
|
|
Returns: |
|
|
bool: a boolean value indicating whether the current device is the first one |
|
|
among its group for `parallel_mode`. |
|
|
""" |
|
|
rank = 0 |
|
|
world_size = 1 |
|
|
if self.is_initialized(parallel_mode): |
|
|
rank = self.get_local_rank(parallel_mode) |
|
|
world_size = self.get_world_size(parallel_mode) |
|
|
return rank == world_size - 1 |
|
|
|
|
|
def is_pipeline_first_stage(self, ignore_virtual=False): |
|
|
if not ignore_virtual: |
|
|
if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != 0: |
|
|
return False |
|
|
return self.is_first_rank(ParallelMode.PIPELINE) |
|
|
|
|
|
def is_pipeline_last_stage(self, ignore_virtual=False): |
|
|
if not ignore_virtual: |
|
|
if ( |
|
|
self.virtual_pipeline_parallel_size is not None |
|
|
and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1 |
|
|
): |
|
|
return False |
|
|
return self.is_last_rank(ParallelMode.PIPELINE) |
|
|
|
|
|
def get_world_size(self, parallel_mode: ParallelMode): |
|
|
"""Returns the world size for `parallel_mode`. |
|
|
|
|
|
Args: |
|
|
parallel_mode: The chosen parallel mode. |
|
|
|
|
|
Returns: |
|
|
int: The world size for `parallel_mode`. |
|
|
""" |
|
|
self._check_parallel_mode(parallel_mode) |
|
|
return self._world_sizes.get(parallel_mode, 1) |
|
|
|
|
|
def get_group(self, parallel_mode: ParallelMode): |
|
|
"""Returns the group of the current device for `parallel_mode`. |
|
|
|
|
|
Args: |
|
|
parallel_mode: The chosen parallel mode. |
|
|
|
|
|
Returns: |
|
|
torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`. |
|
|
""" |
|
|
self._check_parallel_mode(parallel_mode) |
|
|
return self._groups[parallel_mode] |
|
|
|
|
|
def get_ranks_in_group(self, parallel_mode: ParallelMode): |
|
|
"""Returns the rank of the current device for `parallel_mode` in the group. |
|
|
|
|
|
Args: |
|
|
parallel_mode: The chosen parallel mode. |
|
|
|
|
|
Returns: |
|
|
int: The rank of the current device for `parallel_mode` in the group. |
|
|
""" |
|
|
self._check_parallel_mode(parallel_mode) |
|
|
return self._ranks_in_group[parallel_mode] |
|
|
|
|
|
def get_cpu_group(self, parallel_mode: ParallelMode): |
|
|
self._check_parallel_mode(parallel_mode) |
|
|
return self._cpu_groups[parallel_mode] |
|
|
|
|
|
def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int, use_cpu: bool = False): |
|
|
"""Initializes the global distributed environment |
|
|
|
|
|
Args: |
|
|
rank (int): rank for the default process group. |
|
|
world_size (int): world size of the default process group. |
|
|
backend (str): backend for ``torch.distributed`` |
|
|
host (str): the master address for distributed training. |
|
|
port (str): the master port for distributed training. |
|
|
use_cpu (bool): whether to set up cpu process group. |
|
|
""" |
|
|
|
|
|
init_method = f"tcp://[{host}]:{port}" |
|
|
dist.init_process_group( |
|
|
rank=rank, |
|
|
world_size=world_size, |
|
|
backend=backend, |
|
|
init_method=init_method, |
|
|
timeout=LLM_NCCL_TIMEOUT, |
|
|
) |
|
|
|
|
|
|
|
|
ranks = list(range(world_size)) |
|
|
if use_cpu: |
|
|
cpu_group = ( |
|
|
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) |
|
|
if dist.get_backend() != "gloo" |
|
|
else None |
|
|
) |
|
|
else: |
|
|
cpu_group = None |
|
|
self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL) |
|
|
self._global_ranks[ParallelMode.GLOBAL] = rank |
|
|
|
|
|
def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode): |
|
|
self._check_parallel_mode(mode) |
|
|
self._local_ranks[mode] = local_rank |
|
|
self._world_sizes[mode] = world_size |
|
|
self._groups[mode] = process_group |
|
|
self._cpu_groups[mode] = cpu_group |
|
|
self._ranks_in_group[mode] = ranks_in_group |
|
|
|
|
|
def check_sanity(self): |
|
|
"""Checks sanity of the parallel context. |
|
|
|
|
|
Raises: |
|
|
AssertionError: Raises an AssertionError if the world size does not equal to the product |
|
|
of data parallel size, pipeline parallel size and tensor parallel size. |
|
|
""" |
|
|
dps = self.data_parallel_size |
|
|
pps = self.pipeline_parallel_size |
|
|
tps = self.tensor_parallel_size |
|
|
ws = self.world_size |
|
|
assert ws == dps * pps * tps, ( |
|
|
f"Expected the world size {ws} to be equal to data" |
|
|
f" parallel size ({dps}) * pipeline parallel size " |
|
|
f"({pps}) * tensor parallel size ({tps})" |
|
|
) |
|
|
assert self.zero1_parallel_size > 0 |
|
|
assert self.data_parallel_size % self.zero1_parallel_size == 0 |
|
|
|
|
|
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): |
|
|
if key in config: |
|
|
ele = config[key] |
|
|
if isinstance(ele, int): |
|
|
setattr(self, attr_name, ele) |
|
|
elif isinstance(ele, dict): |
|
|
setattr(self, attr_name, ele["size"]) |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f'{"Parallel configuration does not support this kind of argument, please use int or dict"}' |
|
|
) |
|
|
|
|
|
def init_parallel_groups(self): |
|
|
"""Initializes the parallel groups.""" |
|
|
|
|
|
|
|
|
rank = self.get_global_rank() |
|
|
world_size = self.get_world_size(ParallelMode.GLOBAL) |
|
|
self.world_size = world_size |
|
|
|
|
|
|
|
|
parallel_config = self.config.get("parallel", None) |
|
|
if parallel_config is not None: |
|
|
self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size") |
|
|
self._set_parallel_size_from_config(parallel_config, "tensor", "tensor_parallel_size") |
|
|
self._set_parallel_size_from_config(parallel_config, "zero1", "zero1_parallel_size") |
|
|
|
|
|
|
|
|
|
|
|
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size) |
|
|
|
|
|
|
|
|
self.nettest_parallel_size = 32 |
|
|
|
|
|
if self.zero1_parallel_size <= 0: |
|
|
self.zero1_parallel_size = self.data_parallel_size |
|
|
|
|
|
self.check_sanity() |
|
|
|
|
|
initializer_args = [ |
|
|
rank, |
|
|
world_size, |
|
|
self.data_parallel_size, |
|
|
self.pipeline_parallel_size, |
|
|
self.tensor_parallel_size, |
|
|
self.zero1_parallel_size, |
|
|
self.nettest_parallel_size, |
|
|
] |
|
|
|
|
|
|
|
|
initializers = [] |
|
|
initializers.append(pgroup_initializer.Initializer_Data(*initializer_args)) |
|
|
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args)) |
|
|
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args)) |
|
|
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args)) |
|
|
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args)) |
|
|
if self.pipeline_parallel_size > 1: |
|
|
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args)) |
|
|
for initializer in initializers: |
|
|
parallel_setting = initializer.init_dist_group() |
|
|
if isinstance(parallel_setting, list): |
|
|
for args in parallel_setting: |
|
|
self._register_dist(*args) |
|
|
else: |
|
|
self._register_dist(*parallel_setting) |
|
|
|
|
|
def is_initialized(self, parallel_mode: ParallelMode): |
|
|
"""Returns a boolean value indicating whether `parallel_mode` is initialized |
|
|
in the current system. |
|
|
""" |
|
|
return parallel_mode in self._groups |
|
|
|
|
|
def destroy(self): |
|
|
"""Destroys the current distributed parallel environment.""" |
|
|
for mode, group in self._groups.items(): |
|
|
if mode is not ParallelMode.GLOBAL: |
|
|
dist.destroy_process_group(group) |
|
|
|
|
|
dist.destroy_process_group() |
|
|
self._groups.clear() |
|
|
|
|
|
def set_device(self, device_ordinal: int = None): |
|
|
"""Sets distributed processes to be bound to devices. |
|
|
|
|
|
Args: |
|
|
device_ordinal (int, optional): the device id to be bound to |
|
|
""" |
|
|
global_rank = self.get_global_rank() |
|
|
if device_ordinal is None: |
|
|
devices_per_node = torch.cuda.device_count() |
|
|
device_ordinal = global_rank % devices_per_node |
|
|
|
|
|
torch.cuda.set_device(device_ordinal) |
|
|
logger.info(f"process rank {global_rank} is bound to host:{socket.gethostname()} device: {device_ordinal}") |
|
|
|
|
|
def set_seed(self, seed: int, dpseed_with_tpoffset: bool = False): |
|
|
"""Sets seeds for all random libraries. |
|
|
|
|
|
Args: |
|
|
seed (int): seed for random states |
|
|
""" |
|
|
pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0) |
|
|
global_rank = self.get_global_rank() |
|
|
|
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
assert torch.cuda.is_available() |
|
|
|
|
|
|
|
|
dp_seed = seed |
|
|
if dpseed_with_tpoffset: |
|
|
dp_seed = seed + pipeline_offset * 1024 |
|
|
add_seed(ParallelMode.DATA, dp_seed) |
|
|
add_seed(ParallelMode.DUMMY, dp_seed) |
|
|
|
|
|
|
|
|
if self.is_initialized(ParallelMode.TENSOR): |
|
|
tp_rank = self.get_local_rank(ParallelMode.TENSOR) |
|
|
tp_seed = seed + tp_rank + pipeline_offset * 1024 |
|
|
add_seed(ParallelMode.TENSOR, tp_seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
set_mode(ParallelMode.DUMMY) |
|
|
|
|
|
seeds = get_seeds() |
|
|
seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()]) |
|
|
logger.info( |
|
|
f"initialized seed on rank {global_rank}, " |
|
|
f"numpy: {seed}, python random: {seed}, {seed_str}," |
|
|
f"the default parallel seed is {ParallelMode.DATA}." |
|
|
) |
|
|
|
|
|
def set_virtual_pipeline_parallel_size(self, size): |
|
|
self.virtual_pipeline_parallel_size = size |
|
|
|
|
|
def set_virtual_pipeline_parallel_rank(self, rank): |
|
|
self.virtual_pipeline_parallel_rank = rank |
|
|
|
|
|
|
|
|
global_context = ParallelContext() |
|
|
|