|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
import logging |
|
|
import os.path |
|
|
from io import BytesIO |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, Union |
|
|
|
|
|
import numpy |
|
|
|
|
|
|
|
|
import tensorstore |
|
|
import torch |
|
|
from torch.distributed.checkpoint import FileSystemReader, load |
|
|
from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata |
|
|
|
|
|
from nemo.export.tarutils import TarPath, ZarrPathStore |
|
|
from nemo.export.utils._mock_import import _mock_import |
|
|
|
|
|
LOGGER = logging.getLogger("NeMo") |
|
|
|
|
|
|
|
|
def nemo_to_path(nemo_checkpoint: Union[Path, str]) -> Union[Path, TarPath]: |
|
|
""" |
|
|
Creates Path / TarPath object suitable for navigating inside the nemo checkpoint. |
|
|
|
|
|
Args: |
|
|
nemo_checkpoint (Path, str): Path to the NeMo checkpoint. |
|
|
Returns: |
|
|
Path | TarPath: Suitable Path object for navigating through the checkpoint. |
|
|
""" |
|
|
string_path = str(nemo_checkpoint) |
|
|
|
|
|
if os.path.isdir(string_path): |
|
|
return Path(string_path) |
|
|
return TarPath(string_path) |
|
|
|
|
|
|
|
|
class TarFileSystemReader(FileSystemReader): |
|
|
"""Reader that accepts both Path and TarPath checkpoint directory. |
|
|
|
|
|
The FileSystemReader works with TarPath, but expects a pure Path. |
|
|
It's enough to skip the Path check in __init__. |
|
|
""" |
|
|
|
|
|
def __init__(self, path: Union[Path, TarPath]) -> None: |
|
|
"""Makes sure that super().__init__ gets a pure path as expected.""" |
|
|
super_path = str(path) if isinstance(path, TarPath) else path |
|
|
super().__init__(super_path) |
|
|
if isinstance(path, TarPath): |
|
|
self.path = path |
|
|
|
|
|
|
|
|
def load_sharded_metadata_torch_dist( |
|
|
checkpoint_dir: Union[Path, TarPath], load_extra_states: bool = False |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Loads model state dictionary from torch_dist checkpoint. |
|
|
|
|
|
Args: |
|
|
checkpoint_dir (Path | TarPath): Path to the model weights directory. |
|
|
load_extra_states (bool): If set to true, loads BytesIO objects, related to the extra states. |
|
|
Returns: |
|
|
dict: Loaded model state dictionary (weights are stored in torch tensors). |
|
|
""" |
|
|
fs_reader = TarFileSystemReader(checkpoint_dir) |
|
|
metadata = fs_reader.read_metadata() |
|
|
|
|
|
state_dict = { |
|
|
k: torch.empty(tp.size, dtype=tp.properties.dtype) |
|
|
for k, tp in metadata.state_dict_metadata.items() |
|
|
if isinstance(tp, TensorStorageMetadata) |
|
|
} |
|
|
|
|
|
if load_extra_states: |
|
|
state_dict.update( |
|
|
{k: [] for k, tp in metadata.state_dict_metadata.items() if isinstance(tp, BytesStorageMetadata)} |
|
|
) |
|
|
|
|
|
load(state_dict, storage_reader=fs_reader) |
|
|
return state_dict |
|
|
|
|
|
|
|
|
def load_sharded_pickle_extra_state_scale(dir: Union[Path, TarPath]) -> Dict[str, BytesIO]: |
|
|
""" |
|
|
Loads model extra states from the .pt shards. |
|
|
|
|
|
Args: |
|
|
dir (Path | TarPath): Path to the directory with sharded extra states. |
|
|
Returns: |
|
|
dict: State dictionary corresponding to the loaded extra states. |
|
|
""" |
|
|
pt_files = list(dir.glob('shard_*_*.pt')) |
|
|
extra_states = {} |
|
|
for file in pt_files: |
|
|
shard_name = file.name.split('.')[0] |
|
|
with file.open('rb') as opened_file: |
|
|
extra_states[dir.name + '/' + shard_name] = torch.load(opened_file, weights_only=True) |
|
|
|
|
|
return extra_states |
|
|
|
|
|
|
|
|
def contains_extra_states(subdir: Union[Path, TarPath]) -> bool: |
|
|
""" |
|
|
Checks if zarr directory contains extra states. |
|
|
|
|
|
Args: |
|
|
subdir (Path | TarPath): Directory inside the zarr checkpoint. |
|
|
Returns: |
|
|
bool: Is a directory with extra states |
|
|
""" |
|
|
return list(subdir.glob('shard_0_*.pt')) != [] |
|
|
|
|
|
|
|
|
def load_sharded_metadata_zarr( |
|
|
checkpoint_dir: Union[Path, TarPath], load_extra_states: bool = False |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Loads model dictionary from the zarr format. |
|
|
|
|
|
Args: |
|
|
checkpoint_dir (Path | TarPath): Path to the NeMo checkpoint. |
|
|
load_extra_states (bool): If set to True, the function will load BufferIO objects with extra states. |
|
|
Returns: |
|
|
dict: Model state dictionary. |
|
|
""" |
|
|
if load_extra_states: |
|
|
torch.serialization.add_safe_globals([BytesIO]) |
|
|
|
|
|
sharded_state_dict = {} |
|
|
for subdir in checkpoint_dir.iterdir(): |
|
|
if not subdir.is_dir(): |
|
|
continue |
|
|
|
|
|
if load_extra_states and contains_extra_states(subdir): |
|
|
sharded_state_dict.update(load_sharded_pickle_extra_state_scale(subdir)) |
|
|
|
|
|
elif (subdir / '.zarray').exists(): |
|
|
key = subdir.name |
|
|
zstore = ZarrPathStore(subdir) |
|
|
|
|
|
import zarr |
|
|
|
|
|
arr = zarr.open(zstore, 'r') |
|
|
|
|
|
if arr.dtype.name == "bfloat16": |
|
|
sharded_state_dict[key] = torch.from_numpy(arr[:].view(numpy.int16)).view(torch.bfloat16) |
|
|
else: |
|
|
sharded_state_dict[key] = torch.from_numpy(arr[:]) |
|
|
|
|
|
return sharded_state_dict |
|
|
|
|
|
|
|
|
def nemo_weights_directory(nemo_path: Union[Path, TarPath]) -> Union[Path, TarPath]: |
|
|
""" |
|
|
Returns a Path pointing to the weights directory inside the NeMo checkpoint. |
|
|
|
|
|
Args: |
|
|
nemo_path (Path | TarPath): Path to the nemo checkpoint. |
|
|
Returns: |
|
|
Path | TarPath: Path to the weights directory inside the model checkpoint. |
|
|
""" |
|
|
if (nemo_path / "model_weights").exists(): |
|
|
return nemo_path / "model_weights" |
|
|
|
|
|
if (nemo_path / "weights").exists(): |
|
|
return nemo_path / "weights" |
|
|
|
|
|
return nemo_path |
|
|
|
|
|
|
|
|
def load_model_weights(checkpoint_path: Union[str, Path], load_extra_states: bool = False) -> Dict[str, Any]: |
|
|
""" |
|
|
Loads NeMo state dictionary. Weights are stored in torch.Tensor |
|
|
|
|
|
Args: |
|
|
checkpoint_path (str | Path): Path to the NeMo checkpoint. |
|
|
load_extra_states (bool): If True, loads BytesIO objects, corresponding to the extra states. |
|
|
Returns: |
|
|
dict: Model state dictionary. |
|
|
""" |
|
|
|
|
|
nemo_path = nemo_to_path(checkpoint_path) |
|
|
nemo_weights = nemo_weights_directory(nemo_path) |
|
|
|
|
|
with (nemo_weights / 'metadata.json').open(mode='r') as f: |
|
|
config_dict = json.load(f) |
|
|
|
|
|
if config_dict['sharded_backend'] == 'zarr': |
|
|
return load_sharded_metadata_zarr(nemo_weights, load_extra_states=load_extra_states) |
|
|
elif config_dict['sharded_backend'] == 'torch_dist': |
|
|
|
|
|
with _mock_import("megatron.core.dist_checkpointing.strategies.torch"): |
|
|
return load_sharded_metadata_torch_dist(nemo_weights, load_extra_states=load_extra_states) |
|
|
|
|
|
raise NotImplementedError(f'Distributed checkpoint backend {config_dict["sharded_backend"]} not supported') |
|
|
|