NeMo_Canary / nemo /core /optim /distributed_adam.py
Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import contextlib
import itertools
from typing import Callable, Dict, Iterable, Optional, Tuple, Union
import torch
from apex.contrib.optimizers.distributed_fused_adam import (
DistributedFusedAdam,
_disable_pre_forward_hook,
_multi_tensor_copy,
)
try:
import apex.contrib.nccl_allocator as nccl_allocator
except ImportError:
nccl_allocator = None
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace
from megatron.core.dist_checkpointing.mapping import ShardedTensor
from megatron.core.dist_checkpointing.optimizer import get_param_id_to_sharded_param_map, optim_state_to_sharding_state
from nemo.utils import logging, str_to_dtype
from nemo.utils.te_utils import is_float8tensor, is_mxfp8tensor, te_version
if te_version() >= (2, 0):
# TE quantization logic using quantizer API
# Supported TE versions: 2.0+
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
def _quantize_param_fragment_impl(
input_: torch.Tensor,
*,
out: torch.Tensor,
param: torch.nn.Parameter,
) -> None:
quantizer = param._quantizer
out = Float8Tensor(
shape=input_.size(),
dtype=param.dtype,
requires_grad=False,
data=out,
fp8_scale_inv=param._scale_inv,
fp8_dtype=param._fp8_dtype,
quantizer=quantizer,
)
quantizer.update_quantized(input_, out)
def _get_fp8_scale_and_amax_impl(tensor: Float8Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
quantizer = tensor._quantizer
return quantizer.scale, quantizer.amax
elif te_version() >= (1, 0):
# TE quantization logic with fp8_meta dicts
# Supported TE versions: 1.0 - 1.14
from transformer_engine.pytorch.cpp_extensions import cast_to_fp8
def _quantize_param_fragment_impl(
input_: torch.Tensor,
*,
out: torch.Tensor,
param: torch.nn.Parameter,
) -> None:
cast_to_fp8(
src.view(1, -1),
param._fp8_meta["scaling_fwd"],
param._fp8_meta_index,
param._fp8_dtype,
out=dst.view(1, -1),
)
def _get_fp8_scale_and_amax_impl(tensor) -> Tuple[torch.Tensor, torch.Tensor]:
fp8_meta = tensor._fp8_meta["scaling_fwd"]
fp8_meta_index = tensor._fp8_meta_index
return fp8_meta.scale[fp8_meta_index], fp8_meta.amax_history[0][fp8_meta_index]
else:
# Fallback impl if TE version is invalid
def _quantize_param_fragment_impl(*args, **kwargs) -> None:
raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
def _get_fp8_scale_and_amax_impl(*args, **kwargs):
raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
def quantize_param_fragment(
input_: torch.Tensor,
*,
out: torch.Tensor,
param: torch.nn.Parameter,
) -> None:
"""Cast values in parameter fragment to FP8
Arguments:
input_ (torch.Tensor): Values to quantize.
out (torch.Tensor): Raw UINT8 buffer to fill with FP8 values.
Dimensions should match input_.
param (torch.nn.Parameter): Parameter containing this parameter
fragment. Must be a Float8Tensor.
"""
_quantize_param_fragment_impl(input_, out=out, param=param)
def get_fp8_scale_and_amax(tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get FP8 scale and amax from Float8Tensor"""
return _get_fp8_scale_and_amax_impl(tensor)
_distributed_pgs = {}
def create_distributed_pgs(*, distributed_size: int) -> Dict:
"""Create process groups for distributing within multiple devices.
User can reuse this function to reorder communicators for SHArP.
Arguments:
distributed_size (int): the number of devices to distribute optimizer
state over.
"""
global _distributed_pgs
assert torch.distributed.is_initialized()
if _distributed_pgs:
return _distributed_pgs
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
devices = distributed_size
nodes = world_size // devices
if nodes * devices != world_size:
logging.warning("Expected all nodes have the same amout of devices, disable distribute_within_nodes.")
return {}
node_id = rank // devices
device_id = rank % devices
distributed_pgs = []
for i in range(nodes):
ranks = [i * devices + j for j in range(devices)]
pg = torch.distributed.new_group(ranks=ranks)
distributed_pgs.append(pg)
redundant_pgs = []
for i in range(devices):
ranks = [i + j * devices for j in range(nodes)]
pg = torch.distributed.new_group(ranks=ranks)
redundant_pgs.append(pg)
# To re-order SHArP communicator right after distributed init,
# we have to expose redundant_process_group to user.
# User has too invoke allreduce through redundant_process_group
# before all other communicators to lock SHArP tree.
_distributed_pgs = {
'world_size': world_size,
'rank': rank,
'devices': devices,
'nodes': nodes,
'node_id': node_id,
'device_id': device_id,
'distributed_process_group': distributed_pgs[node_id],
'redundant_process_group': redundant_pgs[device_id],
}
return _distributed_pgs
def create_distribute_within_nodes_pgs():
"""Create process groups for distributing within nodes.
User can reuse this function to reorder communicators for SHArP.
This funcion is kept for backward compatibility.
"""
return create_distributed_pgs(distributed_size=torch.cuda.device_count())
class MegatronDistributedFusedAdam(DistributedFusedAdam):
"""Adam optimizer with ZeRO algorithm
Child class of Apex DistributedFusedAdam, with optimizations for
NeMo-Megatron.
Arguments:
params (iterable): iterable of parameters to optimize or dicts
defining parameter groups.
disable_distributed_parameters (bool, optional): use standard
data-parallel communication instead of ZeRO.
(default: False)
distribute_within_nodes (bool, optional): distribute states
within the same node, e.g. DGX. This can improve performance
but requires larger memory than distributing within all
ranks, especially for pure data parallel models.
(default: False).
distributed_size (int, optional): the number of devices to
distribute optimizer state over.
lock_timeout (float, optional): timeout for callback mutex in
seconds.
**kwargs: keyword arguments to pass to Apex
DistributedFusedAdam.
"""
def __init__(
self,
params: Union[Iterable[torch.nn.Parameter], Iterable[dict]],
disable_distributed_parameters: bool = False,
distribute_within_nodes: bool = False,
distributed_size: Optional[int] = None,
lock_timeout: Optional[float] = None,
**kwargs,
):
# Update distributed_size settings
if distribute_within_nodes:
if distributed_size is not None and distributed_size != torch.cuda.device_count():
raise ValueError("Inconsistent distributed_size value")
distributed_size = torch.cuda.device_count()
# Initialize process groups
if 'process_group' not in kwargs and parallel_state.is_initialized():
kwargs['process_group'] = parallel_state.get_data_parallel_group(with_context_parallel=True)
if disable_distributed_parameters:
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
self_groups = [torch.distributed.new_group(ranks=[i]) for i in range(world_size)]
kwargs['distributed_process_group'] = self_groups[rank]
kwargs['redundant_process_group'] = kwargs['process_group']
elif distributed_size is not None:
dist_pg_infos = create_distributed_pgs(distributed_size=distributed_size)
if dist_pg_infos:
kwargs['distributed_process_group'] = dist_pg_infos['distributed_process_group']
kwargs['redundant_process_group'] = dist_pg_infos['redundant_process_group']
global _distributed_pgs
_distributed_pgs = {}
# Make sure dtypes are in right type
for keyword in ('dtype', 'grad_sync_dtype', 'param_sync_dtype'):
if keyword in kwargs:
kwargs[keyword] = str_to_dtype(kwargs[keyword])
# Make sure params are in consistent format (list of param group dicts)
param_groups = list(params)
assert param_groups
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
# Construct distributed optimizer
super().__init__(param_groups, **kwargs)
# Create mutex with timeout
self._lock_with_timeout = None
if lock_timeout is not None:
@contextlib.contextmanager
def lock_with_timeout():
result = self._lock.acquire(timeout=lock_timeout)
try:
yield result
finally:
if result:
# Acquired lock before timeout
self._lock.release()
else:
# Failed to acquire lock before timeout
print(f'MegatronDistributedFusedAdam: Failed to acquire lock within {lock_timeout} seconds.')
self._lock_with_timeout = lock_with_timeout
# Check for MXFP8 parameters
if any(is_mxfp8tensor(param) for param in self.parameters()):
raise ValueError("Distributed optimizer currently does not support MXFP8 parameters")
def _broadcast_params(self) -> None:
# Assume params have already been synchronized
pass
def _make_post_backward_hook(self, param: torch.nn.Parameter, param_group_id: int, param_id: int) -> Callable:
def hook(*unused):
if getattr(param, '_pre_forward_hook_is_enabled', False):
raise RuntimeError(
'A parameter called its post-backward hook '
'before its pre-forward hook. '
'Please manually interact with the parameter '
'before the forward pass (e.g. by calling data_ptr) '
'or run DistributedFusedAdam with overlap_param_sync=False.'
)
lock = self._lock
if self._lock_with_timeout is not None:
lock = self._lock_with_timeout()
with lock:
need_to_initialize = 'fragments' not in self.state[param]
if need_to_initialize:
self._init_param_state(param, param_group_id, param_id)
if self.greedy_grad_copy and not getattr(param, '_disable_greedy_grad_copy', False):
self._grad_copy(param)
if self.overlap_grad_sync and not getattr(param, '_disable_overlap_grad_sync', False):
self._try_start_bucket_grad_sync(
params=[param],
ignore_last_bucket=need_to_initialize,
)
return hook
def init_params(
self,
params: Optional[Iterable[torch.nn.Parameter]] = None,
param_sync_dtype: Optional[torch.dtype] = None,
**kwargs,
) -> None:
"""Initialize optimizer state for parameters
Initializes FP8 and non-FP8 params separately.
"""
# Default cases
if params is None:
params = self.parameters()
elif isinstance(params, torch.Tensor):
params = [params]
# Ignore parameters that have already been initialized
params = [param for param in params if "fragments" not in self.state[param]]
if not params:
return
# Initialize FP8 and non-FP8 tensors separately
if any(is_float8tensor(param) for param in params):
super().init_params(
filter(is_float8tensor, params),
param_sync_dtype=torch.uint8,
**kwargs,
)
super().init_params(
params,
param_sync_dtype=param_sync_dtype,
**kwargs,
)
def init_params_bucket(
self,
params: Iterable[torch.nn.Parameter],
grad_sync_dtype: Optional[torch.dtype] = None,
param_sync_dtype: Optional[torch.dtype] = None,
**kwargs,
) -> None:
"""Initialize optimizer state for parameters in one effective bucket"""
# Ignore parameters that have already been initialized
if isinstance(params, torch.Tensor):
params = [params]
params = [param for param in params if "fragments" not in self.state[param]]
if not params:
return
# Initialize parameters with FP32 grads
fp32_params = []
remaining_params = []
for param in params:
if getattr(param, '_with_fp32_optimizer', False):
fp32_params.append(param)
else:
remaining_params.append(param)
params = remaining_params
start_bucket_id = len(self.state["buckets"])
super().init_params_bucket(
fp32_params,
grad_sync_dtype=torch.float32,
param_sync_dtype=param_sync_dtype,
**kwargs,
)
end_bucket_id = len(self.state["buckets"])
fp32_buckets = self.state["buckets"][start_bucket_id:end_bucket_id]
# Initialize FP8 parameters
fp8_params = []
remaining_params = []
for param in params:
if is_float8tensor(param):
fp8_params.append(param)
else:
remaining_params.append(param)
params = remaining_params
start_bucket_id = len(self.state["buckets"])
super().init_params_bucket(
fp8_params,
grad_sync_dtype=grad_sync_dtype,
param_sync_dtype=torch.uint8,
**kwargs,
)
end_bucket_id = len(self.state["buckets"])
fp8_buckets = self.state["buckets"][start_bucket_id:end_bucket_id]
# Initialize remaining parameters as usual
normal_buckets = []
start_bucket_id = len(self.state["buckets"])
super().init_params_bucket(
params,
grad_sync_dtype=grad_sync_dtype,
param_sync_dtype=param_sync_dtype,
**kwargs,
)
end_bucket_id = len(self.state["buckets"])
normal_buckets = self.state["buckets"][start_bucket_id:end_bucket_id]
def add_param_to_bucket(
param: torch.nn.Parameter,
bucket: self.StateBucket,
) -> None:
"""Add trivial param fragment to bucket"""
param_fragments = self.state[param]["fragments"]
param_group_id = param_fragments[0].param_group_id
param_id = param_fragments[0].param_id
bucket_id = bucket.fragments[0].bucket_id
param_size = param.numel()
bucket_size = bucket.bucket_size
fragment = self.ParameterFragment(
param_group_id=param_group_id,
param_id=param_id,
bucket_id=bucket_id,
param_range=(param_size, param_size),
bucket_range=(bucket_size, bucket_size),
in_local_shard=False,
shard_range=None,
shard_bucket_range=None,
shard_param_range=None,
)
param_fragments.append(fragment)
bucket.fragments.append(fragment)
# Make sure all added buckets depend on provided params
for bucket in fp32_buckets:
for param in itertools.chain(fp8_params, params):
add_param_to_bucket(param, bucket)
for bucket in fp8_buckets:
for param in itertools.chain(fp32_params, params):
add_param_to_bucket(param, bucket)
for bucket in normal_buckets:
for param in itertools.chain(fp32_params, fp8_params):
add_param_to_bucket(param, bucket)
def _init_param_state(
self,
param: torch.nn.Parameter,
param_group_id: int,
param_id: int,
param_sync_dtype: Optional[torch.dtype] = None,
**kwargs,
) -> None:
"""Initialize optimizer state for a parameter
Initializing the master weights requires slicing a flattened
view of the param. FP8 tensors do not handle these operations
gracefully, so we hack around it by explicitly casting to
FP32.
"""
# Initialize non-FP8 params as usual
if not is_float8tensor(param):
super()._init_param_state(
param,
param_group_id,
param_id,
param_sync_dtype=param_sync_dtype,
**kwargs,
)
# Return immediately if already initialized
if "fragments" in self.state[param]:
return
# Initialize with FP32 copy of param
fp32_param = param.float()
super()._init_param_state(
fp32_param,
param_group_id,
param_id,
param_sync_dtype=torch.uint8,
**kwargs,
)
self.state[param].update(self.state[fp32_param])
del self.state[fp32_param]
@torch.no_grad()
def init_param_buffer(self) -> None:
"""Allocate contiguous buffers for param buckets
For FP8 params, the FP8 data buffer is made a view into a
contiguous buffer.
"""
# Make sure all params are initialized
self.contiguous_param_buffer = True
self.init_params()
# Construct param buffers
buffer_sizes = collections.defaultdict(lambda: 0)
for bucket in self.state["buckets"]:
dtypes = bucket.dtypes()
buffer_sizes[dtypes] = max(bucket.contiguous_buffer_offset + bucket.bucket_size, buffer_sizes[dtypes])
for dtypes, buffer_size in buffer_sizes.items():
_, _, param_sync_dtype = dtypes
if getattr(self, "nccl_ub", False):
if not nccl_allocator:
raise RuntimeError("NCCL allocator importing failed but nccl ub is still requested")
with nccl_allocator.nccl_mem():
self._param_buffers[dtypes] = torch.zeros(
[buffer_size], dtype=param_sync_dtype, device=self.device
)
else:
self._param_buffers[dtypes] = torch.zeros([buffer_size], dtype=param_sync_dtype, device=self.device)
# Figure out corresponding positions in params and param buffer
params = list(self.parameters())
param_flat_views = []
param_buffer_views = []
for i, param in enumerate(params):
fragment = self.state[param]["fragments"][0]
bucket_id = fragment.bucket_id
bucket = self.state["buckets"][bucket_id]
param_size = param.numel()
bucket_start, _ = fragment.bucket_range
buffer_offset = bucket.contiguous_buffer_offset
buffer_start = buffer_offset + bucket_start
buffer_end = buffer_start + param_size
param_buffer = self._param_buffers[bucket.dtypes()]
param_buffer_view = param_buffer[buffer_start:buffer_end].detach()
if param_buffer_view.device != param.device:
raise RuntimeError(
"Attempted to change a parameter with device={param.device} "
f"into a buffer view with device={param_buffer_view.device}"
)
if is_float8tensor(param):
param_flat_views.append(param._data.detach().view(-1))
else:
if param_buffer_view.dtype != param.dtype:
raise RuntimeError(
f"Attempted to change a parameter with dtype={param.dtype} "
f"into a buffer view with dtype={param_buffer_view.dtype}"
)
if param.is_contiguous(memory_format=torch.channels_last):
param = param.permute(0, 2, 3, 1)
param_flat_views.append(param.detach().view(-1))
param_buffer_views.append(param_buffer_view)
# Copy values into param buffer
_multi_tensor_copy(
param_flat_views,
param_buffer_views,
dummy_overflow_buf=self._dummy_overflow_buf,
)
# Make all params a view into the param buffer
for param, buffer_view in zip(params, param_buffer_views):
if is_float8tensor(param):
param._data = buffer_view.view(param.size())
else:
# Preserve memory format for param here, i.e. NHWC tensors
# `param.data.set_()` failed to change storage.
# `param.set_()` invalidates bprop hook.
param.data = torch.as_strided(
buffer_view,
param.size(),
param.stride(),
storage_offset=buffer_view.storage_offset(),
)
def try_grad_sync(self, params: Iterable[torch.nn.Parameter]) -> None:
"""Attempt to launch gradient synchronization"""
def is_grad_copy_enabled(param: torch.nn.Parameter) -> bool:
return not getattr(param, '_disable_greedy_grad_copy', False) and not getattr(
param, '_disable_overlap_grad_sync', False
)
params = list(filter(is_grad_copy_enabled, params))
for p in params:
self._grad_copy(p)
self._try_start_bucket_grad_sync(params=params)
def zero_grad(self, *args, **kwargs) -> None:
"""Clear parameter gradients"""
super().zero_grad(*args, **kwargs)
# Reset main grads
if self.contiguous_grad_buffer:
for param in self.parameters():
with _disable_pre_forward_hook(param):
param.main_grad = self.grad_buffer_view(param)
def grad_norm(
self,
parameters: Optional[Iterable[torch.nn.Parameter]] = None,
norm_type: float = 2.0,
force: bool = False,
) -> torch.Tensor:
"""L2 norm of parameter gradients"""
assert norm_type == 2
if parameters is not None:
# Make sure we can access iterable multiple times
parameters = list(parameters)
# Compute grad norm
if force or self._grad_norm is None:
# Compute norm of local gradients for distributed optimizer
grad_norm_sq = self._local_grad_norm(parameters=parameters, norm_type=norm_type)
if self.redundant_size > 1:
grad_norm_sq /= self.redundant_size
# Sum over all procs to get grad norm
torch.distributed.all_reduce(
grad_norm_sq,
op=torch.distributed.ReduceOp.SUM,
)
self._grad_norm = grad_norm_sq.sqrt()
# Use cached grad norm
return super().grad_norm()
@torch.no_grad()
def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.ParameterFragment]) -> None:
"""Update parameter fragments with values from parameter buckets
For FP8 params, values are copied directly into the FP8 data
buffer.
"""
# Figure out corresponding positions in param buckets and params
buffers_in = []
buffers_out = []
fragments = list(fragments)
for fragment in fragments:
# Check if fragment needs to be updated
bucket_id = fragment.bucket_id
bucket_start, bucket_end = fragment.bucket_range
param_start, param_end = fragment.param_range
if param_end <= param_start or bucket_id not in self._params_buckets:
continue
# Corresponding positions in bucket and param
param_bucket = self._params_buckets[bucket_id]
param = self.parameter(fragment)
buffer_in = param_bucket.params_bucket[bucket_start:bucket_end]
if is_float8tensor(param):
# Copy into FP8 params's data buffer
assert (
param_bucket.params_bucket.dtype == torch.uint8
), "Expected FP8 params to perform param sync in UINT8"
buffer_out = param._data.view(-1)[param_start:param_end]
buffers_in.append(buffer_in)
buffers_out.append(buffer_out)
elif torch.is_floating_point(buffer_in) and torch.is_floating_point(param):
# Conv with NHWC layout, i.e. shape (N, C, H, W) and stride
# (HWC, 1, WC, C), can't `.view(-1)`. Here to turn it to
# tensor with shape (N, H, W, C) and stride (HWC, WC, C, 1).
# Note: https://github.com/NVIDIA/apex/pull/1794
if param.is_contiguous(memory_format=torch.channels_last):
param = param.permute(0, 2, 3, 1)
# Cast between floating-point dtypes
buffer_out = param.detach().view(-1)[param_start:param_end]
buffers_in.append(buffer_in)
buffers_out.append(buffer_out)
else:
# Copy most significant bytes for non-floating-point
# dtypes
# Note: Assume dtypes are little-endian
buffer_out = param.detach().view(-1)[param_start:param_end]
in_bytes = buffer_in.unsqueeze(-1).view(torch.uint8)
out_bytes = buffer_out.unsqueeze(-1).view(torch.uint8)
copy_size = min(in_bytes.size(-1), out_bytes.size(-1))
buffers_in.append(in_bytes[..., -copy_size:])
buffers_out.append(out_bytes[..., -copy_size:])
if copy_size < out_bytes.size(-1):
out_bytes[..., :-copy_size].zero_()
# Copy data from parameter buckets to parameters
_multi_tensor_copy(
buffers_in,
buffers_out,
dummy_overflow_buf=self._dummy_overflow_buf,
)
# Update transpose caches
params = set(self.parameter(fragment) for fragment in fragments)
for param in params:
if is_float8tensor(param):
param._reset_caches()
@torch.no_grad()
def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedAdam.ParameterBucket]) -> None:
"""Make sure local shards of parameters are in expected datatypes
For FP8 params, FP32 values are cast into FP8 using per-param
scaling factors and per-param amaxes are computed and reduced.
"""
# Just call base class function if there are no FP8 tensors
num_fp8_params = sum(1 for param in self.parameters() if is_float8tensor(param))
if num_fp8_params == 0:
super()._check_params_shard_dtypes(params_buckets)
return
# Cast local data to FP8
fp8_params_shards = dict()
for bucket_id, param_bucket in params_buckets.items():
state_bucket = self.state["buckets"][bucket_id]
if state_bucket.param_sync_dtype != torch.uint8:
continue
# Initialize FP8 buffer for param sync
params_shard = param_bucket.params_shard
if self.contiguous_param_buffer:
shard_size = state_bucket.shard_size
buffer_offset = state_bucket.contiguous_buffer_offset
buffer_start = buffer_offset + self.distributed_rank * shard_size
buffer_end = buffer_start + shard_size
param_buffer = self._param_buffers[state_bucket.dtypes()]
fp8_params_shard = param_buffer[buffer_start:buffer_end]
else:
fp8_params_shard = torch.empty_like(params_shard, dtype=torch.uint8)
param_bucket.params_shard = fp8_params_shard
# Cast param fragments to FP8
for fragment in self.state["buckets"][bucket_id].fragments:
param = self.parameter(fragment)
if not is_float8tensor(param):
continue
if not fragment.in_local_shard:
continue
shard_start, shard_end = fragment.shard_range
if shard_end <= shard_start:
continue
shard_range = slice(shard_start, shard_end)
quantize_param_fragment(
params_shard[shard_range],
out=fp8_params_shard[shard_range],
param=param,
)
# Update FP8 scaling factors when all buckets have processed
if getattr(self, "_check_params_shard_dtypes_progress", None) is None:
self._check_params_shard_dtypes_progress = []
self._check_params_shard_dtypes_progress.extend(params_buckets.keys())
if len(self._check_params_shard_dtypes_progress) == len(self.state["buckets"]):
assert len(set(self._check_params_shard_dtypes_progress)) == len(self.state["buckets"])
# FP8 scaling factors
amaxes = []
scales = []
scale_invs = []
i = -1
for param in self.parameters():
if not is_float8tensor(param):
continue
i += 1
scale, amax = get_fp8_scale_and_amax(param)
amaxes.append(amax.view(1))
scales.append(scale.view(1))
scale_invs.append(param._scale_inv.view(1))
# Update cached scale-inverses
packed_scales = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device)
packed_scale_views = [packed_scales[i].view(1) for i in range(num_fp8_params)]
_multi_tensor_copy(
scales,
packed_scale_views,
dummy_overflow_buf=self._dummy_overflow_buf,
)
torch.reciprocal(packed_scales, out=packed_scales)
_multi_tensor_copy(
packed_scale_views,
scale_invs,
dummy_overflow_buf=self._dummy_overflow_buf,
)
# Reduce amaxes
# Note: Assume each param has a separate amax
packed_amaxes = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device)
packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)]
_multi_tensor_copy(
amaxes,
packed_amax_views,
dummy_overflow_buf=self._dummy_overflow_buf,
)
torch.distributed.all_reduce(
packed_amaxes,
op=torch.distributed.ReduceOp.MAX,
group=self.distributed_process_group,
)
_multi_tensor_copy(
packed_amax_views,
amaxes,
dummy_overflow_buf=self._dummy_overflow_buf,
)
# Reset
self._check_params_shard_dtypes_progress = None
# Handle any remaining dtype conversions
super()._check_params_shard_dtypes(params_buckets)
def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None):
"""Create sharded state dict"""
if optimizer_state_dict is None:
optimizer_state_dict = self.state_dict()
id_to_sharded_param_map = get_param_id_to_sharded_param_map(
model_sharded_state_dict=model_sharded_state_dict,
optim_params_iter=self.parameters(),
)
# Convert state
step = optimizer_state_dict['state'].pop('step')
state_dict_format = optimizer_state_dict.pop('format', None)
optim_state_to_sharding_state(optimizer_state_dict, id_to_sharded_param_map)
optimizer_state_dict['state']['step'] = step
if state_dict_format is not None:
optimizer_state_dict['format'] = state_dict_format
def rename_fp32_params(x):
if isinstance(x, ShardedTensor) and x.key.startswith('optimizer.state.param'):
x.key = x.key.replace('optimizer.state.param', 'optimizer.state.fp32_param')
return x
dict_list_map_inplace(rename_fp32_params, optimizer_state_dict)
return optimizer_state_dict