|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from torch.nn.parallel._functions import _get_stream |
|
|
from torch.nn.parallel.scatter_gather import ( |
|
|
_is_namedtuple |
|
|
) |
|
|
from typing import Any, Dict, List, Tuple |
|
|
|
|
|
__all__ = [] |
|
|
|
|
|
def _pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]: |
|
|
""" |
|
|
Turn argument list into separate key list and value list (unpack_kwargs does the opposite) |
|
|
Inspiration: https://github.com/facebookresearch/fairscale/blob/eeb6684/fairscale/internal/containers.py#L70 |
|
|
Usage:: |
|
|
|
|
|
kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) |
|
|
assert kwarg_keys == ("a", "b") |
|
|
assert flat_args == (1, 2, 3, 4) |
|
|
args, kwargs = unpack_kwargs(kwarg_keys, flat_args) |
|
|
assert args == (1, 2) |
|
|
assert kwargs == {"a": 3, "b": 4} |
|
|
Returns: |
|
|
Tuple[Tuple[Any, ...], Tuple[str, ...]]: The first tuple element gives |
|
|
gives both positional args and kwarg values, where the positional args |
|
|
proceed kwarg values and kwarg values are ordered consistently with the |
|
|
kwarg keys. The second tuple element gives the kwarg keys. |
|
|
The second tuple element's length is at most the first tuple element's length. |
|
|
""" |
|
|
kwarg_keys: List[str] = [] |
|
|
flat_args: List[Any] = list(args) |
|
|
for k, v in kwargs.items(): |
|
|
kwarg_keys.append(k) |
|
|
flat_args.append(v) |
|
|
|
|
|
return tuple(flat_args), tuple(kwarg_keys) |
|
|
|
|
|
|
|
|
def _unpack_kwargs(flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: |
|
|
"""See _pack_kwargs.""" |
|
|
assert len(kwarg_keys) <= len(flat_args), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" |
|
|
if len(kwarg_keys) == 0: |
|
|
return flat_args, {} |
|
|
args = flat_args[: -len(kwarg_keys)] |
|
|
kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])} |
|
|
return args, kwargs |
|
|
|
|
|
def _recursive_to(inputs, target_gpu, use_side_stream_for_tensor_copies): |
|
|
r""" |
|
|
Recursively moves input to the target_gpu. |
|
|
""" |
|
|
|
|
|
def to_map(obj): |
|
|
if isinstance(obj, torch.Tensor): |
|
|
if obj.device == torch.device("cuda", target_gpu): |
|
|
return (obj,) |
|
|
if not use_side_stream_for_tensor_copies: |
|
|
return (obj.to(target_gpu),) |
|
|
else: |
|
|
|
|
|
|
|
|
stream = _get_stream(target_gpu) |
|
|
with torch.cuda.stream(stream): |
|
|
output = obj.to(target_gpu) |
|
|
|
|
|
with torch.cuda.device(target_gpu): |
|
|
current_stream = torch.cuda.current_stream() |
|
|
|
|
|
current_stream.wait_stream(stream) |
|
|
|
|
|
|
|
|
output.record_stream(current_stream) |
|
|
return (output,) |
|
|
if _is_namedtuple(obj): |
|
|
return [type(obj)(*args) for args in zip(*map(to_map, obj))] |
|
|
if isinstance(obj, tuple) and len(obj) > 0: |
|
|
return list(zip(*map(to_map, obj))) |
|
|
if isinstance(obj, list) and len(obj) > 0: |
|
|
return [list(i) for i in zip(*map(to_map, obj))] |
|
|
if isinstance(obj, dict) and len(obj) > 0: |
|
|
return [type(obj)(i) for i in zip(*map(to_map, obj.items()))] |
|
|
return [obj] |
|
|
|
|
|
|
|
|
try: |
|
|
res = to_map(inputs) |
|
|
finally: |
|
|
to_map = None |
|
|
return res |
|
|
|
|
|
|
|
|
def _to_kwargs(inputs, kwargs, device_id, use_side_stream_for_tensor_copies): |
|
|
inputs = ( |
|
|
_recursive_to(inputs, device_id, use_side_stream_for_tensor_copies) |
|
|
if inputs |
|
|
else [] |
|
|
) |
|
|
kwargs = ( |
|
|
_recursive_to(kwargs, device_id, use_side_stream_for_tensor_copies) |
|
|
if kwargs |
|
|
else [] |
|
|
) |
|
|
if len(inputs) < len(kwargs): |
|
|
inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) |
|
|
elif len(kwargs) < len(inputs): |
|
|
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) |
|
|
inputs = tuple(inputs) |
|
|
kwargs = tuple(kwargs) |
|
|
return inputs, kwargs |
|
|
|
|
|
def _verify_param_shape_across_processes(process_group, tensors, logger=None): |
|
|
return dist._verify_params_across_processes(process_group, tensors, logger) |
|
|
|
|
|
def _sync_module_states( |
|
|
module, |
|
|
process_group, |
|
|
broadcast_bucket_size, |
|
|
src, |
|
|
params_and_buffers_to_ignore, |
|
|
): |
|
|
""" |
|
|
Syncs ``module``'s parameters and buffers state so that all ranks contain |
|
|
the same module state across all ranks. Note that this API assumes that all |
|
|
parameter shapes are consistent before running the synchronization. This can |
|
|
be checked with ``_verify_param_shape_across_processes``. |
|
|
""" |
|
|
module_states = [] |
|
|
for name, param in module.named_parameters(): |
|
|
if name not in params_and_buffers_to_ignore: |
|
|
module_states.append(param.detach()) |
|
|
|
|
|
for name, buffer in module.named_buffers(): |
|
|
if name not in params_and_buffers_to_ignore: |
|
|
module_states.append(buffer.detach()) |
|
|
|
|
|
_sync_params_and_buffers( |
|
|
process_group, |
|
|
module_states, |
|
|
broadcast_bucket_size, |
|
|
src |
|
|
) |
|
|
|
|
|
def _sync_params_and_buffers( |
|
|
process_group: dist.ProcessGroup, |
|
|
module_states: List[torch.Tensor], |
|
|
broadcast_bucket_size: int, |
|
|
src: int, |
|
|
): |
|
|
""" |
|
|
Synchronizes ``module_states`` (list of tensors) across all processes by |
|
|
broadcasting them from rank 0. |
|
|
""" |
|
|
if len(module_states) > 0: |
|
|
dist._broadcast_coalesced( |
|
|
process_group, module_states, broadcast_bucket_size, src |
|
|
) |
|
|
|
|
|
def _replace_by_prefix( |
|
|
state_dict: Dict[str, Any], |
|
|
old_prefix: str, |
|
|
new_prefix: str, |
|
|
) -> None: |
|
|
""" |
|
|
Replace all keys that match a given old_prefix with a new_prefix (in-place). |
|
|
|
|
|
Usage:: |
|
|
|
|
|
state_dict = {"layer.xyz": torch.tensor(1)} |
|
|
replace_by_prefix_(state_dict, "layer.", "module.layer.") |
|
|
assert state_dict == {"module.layer.xyz": torch.tensor(1)} |
|
|
""" |
|
|
if old_prefix == new_prefix: |
|
|
raise ValueError("old_prefix and new_prefix must be distinct") |
|
|
for key in list(state_dict.keys()): |
|
|
if not key.startswith(old_prefix): |
|
|
continue |
|
|
new_key = new_prefix + key[len(old_prefix) :] |
|
|
state_dict[new_key] = state_dict[key] |
|
|
del state_dict[key] |
|
|
|