# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any import numpy as np import torch def weight_init(shape: tuple, mode: str, fan_in: int, fan_out: int): """ Unified routine for initializing weights and biases. This function provides a unified interface for various weight initialization strategies like Xavier (Glorot) and Kaiming (He) initializations. Parameters ---------- shape : tuple The shape of the tensor to initialize. It could represent weights or biases of a layer in a neural network. mode : str The mode/type of initialization to use. Supported values are: - "xavier_uniform": Xavier (Glorot) uniform initialization. - "xavier_normal": Xavier (Glorot) normal initialization. - "kaiming_uniform": Kaiming (He) uniform initialization. - "kaiming_normal": Kaiming (He) normal initialization. fan_in : int The number of input units in the weight tensor. For convolutional layers, this typically represents the number of input channels times the kernel height times the kernel width. fan_out : int The number of output units in the weight tensor. For convolutional layers, this typically represents the number of output channels times the kernel height times the kernel width. Returns ------- torch.Tensor The initialized tensor based on the specified mode. Raises ------ ValueError If the provided `mode` is not one of the supported initialization modes. """ if mode == "xavier_uniform": return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) if mode == "xavier_normal": return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) if mode == "kaiming_uniform": return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) if mode == "kaiming_normal": return np.sqrt(1 / fan_in) * torch.randn(*shape) raise ValueError(f'Invalid init mode "{mode}"') def _recursive_property(prop_name: str, prop_type: type, doc: str) -> property: """ Property factory that sets the property on a Module ``self`` and recursively on all submodules. For ``self``, the property is stored under a semi-private ``_`` attribute and for submodules the setter is delegated to the ``setattr`` function. Parameters ---------- prop_name : str The name of the property. prop_type : type The type of the property. doc : str The documentation string for the property. Returns ------- property The property object. """ def _setter(self, value: Any): if not isinstance(value, prop_type): raise TypeError( f"{prop_name} must be a {prop_type.__name__} value, but got {type(value).__name__}." ) # Set for self setattr(self, f"_{prop_name}", value) # Set for submodules submodules = iter(self.modules()) next(submodules) # Skip self for m in submodules: if hasattr(m, prop_name): setattr(m, prop_name, value) def _getter(self): return getattr(self, f"_{prop_name}") return property(_getter, _setter, doc=doc) def _wrapped_property(prop_name: str, wrapped_obj_name: str, doc: str) -> property: """ Property factory to define a property on a Module ``self`` that is wraps another Module in an attribute ``self.``. The property delegates the setter and getter to the wrapped object's. Parameters ---------- prop_name : str The name of the property. wrapped_obj_name : str The name of the attribute that wraps the other Module. doc : str The documentation string for the property. Returns ------- property The property object. """ def _setter(self, value: Any): wrapped_obj = getattr(self, wrapped_obj_name) if hasattr(wrapped_obj, prop_name): setattr(wrapped_obj, prop_name, value) else: raise AttributeError(f"{prop_name} is not supported by the wrapped model.") def _getter(self): wrapped_obj = getattr(self, wrapped_obj_name) return getattr(wrapped_obj, prop_name) return property(_getter, _setter, doc=doc)