|
|
|
|
|
from collections.abc import Collection, Mapping, MutableMapping |
|
|
from typing import Callable, cast, Optional, TypeVar, Union |
|
|
|
|
|
import torch |
|
|
from torch.distributed._shard.sharded_tensor.api import ShardedTensor |
|
|
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE |
|
|
from torch.distributed.tensor import DTensor |
|
|
|
|
|
|
|
|
PATH_ITEM = Union[str, int] |
|
|
OBJ_PATH = tuple[PATH_ITEM, ...] |
|
|
T = TypeVar("T") |
|
|
|
|
|
STATE_DICT_ITEM = object |
|
|
CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM] |
|
|
|
|
|
__all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"] |
|
|
|
|
|
|
|
|
def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool: |
|
|
return isinstance(value, torch.Tensor) |
|
|
|
|
|
|
|
|
|
|
|
def traverse_state_dict( |
|
|
state_dict: STATE_DICT_TYPE, |
|
|
visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], |
|
|
keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, |
|
|
) -> None: |
|
|
""" |
|
|
Invoke ``visitor`` for each value recursively in ``state_dict``. |
|
|
Mapping will be traversed and ``visitor`` will be applied to the leaf elements. |
|
|
``visitor`` will only be applied to elements in a list or a tuple, if the |
|
|
container contains tensors or mappings. |
|
|
""" |
|
|
|
|
|
def _is_terminal(value: STATE_DICT_ITEM) -> bool: |
|
|
values: Collection[STATE_DICT_ITEM] |
|
|
if isinstance(value, Mapping): |
|
|
return False |
|
|
elif isinstance(value, list): |
|
|
values = value |
|
|
else: |
|
|
return True |
|
|
|
|
|
for entry in values: |
|
|
if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): |
|
|
return False |
|
|
if keep_traversing is not None and keep_traversing(entry): |
|
|
return False |
|
|
return True |
|
|
|
|
|
def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: |
|
|
if isinstance(value, Mapping): |
|
|
for k, v in value.items(): |
|
|
_traverse_obj(path + (str(k),), v) |
|
|
elif _is_terminal(value): |
|
|
visitor(path, value) |
|
|
elif isinstance(value, (list, tuple)): |
|
|
for i, v in enumerate(value): |
|
|
_traverse_obj(path + (i,), v) |
|
|
|
|
|
for key, value in state_dict.items(): |
|
|
_traverse_obj((str(key),), value) |
|
|
|
|
|
|
|
|
def traverse_state_dict_v_2_3( |
|
|
state_dict: STATE_DICT_TYPE, |
|
|
visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], |
|
|
keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, |
|
|
) -> None: |
|
|
""" |
|
|
Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates |
|
|
to false for all elements. |
|
|
By default, all collections with at least one ``torch.Tensor`` element are traversed. |
|
|
Visitor takes a path argument that is a tuple of the keys used to reach it. |
|
|
""" |
|
|
|
|
|
|
|
|
def _is_terminal(value: STATE_DICT_ITEM) -> bool: |
|
|
values: Collection[STATE_DICT_ITEM] |
|
|
if isinstance(value, Mapping): |
|
|
values = value.values() |
|
|
elif isinstance(value, list): |
|
|
values = value |
|
|
else: |
|
|
return True |
|
|
|
|
|
for entry in values: |
|
|
if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): |
|
|
return False |
|
|
if keep_traversing is not None and keep_traversing(entry): |
|
|
return False |
|
|
return True |
|
|
|
|
|
def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: |
|
|
if _is_terminal(value): |
|
|
visitor(path, value) |
|
|
elif isinstance(value, Mapping): |
|
|
for k, v in value.items(): |
|
|
_traverse_obj(path + (str(k),), v) |
|
|
elif isinstance(value, list): |
|
|
for i, v in enumerate(value): |
|
|
_traverse_obj(path + (i,), v) |
|
|
|
|
|
for key, value in state_dict.items(): |
|
|
_traverse_obj((str(key),), value) |
|
|
|
|
|
|
|
|
def set_element( |
|
|
root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM |
|
|
) -> None: |
|
|
"""Set ``value`` in ``root_dict`` along the ``path`` object path.""" |
|
|
cur_container = cast(CONTAINER_TYPE, root_dict) |
|
|
|
|
|
def extend_list(lst: list[STATE_DICT_ITEM], idx: int) -> None: |
|
|
while len(lst) <= idx: |
|
|
lst.append(None) |
|
|
|
|
|
for i in range(1, len(path)): |
|
|
prev_key = path[i - 1] |
|
|
key = path[i] |
|
|
def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else []) |
|
|
|
|
|
if isinstance(cur_container, Mapping): |
|
|
cur_container = cast( |
|
|
CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val) |
|
|
) |
|
|
else: |
|
|
extend_list(cur_container, prev_key) |
|
|
if cur_container[prev_key] is None: |
|
|
cur_container[prev_key] = def_val |
|
|
cur_container = cur_container[prev_key] |
|
|
|
|
|
key = path[-1] |
|
|
if type(key) == int: |
|
|
extend_list(cast(list[STATE_DICT_ITEM], cur_container), key) |
|
|
|
|
|
cur_container[key] = value |
|
|
|
|
|
|
|
|
def get_element( |
|
|
root_dict: STATE_DICT_TYPE, |
|
|
path: OBJ_PATH, |
|
|
default_value: Optional[T] = None, |
|
|
) -> Optional[T]: |
|
|
"""Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found.""" |
|
|
cur_value = cast(CONTAINER_TYPE, root_dict) |
|
|
for part in path: |
|
|
if type(part) is int: |
|
|
if not isinstance(cur_value, list) or len(cur_value) < part: |
|
|
return default_value |
|
|
elif not isinstance(cur_value, Mapping) or part not in cur_value: |
|
|
return default_value |
|
|
|
|
|
cur_value = cast(CONTAINER_TYPE, cur_value[part]) |
|
|
return cast(Optional[T], cur_value) |
|
|
|
|
|
|
|
|
def _print_nested( |
|
|
value: STATE_DICT_ITEM, |
|
|
prefix: str = "", |
|
|
print_fun: Callable[[str], None] = print, |
|
|
) -> None: |
|
|
if type(value) is ShardedTensor: |
|
|
print_fun(f"{prefix} ShardedTensor size: {value.size()}") |
|
|
for shard in value.local_shards(): |
|
|
_print_nested( |
|
|
shard.tensor, |
|
|
f"{shard.metadata.shard_offsets} ", |
|
|
print_fun=print_fun, |
|
|
) |
|
|
elif type(value) is (DTensor): |
|
|
print_fun(f"{prefix} DistributedTensor size: {value.size()}") |
|
|
|
|
|
_print_nested( |
|
|
value._local_tensor, |
|
|
print_fun=print_fun, |
|
|
) |
|
|
elif isinstance(value, torch.Tensor): |
|
|
print_fun(f"{prefix} Tensor size: {value.size()}") |
|
|
else: |
|
|
print_fun(f"{prefix} Type: {type(value)}") |
|
|
|
|
|
|
|
|
def print_tensor( |
|
|
path: OBJ_PATH, |
|
|
value: STATE_DICT_ITEM, |
|
|
print_fun: Callable[[str], None] = print, |
|
|
) -> None: |
|
|
""" |
|
|
Use this callback with traverse_state_dict to print its content. |
|
|
|
|
|
By default the content is printed using the builtin ``print`` but this can |
|
|
be change by passing a different ``print_fun` callable. |
|
|
""" |
|
|
_print_nested(value, prefix=str(path), print_fun=print_fun) |
|
|
|