ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import numpy as np
import torch
def weight_init(shape: tuple, mode: str, fan_in: int, fan_out: int):
"""
Unified routine for initializing weights and biases.
This function provides a unified interface for various weight initialization
strategies like Xavier (Glorot) and Kaiming (He) initializations.
Parameters
----------
shape : tuple
The shape of the tensor to initialize. It could represent weights or biases
of a layer in a neural network.
mode : str
The mode/type of initialization to use. Supported values are:
- "xavier_uniform": Xavier (Glorot) uniform initialization.
- "xavier_normal": Xavier (Glorot) normal initialization.
- "kaiming_uniform": Kaiming (He) uniform initialization.
- "kaiming_normal": Kaiming (He) normal initialization.
fan_in : int
The number of input units in the weight tensor. For convolutional layers,
this typically represents the number of input channels times the kernel height
times the kernel width.
fan_out : int
The number of output units in the weight tensor. For convolutional layers,
this typically represents the number of output channels times the kernel height
times the kernel width.
Returns
-------
torch.Tensor
The initialized tensor based on the specified mode.
Raises
------
ValueError
If the provided `mode` is not one of the supported initialization modes.
"""
if mode == "xavier_uniform":
return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
if mode == "xavier_normal":
return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
if mode == "kaiming_uniform":
return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
if mode == "kaiming_normal":
return np.sqrt(1 / fan_in) * torch.randn(*shape)
raise ValueError(f'Invalid init mode "{mode}"')
def _recursive_property(prop_name: str, prop_type: type, doc: str) -> property:
"""
Property factory that sets the property on a Module ``self`` and
recursively on all submodules.
For ``self``, the property is stored under a semi-private ``_<prop_name>`` attribute
and for submodules the setter is delegated to the ``setattr`` function.
Parameters
----------
prop_name : str
The name of the property.
prop_type : type
The type of the property.
doc : str
The documentation string for the property.
Returns
-------
property
The property object.
"""
def _setter(self, value: Any):
if not isinstance(value, prop_type):
raise TypeError(
f"{prop_name} must be a {prop_type.__name__} value, but got {type(value).__name__}."
)
# Set for self
setattr(self, f"_{prop_name}", value)
# Set for submodules
submodules = iter(self.modules())
next(submodules) # Skip self
for m in submodules:
if hasattr(m, prop_name):
setattr(m, prop_name, value)
def _getter(self):
return getattr(self, f"_{prop_name}")
return property(_getter, _setter, doc=doc)
def _wrapped_property(prop_name: str, wrapped_obj_name: str, doc: str) -> property:
"""
Property factory to define a property on a Module ``self`` that is
wraps another Module in an attribute ``self.<wrapped_obj_name>``. The
property delegates the setter and getter to the wrapped object's.
Parameters
----------
prop_name : str
The name of the property.
wrapped_obj_name : str
The name of the attribute that wraps the other Module.
doc : str
The documentation string for the property.
Returns
-------
property
The property object.
"""
def _setter(self, value: Any):
wrapped_obj = getattr(self, wrapped_obj_name)
if hasattr(wrapped_obj, prop_name):
setattr(wrapped_obj, prop_name, value)
else:
raise AttributeError(f"{prop_name} is not supported by the wrapped model.")
def _getter(self):
wrapped_obj = getattr(self, wrapped_obj_name)
return getattr(wrapped_obj, prop_name)
return property(_getter, _setter, doc=doc)