# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import contextlib import itertools from typing import Callable, Dict, Iterable, Optional, Tuple, Union import torch from apex.contrib.optimizers.distributed_fused_adam import ( DistributedFusedAdam, _disable_pre_forward_hook, _multi_tensor_copy, ) try: import apex.contrib.nccl_allocator as nccl_allocator except ImportError: nccl_allocator = None from megatron.core import parallel_state from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace from megatron.core.dist_checkpointing.mapping import ShardedTensor from megatron.core.dist_checkpointing.optimizer import get_param_id_to_sharded_param_map, optim_state_to_sharding_state from nemo.utils import logging, str_to_dtype from nemo.utils.te_utils import is_float8tensor, is_mxfp8tensor, te_version if te_version() >= (2, 0): # TE quantization logic using quantizer API # Supported TE versions: 2.0+ from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor def _quantize_param_fragment_impl( input_: torch.Tensor, *, out: torch.Tensor, param: torch.nn.Parameter, ) -> None: quantizer = param._quantizer out = Float8Tensor( shape=input_.size(), dtype=param.dtype, requires_grad=False, data=out, fp8_scale_inv=param._scale_inv, fp8_dtype=param._fp8_dtype, quantizer=quantizer, ) quantizer.update_quantized(input_, out) def _get_fp8_scale_and_amax_impl(tensor: Float8Tensor) -> Tuple[torch.Tensor, torch.Tensor]: quantizer = tensor._quantizer return quantizer.scale, quantizer.amax elif te_version() >= (1, 0): # TE quantization logic with fp8_meta dicts # Supported TE versions: 1.0 - 1.14 from transformer_engine.pytorch.cpp_extensions import cast_to_fp8 def _quantize_param_fragment_impl( input_: torch.Tensor, *, out: torch.Tensor, param: torch.nn.Parameter, ) -> None: cast_to_fp8( src.view(1, -1), param._fp8_meta["scaling_fwd"], param._fp8_meta_index, param._fp8_dtype, out=dst.view(1, -1), ) def _get_fp8_scale_and_amax_impl(tensor) -> Tuple[torch.Tensor, torch.Tensor]: fp8_meta = tensor._fp8_meta["scaling_fwd"] fp8_meta_index = tensor._fp8_meta_index return fp8_meta.scale[fp8_meta_index], fp8_meta.amax_history[0][fp8_meta_index] else: # Fallback impl if TE version is invalid def _quantize_param_fragment_impl(*args, **kwargs) -> None: raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer") def _get_fp8_scale_and_amax_impl(*args, **kwargs): raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer") def quantize_param_fragment( input_: torch.Tensor, *, out: torch.Tensor, param: torch.nn.Parameter, ) -> None: """Cast values in parameter fragment to FP8 Arguments: input_ (torch.Tensor): Values to quantize. out (torch.Tensor): Raw UINT8 buffer to fill with FP8 values. Dimensions should match input_. param (torch.nn.Parameter): Parameter containing this parameter fragment. Must be a Float8Tensor. """ _quantize_param_fragment_impl(input_, out=out, param=param) def get_fp8_scale_and_amax(tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Get FP8 scale and amax from Float8Tensor""" return _get_fp8_scale_and_amax_impl(tensor) _distributed_pgs = {} def create_distributed_pgs(*, distributed_size: int) -> Dict: """Create process groups for distributing within multiple devices. User can reuse this function to reorder communicators for SHArP. Arguments: distributed_size (int): the number of devices to distribute optimizer state over. """ global _distributed_pgs assert torch.distributed.is_initialized() if _distributed_pgs: return _distributed_pgs world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() devices = distributed_size nodes = world_size // devices if nodes * devices != world_size: logging.warning("Expected all nodes have the same amout of devices, disable distribute_within_nodes.") return {} node_id = rank // devices device_id = rank % devices distributed_pgs = [] for i in range(nodes): ranks = [i * devices + j for j in range(devices)] pg = torch.distributed.new_group(ranks=ranks) distributed_pgs.append(pg) redundant_pgs = [] for i in range(devices): ranks = [i + j * devices for j in range(nodes)] pg = torch.distributed.new_group(ranks=ranks) redundant_pgs.append(pg) # To re-order SHArP communicator right after distributed init, # we have to expose redundant_process_group to user. # User has too invoke allreduce through redundant_process_group # before all other communicators to lock SHArP tree. _distributed_pgs = { 'world_size': world_size, 'rank': rank, 'devices': devices, 'nodes': nodes, 'node_id': node_id, 'device_id': device_id, 'distributed_process_group': distributed_pgs[node_id], 'redundant_process_group': redundant_pgs[device_id], } return _distributed_pgs def create_distribute_within_nodes_pgs(): """Create process groups for distributing within nodes. User can reuse this function to reorder communicators for SHArP. This funcion is kept for backward compatibility. """ return create_distributed_pgs(distributed_size=torch.cuda.device_count()) class MegatronDistributedFusedAdam(DistributedFusedAdam): """Adam optimizer with ZeRO algorithm Child class of Apex DistributedFusedAdam, with optimizations for NeMo-Megatron. Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups. disable_distributed_parameters (bool, optional): use standard data-parallel communication instead of ZeRO. (default: False) distribute_within_nodes (bool, optional): distribute states within the same node, e.g. DGX. This can improve performance but requires larger memory than distributing within all ranks, especially for pure data parallel models. (default: False). distributed_size (int, optional): the number of devices to distribute optimizer state over. lock_timeout (float, optional): timeout for callback mutex in seconds. **kwargs: keyword arguments to pass to Apex DistributedFusedAdam. """ def __init__( self, params: Union[Iterable[torch.nn.Parameter], Iterable[dict]], disable_distributed_parameters: bool = False, distribute_within_nodes: bool = False, distributed_size: Optional[int] = None, lock_timeout: Optional[float] = None, **kwargs, ): # Update distributed_size settings if distribute_within_nodes: if distributed_size is not None and distributed_size != torch.cuda.device_count(): raise ValueError("Inconsistent distributed_size value") distributed_size = torch.cuda.device_count() # Initialize process groups if 'process_group' not in kwargs and parallel_state.is_initialized(): kwargs['process_group'] = parallel_state.get_data_parallel_group(with_context_parallel=True) if disable_distributed_parameters: world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() self_groups = [torch.distributed.new_group(ranks=[i]) for i in range(world_size)] kwargs['distributed_process_group'] = self_groups[rank] kwargs['redundant_process_group'] = kwargs['process_group'] elif distributed_size is not None: dist_pg_infos = create_distributed_pgs(distributed_size=distributed_size) if dist_pg_infos: kwargs['distributed_process_group'] = dist_pg_infos['distributed_process_group'] kwargs['redundant_process_group'] = dist_pg_infos['redundant_process_group'] global _distributed_pgs _distributed_pgs = {} # Make sure dtypes are in right type for keyword in ('dtype', 'grad_sync_dtype', 'param_sync_dtype'): if keyword in kwargs: kwargs[keyword] = str_to_dtype(kwargs[keyword]) # Make sure params are in consistent format (list of param group dicts) param_groups = list(params) assert param_groups if not isinstance(param_groups[0], dict): param_groups = [{'params': param_groups}] # Construct distributed optimizer super().__init__(param_groups, **kwargs) # Create mutex with timeout self._lock_with_timeout = None if lock_timeout is not None: @contextlib.contextmanager def lock_with_timeout(): result = self._lock.acquire(timeout=lock_timeout) try: yield result finally: if result: # Acquired lock before timeout self._lock.release() else: # Failed to acquire lock before timeout print(f'MegatronDistributedFusedAdam: Failed to acquire lock within {lock_timeout} seconds.') self._lock_with_timeout = lock_with_timeout # Check for MXFP8 parameters if any(is_mxfp8tensor(param) for param in self.parameters()): raise ValueError("Distributed optimizer currently does not support MXFP8 parameters") def _broadcast_params(self) -> None: # Assume params have already been synchronized pass def _make_post_backward_hook(self, param: torch.nn.Parameter, param_group_id: int, param_id: int) -> Callable: def hook(*unused): if getattr(param, '_pre_forward_hook_is_enabled', False): raise RuntimeError( 'A parameter called its post-backward hook ' 'before its pre-forward hook. ' 'Please manually interact with the parameter ' 'before the forward pass (e.g. by calling data_ptr) ' 'or run DistributedFusedAdam with overlap_param_sync=False.' ) lock = self._lock if self._lock_with_timeout is not None: lock = self._lock_with_timeout() with lock: need_to_initialize = 'fragments' not in self.state[param] if need_to_initialize: self._init_param_state(param, param_group_id, param_id) if self.greedy_grad_copy and not getattr(param, '_disable_greedy_grad_copy', False): self._grad_copy(param) if self.overlap_grad_sync and not getattr(param, '_disable_overlap_grad_sync', False): self._try_start_bucket_grad_sync( params=[param], ignore_last_bucket=need_to_initialize, ) return hook def init_params( self, params: Optional[Iterable[torch.nn.Parameter]] = None, param_sync_dtype: Optional[torch.dtype] = None, **kwargs, ) -> None: """Initialize optimizer state for parameters Initializes FP8 and non-FP8 params separately. """ # Default cases if params is None: params = self.parameters() elif isinstance(params, torch.Tensor): params = [params] # Ignore parameters that have already been initialized params = [param for param in params if "fragments" not in self.state[param]] if not params: return # Initialize FP8 and non-FP8 tensors separately if any(is_float8tensor(param) for param in params): super().init_params( filter(is_float8tensor, params), param_sync_dtype=torch.uint8, **kwargs, ) super().init_params( params, param_sync_dtype=param_sync_dtype, **kwargs, ) def init_params_bucket( self, params: Iterable[torch.nn.Parameter], grad_sync_dtype: Optional[torch.dtype] = None, param_sync_dtype: Optional[torch.dtype] = None, **kwargs, ) -> None: """Initialize optimizer state for parameters in one effective bucket""" # Ignore parameters that have already been initialized if isinstance(params, torch.Tensor): params = [params] params = [param for param in params if "fragments" not in self.state[param]] if not params: return # Initialize parameters with FP32 grads fp32_params = [] remaining_params = [] for param in params: if getattr(param, '_with_fp32_optimizer', False): fp32_params.append(param) else: remaining_params.append(param) params = remaining_params start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( fp32_params, grad_sync_dtype=torch.float32, param_sync_dtype=param_sync_dtype, **kwargs, ) end_bucket_id = len(self.state["buckets"]) fp32_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] # Initialize FP8 parameters fp8_params = [] remaining_params = [] for param in params: if is_float8tensor(param): fp8_params.append(param) else: remaining_params.append(param) params = remaining_params start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( fp8_params, grad_sync_dtype=grad_sync_dtype, param_sync_dtype=torch.uint8, **kwargs, ) end_bucket_id = len(self.state["buckets"]) fp8_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] # Initialize remaining parameters as usual normal_buckets = [] start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( params, grad_sync_dtype=grad_sync_dtype, param_sync_dtype=param_sync_dtype, **kwargs, ) end_bucket_id = len(self.state["buckets"]) normal_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] def add_param_to_bucket( param: torch.nn.Parameter, bucket: self.StateBucket, ) -> None: """Add trivial param fragment to bucket""" param_fragments = self.state[param]["fragments"] param_group_id = param_fragments[0].param_group_id param_id = param_fragments[0].param_id bucket_id = bucket.fragments[0].bucket_id param_size = param.numel() bucket_size = bucket.bucket_size fragment = self.ParameterFragment( param_group_id=param_group_id, param_id=param_id, bucket_id=bucket_id, param_range=(param_size, param_size), bucket_range=(bucket_size, bucket_size), in_local_shard=False, shard_range=None, shard_bucket_range=None, shard_param_range=None, ) param_fragments.append(fragment) bucket.fragments.append(fragment) # Make sure all added buckets depend on provided params for bucket in fp32_buckets: for param in itertools.chain(fp8_params, params): add_param_to_bucket(param, bucket) for bucket in fp8_buckets: for param in itertools.chain(fp32_params, params): add_param_to_bucket(param, bucket) for bucket in normal_buckets: for param in itertools.chain(fp32_params, fp8_params): add_param_to_bucket(param, bucket) def _init_param_state( self, param: torch.nn.Parameter, param_group_id: int, param_id: int, param_sync_dtype: Optional[torch.dtype] = None, **kwargs, ) -> None: """Initialize optimizer state for a parameter Initializing the master weights requires slicing a flattened view of the param. FP8 tensors do not handle these operations gracefully, so we hack around it by explicitly casting to FP32. """ # Initialize non-FP8 params as usual if not is_float8tensor(param): super()._init_param_state( param, param_group_id, param_id, param_sync_dtype=param_sync_dtype, **kwargs, ) # Return immediately if already initialized if "fragments" in self.state[param]: return # Initialize with FP32 copy of param fp32_param = param.float() super()._init_param_state( fp32_param, param_group_id, param_id, param_sync_dtype=torch.uint8, **kwargs, ) self.state[param].update(self.state[fp32_param]) del self.state[fp32_param] @torch.no_grad() def init_param_buffer(self) -> None: """Allocate contiguous buffers for param buckets For FP8 params, the FP8 data buffer is made a view into a contiguous buffer. """ # Make sure all params are initialized self.contiguous_param_buffer = True self.init_params() # Construct param buffers buffer_sizes = collections.defaultdict(lambda: 0) for bucket in self.state["buckets"]: dtypes = bucket.dtypes() buffer_sizes[dtypes] = max(bucket.contiguous_buffer_offset + bucket.bucket_size, buffer_sizes[dtypes]) for dtypes, buffer_size in buffer_sizes.items(): _, _, param_sync_dtype = dtypes if getattr(self, "nccl_ub", False): if not nccl_allocator: raise RuntimeError("NCCL allocator importing failed but nccl ub is still requested") with nccl_allocator.nccl_mem(): self._param_buffers[dtypes] = torch.zeros( [buffer_size], dtype=param_sync_dtype, device=self.device ) else: self._param_buffers[dtypes] = torch.zeros([buffer_size], dtype=param_sync_dtype, device=self.device) # Figure out corresponding positions in params and param buffer params = list(self.parameters()) param_flat_views = [] param_buffer_views = [] for i, param in enumerate(params): fragment = self.state[param]["fragments"][0] bucket_id = fragment.bucket_id bucket = self.state["buckets"][bucket_id] param_size = param.numel() bucket_start, _ = fragment.bucket_range buffer_offset = bucket.contiguous_buffer_offset buffer_start = buffer_offset + bucket_start buffer_end = buffer_start + param_size param_buffer = self._param_buffers[bucket.dtypes()] param_buffer_view = param_buffer[buffer_start:buffer_end].detach() if param_buffer_view.device != param.device: raise RuntimeError( "Attempted to change a parameter with device={param.device} " f"into a buffer view with device={param_buffer_view.device}" ) if is_float8tensor(param): param_flat_views.append(param._data.detach().view(-1)) else: if param_buffer_view.dtype != param.dtype: raise RuntimeError( f"Attempted to change a parameter with dtype={param.dtype} " f"into a buffer view with dtype={param_buffer_view.dtype}" ) if param.is_contiguous(memory_format=torch.channels_last): param = param.permute(0, 2, 3, 1) param_flat_views.append(param.detach().view(-1)) param_buffer_views.append(param_buffer_view) # Copy values into param buffer _multi_tensor_copy( param_flat_views, param_buffer_views, dummy_overflow_buf=self._dummy_overflow_buf, ) # Make all params a view into the param buffer for param, buffer_view in zip(params, param_buffer_views): if is_float8tensor(param): param._data = buffer_view.view(param.size()) else: # Preserve memory format for param here, i.e. NHWC tensors # `param.data.set_()` failed to change storage. # `param.set_()` invalidates bprop hook. param.data = torch.as_strided( buffer_view, param.size(), param.stride(), storage_offset=buffer_view.storage_offset(), ) def try_grad_sync(self, params: Iterable[torch.nn.Parameter]) -> None: """Attempt to launch gradient synchronization""" def is_grad_copy_enabled(param: torch.nn.Parameter) -> bool: return not getattr(param, '_disable_greedy_grad_copy', False) and not getattr( param, '_disable_overlap_grad_sync', False ) params = list(filter(is_grad_copy_enabled, params)) for p in params: self._grad_copy(p) self._try_start_bucket_grad_sync(params=params) def zero_grad(self, *args, **kwargs) -> None: """Clear parameter gradients""" super().zero_grad(*args, **kwargs) # Reset main grads if self.contiguous_grad_buffer: for param in self.parameters(): with _disable_pre_forward_hook(param): param.main_grad = self.grad_buffer_view(param) def grad_norm( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None, norm_type: float = 2.0, force: bool = False, ) -> torch.Tensor: """L2 norm of parameter gradients""" assert norm_type == 2 if parameters is not None: # Make sure we can access iterable multiple times parameters = list(parameters) # Compute grad norm if force or self._grad_norm is None: # Compute norm of local gradients for distributed optimizer grad_norm_sq = self._local_grad_norm(parameters=parameters, norm_type=norm_type) if self.redundant_size > 1: grad_norm_sq /= self.redundant_size # Sum over all procs to get grad norm torch.distributed.all_reduce( grad_norm_sq, op=torch.distributed.ReduceOp.SUM, ) self._grad_norm = grad_norm_sq.sqrt() # Use cached grad norm return super().grad_norm() @torch.no_grad() def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.ParameterFragment]) -> None: """Update parameter fragments with values from parameter buckets For FP8 params, values are copied directly into the FP8 data buffer. """ # Figure out corresponding positions in param buckets and params buffers_in = [] buffers_out = [] fragments = list(fragments) for fragment in fragments: # Check if fragment needs to be updated bucket_id = fragment.bucket_id bucket_start, bucket_end = fragment.bucket_range param_start, param_end = fragment.param_range if param_end <= param_start or bucket_id not in self._params_buckets: continue # Corresponding positions in bucket and param param_bucket = self._params_buckets[bucket_id] param = self.parameter(fragment) buffer_in = param_bucket.params_bucket[bucket_start:bucket_end] if is_float8tensor(param): # Copy into FP8 params's data buffer assert ( param_bucket.params_bucket.dtype == torch.uint8 ), "Expected FP8 params to perform param sync in UINT8" buffer_out = param._data.view(-1)[param_start:param_end] buffers_in.append(buffer_in) buffers_out.append(buffer_out) elif torch.is_floating_point(buffer_in) and torch.is_floating_point(param): # Conv with NHWC layout, i.e. shape (N, C, H, W) and stride # (HWC, 1, WC, C), can't `.view(-1)`. Here to turn it to # tensor with shape (N, H, W, C) and stride (HWC, WC, C, 1). # Note: https://github.com/NVIDIA/apex/pull/1794 if param.is_contiguous(memory_format=torch.channels_last): param = param.permute(0, 2, 3, 1) # Cast between floating-point dtypes buffer_out = param.detach().view(-1)[param_start:param_end] buffers_in.append(buffer_in) buffers_out.append(buffer_out) else: # Copy most significant bytes for non-floating-point # dtypes # Note: Assume dtypes are little-endian buffer_out = param.detach().view(-1)[param_start:param_end] in_bytes = buffer_in.unsqueeze(-1).view(torch.uint8) out_bytes = buffer_out.unsqueeze(-1).view(torch.uint8) copy_size = min(in_bytes.size(-1), out_bytes.size(-1)) buffers_in.append(in_bytes[..., -copy_size:]) buffers_out.append(out_bytes[..., -copy_size:]) if copy_size < out_bytes.size(-1): out_bytes[..., :-copy_size].zero_() # Copy data from parameter buckets to parameters _multi_tensor_copy( buffers_in, buffers_out, dummy_overflow_buf=self._dummy_overflow_buf, ) # Update transpose caches params = set(self.parameter(fragment) for fragment in fragments) for param in params: if is_float8tensor(param): param._reset_caches() @torch.no_grad() def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedAdam.ParameterBucket]) -> None: """Make sure local shards of parameters are in expected datatypes For FP8 params, FP32 values are cast into FP8 using per-param scaling factors and per-param amaxes are computed and reduced. """ # Just call base class function if there are no FP8 tensors num_fp8_params = sum(1 for param in self.parameters() if is_float8tensor(param)) if num_fp8_params == 0: super()._check_params_shard_dtypes(params_buckets) return # Cast local data to FP8 fp8_params_shards = dict() for bucket_id, param_bucket in params_buckets.items(): state_bucket = self.state["buckets"][bucket_id] if state_bucket.param_sync_dtype != torch.uint8: continue # Initialize FP8 buffer for param sync params_shard = param_bucket.params_shard if self.contiguous_param_buffer: shard_size = state_bucket.shard_size buffer_offset = state_bucket.contiguous_buffer_offset buffer_start = buffer_offset + self.distributed_rank * shard_size buffer_end = buffer_start + shard_size param_buffer = self._param_buffers[state_bucket.dtypes()] fp8_params_shard = param_buffer[buffer_start:buffer_end] else: fp8_params_shard = torch.empty_like(params_shard, dtype=torch.uint8) param_bucket.params_shard = fp8_params_shard # Cast param fragments to FP8 for fragment in self.state["buckets"][bucket_id].fragments: param = self.parameter(fragment) if not is_float8tensor(param): continue if not fragment.in_local_shard: continue shard_start, shard_end = fragment.shard_range if shard_end <= shard_start: continue shard_range = slice(shard_start, shard_end) quantize_param_fragment( params_shard[shard_range], out=fp8_params_shard[shard_range], param=param, ) # Update FP8 scaling factors when all buckets have processed if getattr(self, "_check_params_shard_dtypes_progress", None) is None: self._check_params_shard_dtypes_progress = [] self._check_params_shard_dtypes_progress.extend(params_buckets.keys()) if len(self._check_params_shard_dtypes_progress) == len(self.state["buckets"]): assert len(set(self._check_params_shard_dtypes_progress)) == len(self.state["buckets"]) # FP8 scaling factors amaxes = [] scales = [] scale_invs = [] i = -1 for param in self.parameters(): if not is_float8tensor(param): continue i += 1 scale, amax = get_fp8_scale_and_amax(param) amaxes.append(amax.view(1)) scales.append(scale.view(1)) scale_invs.append(param._scale_inv.view(1)) # Update cached scale-inverses packed_scales = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) packed_scale_views = [packed_scales[i].view(1) for i in range(num_fp8_params)] _multi_tensor_copy( scales, packed_scale_views, dummy_overflow_buf=self._dummy_overflow_buf, ) torch.reciprocal(packed_scales, out=packed_scales) _multi_tensor_copy( packed_scale_views, scale_invs, dummy_overflow_buf=self._dummy_overflow_buf, ) # Reduce amaxes # Note: Assume each param has a separate amax packed_amaxes = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)] _multi_tensor_copy( amaxes, packed_amax_views, dummy_overflow_buf=self._dummy_overflow_buf, ) torch.distributed.all_reduce( packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.distributed_process_group, ) _multi_tensor_copy( packed_amax_views, amaxes, dummy_overflow_buf=self._dummy_overflow_buf, ) # Reset self._check_params_shard_dtypes_progress = None # Handle any remaining dtype conversions super()._check_params_shard_dtypes(params_buckets) def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None): """Create sharded state dict""" if optimizer_state_dict is None: optimizer_state_dict = self.state_dict() id_to_sharded_param_map = get_param_id_to_sharded_param_map( model_sharded_state_dict=model_sharded_state_dict, optim_params_iter=self.parameters(), ) # Convert state step = optimizer_state_dict['state'].pop('step') state_dict_format = optimizer_state_dict.pop('format', None) optim_state_to_sharding_state(optimizer_state_dict, id_to_sharded_param_map) optimizer_state_dict['state']['step'] = step if state_dict_format is not None: optimizer_state_dict['format'] = state_dict_format def rename_fp32_params(x): if isinstance(x, ShardedTensor) and x.key.startswith('optimizer.state.param'): x.key = x.key.replace('optimizer.state.param', 'optimizer.state.fp32_param') return x dict_list_map_inplace(rename_fp32_params, optimizer_state_dict) return optimizer_state_dict