ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
import torch
from torch import nn
try:
import transformer_engine.pytorch as te
TE_AVAILABLE = True
except ImportError:
TE_AVAILABLE = False
def remove_extra_state_hook_for_torch(
module: nn.Module,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: list,
unexpected_keys: list,
error_msgs: list,
) -> None:
"""
Pre-hook to remove Transformer Engine's extra state from the state_dict when loading into a PyTorch LayerNorm.
This function scans the state_dict for any keys that match the pattern '{prefix}norm._extra_state'
and removes them. These keys are specific to Transformer Engine's LayerNorm and are not needed
(and may cause errors) when loading into a standard PyTorch LayerNorm.
Args:
module (nn.Module): The module into which the state_dict is being loaded.
state_dict (dict): The state dictionary being loaded.
prefix (str): The prefix for parameters in this module.
local_metadata (dict): Metadata for this module.
strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this module's state_dict function.
missing_keys (list): List of missing keys.
unexpected_keys (list): List of unexpected keys.
error_msgs (list): List of error messages.
"""
# Go through the state dict, and for any keys that have
# prefix + "norm._extra_state", remove those.
# They are extra from transformer engine and not needed in the
# torch layernorm.
keys_to_remove = [
key for key in state_dict if key.startswith(prefix + "_extra_state")
]
for key in keys_to_remove:
del state_dict[key]
def ignore_missing_extra_state_key(
module: nn.Module, incompatible_keys: torch.nn.modules.module._IncompatibleKeys
) -> None:
"""
Post-hook to ignore missing 'ln.norm._extra_state' key when loading state_dict.
This function removes 'ln.norm._extra_state' from the list of missing keys in
the IncompatibleKeys object. This is useful when loading a checkpoint saved
from a Transformer Engine LayerNorm into a PyTorch LayerNorm, where this extra
state is not present or needed.
Args:
module (nn.Module): The module into which the state_dict is being loaded.
incompatible_keys: An object with a 'missing_keys' attribute (typically torch.nn.modules.module._IncompatibleKeys).
"""
# Remove 'ln.norm._extra_state' from the missing keys:
problem_key = "ln._extra_state"
if problem_key in incompatible_keys.missing_keys:
incompatible_keys.missing_keys.remove(problem_key)
def get_layer_norm_class() -> nn.Module:
"""
Dynamically pick the layer norm provider based on availability of transformer engine.
If transformer engine is available, it will use the transformer engine implementation of
LayerNorm. Otherwise, it will use the pytorch implementation of LayerNorm.
Override the default behavior by setting the PHYSICSNEMO_FORCE_TE environment variable.
"""
# This is to allow users to force the use of TE or pytorch layer norm
force_te_setting = os.environ.get("PHYSICSNEMO_FORCE_TE")
te_available = (
TE_AVAILABLE # make a local copy to avoid changing the global variable
)
# Can't use transformer engine without cuda:
if not torch.cuda.is_available():
te_available = False
# Let the users force the setting no matter what:
if force_te_setting is not None:
if force_te_setting.lower() == "true" or force_te_setting.lower() == "1":
te_available = True
elif force_te_setting.lower() == "false" or force_te_setting.lower() == "0":
te_available = False
else:
# In this scenario, the variable PHYSICSNEMO_FORCE_TE was set, but not
# to a value we expect. Emit a warning:
warnings.warn(
f"The PHYSICSNEMO_FORCE_TE environment variable was set to an invalid value. "
f"Expected 'True' or 'False', but got '{force_te_setting}'. "
"Ignoring the variable and using the default behavior.",
UserWarning,
stacklevel=2,
)
if te_available:
base = te.LayerNorm
else:
base = nn.LayerNorm
class LayerNorm(base):
"""
Wrapper around layer norm utilities.
This class will default to using the transformer engine implementation of
LayerNorm - it is significantly faster in the backwards pass.
If transformer engine is not available, it will fall back to the
pytorch implementation of LayerNorm.
Additionally, this class registers pre or post hooks to allow you to
train with / without transformer engine, and run inference
with / without transformer engine.
.. note::
Transformer engine adds additional state parameters that affect
fp8 stability. **Do NOT** switch from transformer engine to pytorch
or from pytorch to transformer engine with a checkpoint if you
are using fp8 precision in the layer norm regions.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if te_available:
self.register_load_state_dict_post_hook(ignore_missing_extra_state_key)
else:
self.register_load_state_dict_pre_hook(
remove_extra_state_hook_for_torch
)
return LayerNorm
LayerNorm = get_layer_norm_class()