|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
|
import itertools |
|
|
import os |
|
|
from collections import defaultdict |
|
|
from contextlib import contextmanager |
|
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Mapping, Optional, Protocol, TypeVar |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from nemo.lightning.megatron_init import initialize_model_parallel_for_nemo |
|
|
from nemo.utils import logging |
|
|
|
|
|
NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE = "NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE" |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from lightning.fabric.utilities.types import Optimizable |
|
|
from megatron.core.model_parallel_config import ModelParallelConfig |
|
|
|
|
|
|
|
|
class SharedStateDictProtocol(Protocol): |
|
|
""" """ |
|
|
|
|
|
def sharded_state_dict(self, prefix=""): |
|
|
""" """ |
|
|
... |
|
|
|
|
|
|
|
|
def init_parallel_ranks( |
|
|
world_size: int, |
|
|
global_rank: int, |
|
|
local_rank: int, |
|
|
parallel_config: "ModelParallelConfig", |
|
|
seed=1234, |
|
|
fp8=False, |
|
|
) -> None: |
|
|
""" |
|
|
Initializes the parallel ranks for distributed training. |
|
|
|
|
|
This function sets up the parallel ranks based on the provided world size, global rank, local rank, |
|
|
and parallel configuration. It also sets the seed for random number generation and determines whether |
|
|
to use fp8 precision. |
|
|
|
|
|
Args: |
|
|
world_size (int): The total number of processes participating in the distributed training. |
|
|
global_rank (int): The rank of the current process in the distributed training setup. |
|
|
local_rank (int): The rank of the current process within its machine. |
|
|
parallel_config (ModelParallelConfig): The configuration object containing settings for model parallelism. |
|
|
seed (int, optional): The seed for random number generation. Defaults to 1234. |
|
|
fp8 (bool, optional): Whether to use fp8 precision for model parameters. Defaults to False. |
|
|
""" |
|
|
from nemo.utils import AppState |
|
|
|
|
|
app_state = AppState() |
|
|
|
|
|
if os.environ.get(NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, "false").lower() == "true": |
|
|
init_world_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size |
|
|
init_global_rank = app_state.global_rank |
|
|
init_local_rank = app_state.local_rank |
|
|
else: |
|
|
init_world_size = world_size |
|
|
pp = parallel_config.pipeline_model_parallel_size or 1 |
|
|
if world_size < pp: |
|
|
raise ValueError(f"Expected world_size ({world_size}) to be greater than/equal to pipeline size ({pp})") |
|
|
init_global_rank = global_rank |
|
|
init_local_rank = local_rank |
|
|
|
|
|
initialize_model_parallel_for_nemo( |
|
|
world_size=init_world_size, |
|
|
global_rank=init_global_rank, |
|
|
local_rank=init_local_rank, |
|
|
tensor_model_parallel_size=parallel_config.tensor_model_parallel_size, |
|
|
expert_model_parallel_size=parallel_config.expert_model_parallel_size, |
|
|
expert_tensor_parallel_size=parallel_config.expert_tensor_parallel_size, |
|
|
pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size, |
|
|
pipeline_model_parallel_comm_backend=parallel_config.pipeline_model_parallel_comm_backend, |
|
|
virtual_pipeline_model_parallel_size=parallel_config.virtual_pipeline_model_parallel_size, |
|
|
context_parallel_size=parallel_config.context_parallel_size, |
|
|
seed=seed, |
|
|
use_fp8=fp8, |
|
|
init_mpi_proc_group=getattr(parallel_config, "tp_comm_overlap", False) |
|
|
and getattr(parallel_config, "tp_comm_bootstrap_backend", None) == 'mpi', |
|
|
use_te_rng_tracker=getattr(parallel_config, "use_te_rng_tracker", False), |
|
|
use_sharp=getattr(parallel_config, "use_sharp", False), |
|
|
use_tp_pp_dp_mapping=getattr(parallel_config, "use_tp_pp_dp_mapping", False), |
|
|
num_distributed_optimizer_instances=getattr(parallel_config, "num_distributed_optimizer_instances", 1), |
|
|
nccl_communicator_config_path=getattr(parallel_config, "nccl_communicator_config_path", None), |
|
|
use_gloo_process_groups=getattr(parallel_config, "use_gloo_process_groups", True), |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
def init_model_parallel(model: Optional[nn.Module] = None) -> None: |
|
|
"""Initializes Megatron-LM model parallel if using model parallelism.""" |
|
|
import torch.distributed |
|
|
from megatron.core import parallel_state |
|
|
|
|
|
from nemo.utils import AppState |
|
|
|
|
|
app_state = AppState() |
|
|
|
|
|
|
|
|
|
|
|
if app_state.model_parallel_size is not None: |
|
|
|
|
|
|
|
|
parallel_state.destroy_model_parallel() |
|
|
if torch.distributed.is_initialized(): |
|
|
parallel_state.initialize_model_parallel( |
|
|
tensor_model_parallel_size=app_state.tensor_model_parallel_size, |
|
|
pipeline_model_parallel_size=app_state.pipeline_model_parallel_size, |
|
|
virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size, |
|
|
pipeline_model_parallel_comm_backend=app_state.pipeline_model_parallel_comm_backend, |
|
|
context_parallel_size=app_state.context_parallel_size, |
|
|
expert_model_parallel_size=app_state.expert_model_parallel_size, |
|
|
expert_tensor_parallel_size=app_state.expert_tensor_parallel_size, |
|
|
use_sharp=app_state.use_sharp, |
|
|
order="tp-cp-ep-pp-dp" if app_state.use_tp_pp_dp_mapping else "tp-cp-ep-dp-pp", |
|
|
num_distributed_optimizer_instances=app_state.num_distributed_optimizer_instances, |
|
|
nccl_communicator_config_path=app_state.nccl_communicator_config_path, |
|
|
create_gloo_process_groups=app_state.use_gloo_process_groups, |
|
|
) |
|
|
|
|
|
|
|
|
assert app_state.tensor_model_parallel_rank == parallel_state.get_tensor_model_parallel_rank() |
|
|
assert app_state.pipeline_model_parallel_rank == parallel_state.get_pipeline_model_parallel_rank() |
|
|
assert app_state.expert_tensor_parallel_rank == parallel_state.get_expert_tensor_parallel_rank() |
|
|
|
|
|
app_state.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group() |
|
|
app_state.data_parallel_group = parallel_state.get_data_parallel_group() |
|
|
app_state.data_parallel_rank = parallel_state.get_data_parallel_rank() |
|
|
app_state.data_parallel_size = parallel_state.get_data_parallel_world_size() |
|
|
app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group() |
|
|
|
|
|
|
|
|
if app_state.init_mpi_proc_group: |
|
|
torch.distributed.new_group(backend="mpi") |
|
|
|
|
|
|
|
|
def set_model_parallel_attributes(model, parallelism): |
|
|
""" """ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from megatron.core.transformer.transformer_config import TransformerConfig |
|
|
|
|
|
has_mcore_config = isinstance(getattr(model, "config", None), TransformerConfig) |
|
|
if has_mcore_config and hasattr(model, "configure_model"): |
|
|
config: TransformerConfig = model.config |
|
|
for attr_name in filter(lambda x: not x.startswith('__'), dir(parallelism)): |
|
|
if not hasattr(config, attr_name): |
|
|
continue |
|
|
setattr(config, attr_name, getattr(parallelism, attr_name)) |
|
|
if hasattr(config, "__io__"): |
|
|
setattr(config.__io__, attr_name, getattr(parallelism, attr_name)) |
|
|
if hasattr(config, '__post_init__'): |
|
|
|
|
|
|
|
|
config.__post_init__() |
|
|
|
|
|
return config |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def megatron_lazy_init_context(config) -> Generator[None, None, None]: |
|
|
""" """ |
|
|
try: |
|
|
from megatron.core.extensions import transformer_engine as _te |
|
|
|
|
|
original = _te._get_extra_te_kwargs |
|
|
|
|
|
def _get_extra_te_kwargs_meta(c): |
|
|
"""Forces device to meta""" |
|
|
kwargs = original(c) |
|
|
kwargs['device'] = 'meta' |
|
|
return kwargs |
|
|
|
|
|
_te._get_extra_te_kwargs = _get_extra_te_kwargs_meta |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
_orig_perform_initialization = config.perform_initialization |
|
|
_orig_use_cpu_initialization = config.use_cpu_initialization |
|
|
|
|
|
config.perform_initialization = False |
|
|
config.use_cpu_initialization = True |
|
|
|
|
|
yield |
|
|
|
|
|
try: |
|
|
from megatron.core.extensions import transformer_engine as _te |
|
|
|
|
|
_te._get_extra_te_kwargs = original |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
config.perform_initialization = _orig_perform_initialization |
|
|
config.use_cpu_initialization = _orig_use_cpu_initialization |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def megatron_cpu_init_context(config) -> Generator[None, None, None]: |
|
|
""" """ |
|
|
_orig_use_cpu_initialization = config.use_cpu_initialization |
|
|
|
|
|
config.use_cpu_initialization = True |
|
|
|
|
|
yield |
|
|
|
|
|
config.use_cpu_initialization = _orig_use_cpu_initialization |
|
|
|
|
|
|
|
|
ModelT = TypeVar("ModelT", bound=nn.Module) |
|
|
|
|
|
|
|
|
class GradScaler(torch.cuda.amp.GradScaler): |
|
|
""" |
|
|
Gradient sclaer for model-parallel inf check. The inf in gradients are checked across tensor-parallel |
|
|
ranks in (1) executing optimizer step and (2) gradient scaler update. |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
init_scale=2.0**16, |
|
|
growth_factor=2.0, |
|
|
backoff_factor=0.5, |
|
|
growth_interval=2000, |
|
|
enabled=True, |
|
|
hysteresis=1, |
|
|
): |
|
|
super().__init__( |
|
|
init_scale=init_scale, |
|
|
growth_factor=growth_factor, |
|
|
backoff_factor=backoff_factor, |
|
|
growth_interval=growth_interval, |
|
|
enabled=enabled, |
|
|
) |
|
|
self.optimizer_update_skipped: Optional[bool] = None |
|
|
self.hysteresis = hysteresis |
|
|
self._hysteresis_tracker = self.hysteresis |
|
|
|
|
|
def _unscale_grads_(self, optimizer, *args): |
|
|
if getattr(optimizer, "_custom_amp_unscale_grads", False): |
|
|
return optimizer.unscale_grads(*args) |
|
|
else: |
|
|
return super()._unscale_grads_(optimizer, *args) |
|
|
|
|
|
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): |
|
|
from megatron.core import parallel_state |
|
|
|
|
|
retval = None |
|
|
found_inf = torch.cuda.FloatTensor([sum(v.item() for v in optimizer_state["found_inf_per_device"].values())]) |
|
|
|
|
|
|
|
|
torch.distributed.all_reduce( |
|
|
found_inf, |
|
|
op=torch.distributed.ReduceOp.MAX, |
|
|
group=parallel_state.get_model_parallel_group(), |
|
|
) |
|
|
|
|
|
if found_inf.item() == 0: |
|
|
retval = optimizer.step(*args, **kwargs) |
|
|
self.optimizer_update_skipped = False |
|
|
else: |
|
|
self.optimizer_update_skipped = True |
|
|
return retval |
|
|
|
|
|
def update(self, new_scale=None): |
|
|
""" |
|
|
Updates to native grad scaler update function. |
|
|
1. Check inf across model-parallel ranks. |
|
|
2. Update hysteresis tracker. |
|
|
3. Apply hysteresis to grad scale update. |
|
|
""" |
|
|
from megatron.core import parallel_state |
|
|
|
|
|
if not self._enabled: |
|
|
return |
|
|
|
|
|
_scale, _growth_tracker = self._check_scale_growth_tracker("update") |
|
|
|
|
|
if new_scale is not None: |
|
|
|
|
|
if isinstance(new_scale, float): |
|
|
self._scale.fill_(new_scale) |
|
|
else: |
|
|
reason = ( |
|
|
"new_scale should be a float or a 1-element torch.cuda.FloatTensor with" " requires_grad=False." |
|
|
) |
|
|
assert isinstance(new_scale, torch.cuda.FloatTensor), reason |
|
|
assert new_scale.numel() == 1, reason |
|
|
assert new_scale.requires_grad is False, reason |
|
|
self._scale.copy_(new_scale) |
|
|
else: |
|
|
|
|
|
|
|
|
found_infs = [ |
|
|
found_inf.to(device=_scale.device, non_blocking=True) |
|
|
for state in self._per_optimizer_states.values() |
|
|
for found_inf in state["found_inf_per_device"].values() |
|
|
] |
|
|
|
|
|
assert len(found_infs) > 0, "No inf checks were recorded prior to update." |
|
|
|
|
|
found_inf_combined = found_infs[0] |
|
|
|
|
|
|
|
|
torch.distributed.all_reduce( |
|
|
found_inf_combined, |
|
|
op=torch.distributed.ReduceOp.MAX, |
|
|
group=parallel_state.get_model_parallel_group(), |
|
|
) |
|
|
|
|
|
if len(found_infs) > 1: |
|
|
for i in range(1, len(found_infs)): |
|
|
found_inf = found_infs[i] |
|
|
|
|
|
torch.distributed.all_reduce( |
|
|
found_inf, |
|
|
op=torch.distributed.ReduceOp.MAX, |
|
|
group=parallel_state.get_model_parallel_group(), |
|
|
) |
|
|
found_inf_combined += found_inf |
|
|
|
|
|
if found_inf_combined > 0: |
|
|
self._hysteresis_tracker -= 1 |
|
|
if self._hysteresis_tracker <= 0: |
|
|
|
|
|
|
|
|
torch._amp_update_scale_( |
|
|
_scale, |
|
|
_growth_tracker, |
|
|
found_inf_combined, |
|
|
self._growth_factor, |
|
|
self._backoff_factor, |
|
|
self._growth_interval, |
|
|
) |
|
|
else: |
|
|
|
|
|
_growth_tracker.fill_(0.0) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
torch._amp_update_scale_( |
|
|
_scale, |
|
|
_growth_tracker, |
|
|
found_inf_combined, |
|
|
self._growth_factor, |
|
|
self._backoff_factor, |
|
|
self._growth_interval, |
|
|
) |
|
|
self._hysteresis_tracker = self.hysteresis |
|
|
|
|
|
|
|
|
self._per_optimizer_states = defaultdict( |
|
|
torch.cuda.amp.grad_scaler._refresh_per_optimizer_state |
|
|
) |
|
|
|
|
|
def state_dict(self): |
|
|
""" |
|
|
Add hysteresis_tracker to the native functions' state_dict. |
|
|
""" |
|
|
return ( |
|
|
{ |
|
|
"scale": self.get_scale(), |
|
|
"growth_factor": self._growth_factor, |
|
|
"backoff_factor": self._backoff_factor, |
|
|
"growth_interval": self._growth_interval, |
|
|
"_growth_tracker": self._get_growth_tracker(), |
|
|
"_hysteresis_tracker": self._hysteresis_tracker, |
|
|
} |
|
|
if self._enabled |
|
|
else {} |
|
|
) |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
""" |
|
|
Load hysteresis_tracker in addition to the state dict of the native function. |
|
|
""" |
|
|
if not self._enabled: |
|
|
return |
|
|
|
|
|
if len(state_dict) == 0: |
|
|
raise RuntimeError( |
|
|
"The source state dict is empty, possibly because it was saved " |
|
|
"from a disabled instance of GradScaler." |
|
|
) |
|
|
|
|
|
self._init_scale = state_dict["scale"] |
|
|
if self._scale is not None: |
|
|
self._scale.fill_(state_dict["scale"]) |
|
|
self._growth_factor = state_dict["growth_factor"] |
|
|
self._backoff_factor = state_dict["backoff_factor"] |
|
|
self._growth_interval = state_dict["growth_interval"] |
|
|
self._init_growth_tracker = state_dict["_growth_tracker"] |
|
|
if self._growth_tracker is not None: |
|
|
self._growth_tracker.fill_(state_dict["_growth_tracker"]) |
|
|
if "_hysterisis_tracker" in state_dict: |
|
|
self._hysteresis_tracker = state_dict["_hysterisis_tracker"] |
|
|
else: |
|
|
self._hysteresis_tracker = 1 |
|
|
|
|
|
|
|
|
def enable_nvidia_optimizations() -> None: |
|
|
"""These optimizations are present in NVIDIA NGC PyTorch Containers.""" |
|
|
|
|
|
nvidia_torch_version = os.getenv("NVIDIA_PYTORCH_VERSION", None) |
|
|
if nvidia_torch_version is not None: |
|
|
try: |
|
|
NVIDIA_TORCH_MAJOR = int(nvidia_torch_version.split(".")[0]) |
|
|
except Exception: |
|
|
NVIDIA_TORCH_MAJOR = 0 |
|
|
try: |
|
|
NVIDIA_TORCH_MINOR = int(nvidia_torch_version.split(".")[1]) |
|
|
except Exception: |
|
|
NVIDIA_TORCH_MINOR = 0 |
|
|
|
|
|
|
|
|
if NVIDIA_TORCH_MAJOR >= 21 or (NVIDIA_TORCH_MAJOR == 21 and NVIDIA_TORCH_MINOR >= 11): |
|
|
|
|
|
torch._C._jit_set_profiling_executor(True) |
|
|
torch._C._jit_set_profiling_mode(True) |
|
|
torch._C._jit_override_can_fuse_on_cpu(False) |
|
|
torch._C._jit_override_can_fuse_on_gpu(False) |
|
|
torch._C._jit_set_texpr_fuser_enabled(False) |
|
|
|
|
|
torch._C._debug_set_autodiff_subgraph_inlining(False) |
|
|
else: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
def optimizer_sharded_state_dict( |
|
|
model: SharedStateDictProtocol, |
|
|
optimizer: "Optimizable", |
|
|
is_loading=False, |
|
|
sharding_type='fully_sharded_model_space', |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Sharded state dictionary for an MainParamsOptimizerWrapper. |
|
|
Used to save and load the optimizer state when training with distributed_checkpoint. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
dict: The sharded state dictionary for the optimizer |
|
|
Raises: |
|
|
ValueError: If a parameter ID does not match any model sharded parameter. |
|
|
""" |
|
|
from megatron.core.dist_checkpointing.optimizer import ( |
|
|
get_param_id_to_sharded_param_map, |
|
|
make_sharded_optimizer_tensor, |
|
|
optim_state_to_sharding_state, |
|
|
) |
|
|
|
|
|
from nemo.core.optim import MainParamsOptimizerWrapper |
|
|
from nemo.core.optim.optimizers import init_optimizer_states |
|
|
|
|
|
model_sharded_state_dict = model.sharded_state_dict() |
|
|
|
|
|
|
|
|
model_sharded_state_dict = { |
|
|
key: value for key, value in model_sharded_state_dict.items() if not key.endswith("_extra_state") |
|
|
} |
|
|
|
|
|
if hasattr(optimizer, "sharded_state_dict"): |
|
|
return optimizer.sharded_state_dict( |
|
|
model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type |
|
|
) |
|
|
|
|
|
if not isinstance(optimizer, MainParamsOptimizerWrapper): |
|
|
|
|
|
init_optimizer_states(optimizer) |
|
|
optimizer_state_dict = optimizer.state_dict() |
|
|
id_to_sharded_param_map = get_param_id_to_sharded_param_map( |
|
|
model_sharded_state_dict=model_sharded_state_dict, |
|
|
optim_params_iter=itertools.chain.from_iterable(g['params'] for g in optimizer.param_groups), |
|
|
) |
|
|
optim_state_to_sharding_state(optimizer_state_dict, id_to_sharded_param_map) |
|
|
return optimizer_state_dict |
|
|
|
|
|
optimizer_state_dict: Dict[str, Any] = optimizer.state_dict() |
|
|
|
|
|
id_to_sharded_param_map = get_param_id_to_sharded_param_map( |
|
|
model_sharded_state_dict=model_sharded_state_dict, |
|
|
optim_params_iter=itertools.chain.from_iterable(g for g in optimizer.float16_groups), |
|
|
) |
|
|
|
|
|
|
|
|
assert len(optimizer_state_dict["fp32_from_fp16_params"]) == len(optimizer_state_dict["optimizer"]["param_groups"]) |
|
|
|
|
|
def get_safe(param_id): |
|
|
try: |
|
|
return id_to_sharded_param_map[param_id] |
|
|
except KeyError as e: |
|
|
raise ValueError(f"Param id {param_id} does not match any model sharded param") from e |
|
|
|
|
|
optimizer_state_dict["fp32_from_fp16_params"] = [ |
|
|
[ |
|
|
make_sharded_optimizer_tensor(get_safe(param_id), fp32_param, prefix="optimizer.state.fp32_param") |
|
|
for param_id, fp32_param in zip(state_group["params"], fp32_group) |
|
|
] |
|
|
for fp32_group, state_group in zip( |
|
|
optimizer_state_dict["fp32_from_fp16_params"], |
|
|
optimizer_state_dict["optimizer"]["param_groups"], |
|
|
) |
|
|
] |
|
|
|
|
|
|
|
|
optim_state_to_sharding_state(optimizer_state_dict["optimizer"], id_to_sharded_param_map) |
|
|
|
|
|
return optimizer_state_dict |
|
|
|
|
|
|
|
|
def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], strict: bool = True) -> None: |
|
|
""" """ |
|
|
from megatron.core import parallel_state |
|
|
from megatron.core.dist_checkpointing.validation import StrictHandling, parse_strict_flag |
|
|
|
|
|
|
|
|
if strict is not None and not isinstance(strict, bool): |
|
|
strict = parse_strict_flag(strict) |
|
|
strict_options = [ |
|
|
StrictHandling.ASSUME_OK_UNEXPECTED, |
|
|
StrictHandling.RAISE_UNEXPECTED, |
|
|
StrictHandling.RAISE_ALL, |
|
|
] |
|
|
strict = strict in strict_options |
|
|
|
|
|
try: |
|
|
from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel |
|
|
|
|
|
have_custom_fsdp = True |
|
|
except ImportError or ModuleNotFoundError: |
|
|
have_custom_fsdp = False |
|
|
|
|
|
for index, module in enumerate(megatron_parallel): |
|
|
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: |
|
|
if "state_dict" in checkpoint: |
|
|
checkpoint_state_dict = checkpoint["state_dict"][f"model_{index}"] |
|
|
else: |
|
|
checkpoint_state_dict = checkpoint[f"model_{index}"] |
|
|
else: |
|
|
if "state_dict" in checkpoint: |
|
|
checkpoint_state_dict = checkpoint["state_dict"] |
|
|
else: |
|
|
checkpoint_state_dict = checkpoint |
|
|
|
|
|
n_nesting = 0 |
|
|
mcore_model = megatron_parallel.module |
|
|
while hasattr(mcore_model, "module"): |
|
|
mcore_model = mcore_model.module |
|
|
n_nesting += 1 |
|
|
|
|
|
_state_dict = {} |
|
|
for key, value in checkpoint_state_dict.items(): |
|
|
|
|
|
count, _key = 0, key |
|
|
while _key.startswith("module."): |
|
|
_key = _key[len("module.") :] |
|
|
count += 1 |
|
|
|
|
|
|
|
|
if count < n_nesting: |
|
|
to_add = "module." * (n_nesting - count) |
|
|
_state_dict[f"{to_add}{key}"] = value |
|
|
elif count > n_nesting: |
|
|
to_remove = "module." * (count - n_nesting) |
|
|
_state_dict[key[len(to_remove) :]] = value |
|
|
else: |
|
|
_state_dict[key] = value |
|
|
|
|
|
if have_custom_fsdp and hasattr(module, "module") and isinstance(module.module, FullyShardedDataParallel): |
|
|
module.module.load_state_dict(_state_dict, strict=strict) |
|
|
continue |
|
|
|
|
|
try: |
|
|
module.load_state_dict(_state_dict, strict=strict) |
|
|
except RuntimeError as e: |
|
|
missing_keys, expected_keys = module.load_state_dict(checkpoint_state_dict, strict=False) |
|
|
if all(s.endswith('_extra_state') for s in missing_keys): |
|
|
logging.warning( |
|
|
f'Loding checkpoint created with Transformer Engine version lower than 1.13. ' |
|
|
f'Missing layers {missing_keys} will be ignored.' |
|
|
) |
|
|
else: |
|
|
raise e |
|
|
|
|
|
|
|
|
def _sync_from_last_pipeline_stage(value: torch.Tensor, broadcast: bool = False): |
|
|
""" |
|
|
When pipeline parallelism is enabled, |
|
|
casts a tensor defined on the last pipeline stage to other ranks. |
|
|
|
|
|
Args: |
|
|
value (torch.Tensor): A tensor to be casted from the final pipeline stage of |
|
|
a pipeline parallelism group (e.g. loss). |
|
|
Note that this tensor should already be defined on the target rank(s) to fill with received data. |
|
|
broadcast (bool): When True, broadcasts value from the final pipeline stage rank to all ranks in its group. |
|
|
When False, only rank zero receives value from the final pipeline stage rank in its group. |
|
|
This mode exists to avoid slow one-to-many communication when not necessary. Defaults to False. |
|
|
""" |
|
|
from megatron.core import parallel_state |
|
|
|
|
|
if parallel_state.get_pipeline_model_parallel_world_size() > 1: |
|
|
src_rank = parallel_state.get_pipeline_model_parallel_last_rank() |
|
|
|
|
|
if not broadcast: |
|
|
pp_ranks = torch.distributed.get_process_group_ranks(parallel_state.get_pipeline_model_parallel_group()) |
|
|
if torch.distributed.get_rank() == src_rank and 0 in pp_ranks: |
|
|
torch.distributed.send(value, 0) |
|
|
elif torch.distributed.get_rank() == 0: |
|
|
torch.distributed.recv(value, src_rank) |
|
|
else: |
|
|
torch.distributed.broadcast( |
|
|
value, |
|
|
src_rank, |
|
|
group=parallel_state.get_pipeline_model_parallel_group(), |
|
|
) |
|
|
|
|
|
|
|
|
def setup_megatron_optimizer( |
|
|
model, |
|
|
config, |
|
|
no_weight_decay_cond: Optional[Callable] = None, |
|
|
scale_lr_cond: Optional[Callable] = None, |
|
|
lr_mult: float = 1.0, |
|
|
): |
|
|
""" """ |
|
|
from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer |
|
|
|
|
|
from nemo.core.optim import McoreDistributedOptimizer |
|
|
from nemo.utils import AppState |
|
|
|
|
|
app_state = AppState() |
|
|
|
|
|
assert isinstance(config, OptimizerConfig), f"Expected OptimizerConfig, got {type(config)}" |
|
|
|
|
|
class McoreOpt(McoreDistributedOptimizer): |
|
|
""" """ |
|
|
|
|
|
def sharded_state_dict( |
|
|
self, |
|
|
model_sharded_state_dict, |
|
|
optimizer_state_dict=None, |
|
|
is_loading=False, |
|
|
sharding_type='fully_sharded_model_space', |
|
|
): |
|
|
mcore_optimizer_sig = inspect.signature(self.mcore_optimizer.sharded_state_dict).parameters |
|
|
distrib_optim_kwargs = {} |
|
|
if "sharding_type" in mcore_optimizer_sig: |
|
|
distrib_optim_kwargs["sharding_type"] = sharding_type |
|
|
state_dict = self.mcore_optimizer.sharded_state_dict( |
|
|
model_sharded_state_dict, is_loading=is_loading, **distrib_optim_kwargs |
|
|
) |
|
|
return state_dict |
|
|
|
|
|
|
|
|
ddp_modules = [m.module for m in model] |
|
|
mcore_opt = get_megatron_optimizer( |
|
|
config, |
|
|
ddp_modules, |
|
|
no_weight_decay_cond=no_weight_decay_cond, |
|
|
scale_lr_cond=scale_lr_cond, |
|
|
lr_mult=lr_mult, |
|
|
use_gloo_process_groups=app_state.use_gloo_process_groups, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for pg in mcore_opt.param_groups: |
|
|
if 'pre_lr_mult' in pg or 'pre_mult_wd' in pg: |
|
|
|
|
|
|
|
|
continue |
|
|
pg['pre_mult_lr'] = pg["lr"] |
|
|
pg['pre_mult_wd'] = pg['weight_decay'] |
|
|
new_lr = pg["lr"] * pg.get('lr_mult', 1.0) |
|
|
new_wd = pg["weight_decay"] * pg.get("wd_mult", 1.0) |
|
|
pg['lr'] = new_lr |
|
|
pg['weight_decay'] = new_wd |
|
|
|
|
|
|
|
|
|
|
|
if 'lr_mult' in pg: |
|
|
pg['pre_lr_mult'] = pg['lr_mult'] |
|
|
del pg['lr_mult'] |
|
|
if 'wd_mult' in pg: |
|
|
pg['pre_wd_mult'] = pg['wd_mult'] |
|
|
del pg['wd_mult'] |
|
|
|
|
|
if getattr(model.ddp_config, "overlap_param_gather", False) and getattr( |
|
|
model.ddp_config, "align_param_gather", False |
|
|
): |
|
|
param_sync_func = [model_chunk.start_param_sync for model_chunk in model] |
|
|
param_sync_func = param_sync_func[0] if len(model) == 1 else param_sync_func |
|
|
for module in model: |
|
|
module.config.param_sync_func = param_sync_func |
|
|
|
|
|
return McoreOpt(mcore_opt) |
|
|
|