Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Any | |
| import numpy as np | |
| import torch | |
| def weight_init(shape: tuple, mode: str, fan_in: int, fan_out: int): | |
| """ | |
| Unified routine for initializing weights and biases. | |
| This function provides a unified interface for various weight initialization | |
| strategies like Xavier (Glorot) and Kaiming (He) initializations. | |
| Parameters | |
| ---------- | |
| shape : tuple | |
| The shape of the tensor to initialize. It could represent weights or biases | |
| of a layer in a neural network. | |
| mode : str | |
| The mode/type of initialization to use. Supported values are: | |
| - "xavier_uniform": Xavier (Glorot) uniform initialization. | |
| - "xavier_normal": Xavier (Glorot) normal initialization. | |
| - "kaiming_uniform": Kaiming (He) uniform initialization. | |
| - "kaiming_normal": Kaiming (He) normal initialization. | |
| fan_in : int | |
| The number of input units in the weight tensor. For convolutional layers, | |
| this typically represents the number of input channels times the kernel height | |
| times the kernel width. | |
| fan_out : int | |
| The number of output units in the weight tensor. For convolutional layers, | |
| this typically represents the number of output channels times the kernel height | |
| times the kernel width. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| The initialized tensor based on the specified mode. | |
| Raises | |
| ------ | |
| ValueError | |
| If the provided `mode` is not one of the supported initialization modes. | |
| """ | |
| if mode == "xavier_uniform": | |
| return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) | |
| if mode == "xavier_normal": | |
| return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) | |
| if mode == "kaiming_uniform": | |
| return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) | |
| if mode == "kaiming_normal": | |
| return np.sqrt(1 / fan_in) * torch.randn(*shape) | |
| raise ValueError(f'Invalid init mode "{mode}"') | |
| def _recursive_property(prop_name: str, prop_type: type, doc: str) -> property: | |
| """ | |
| Property factory that sets the property on a Module ``self`` and | |
| recursively on all submodules. | |
| For ``self``, the property is stored under a semi-private ``_<prop_name>`` attribute | |
| and for submodules the setter is delegated to the ``setattr`` function. | |
| Parameters | |
| ---------- | |
| prop_name : str | |
| The name of the property. | |
| prop_type : type | |
| The type of the property. | |
| doc : str | |
| The documentation string for the property. | |
| Returns | |
| ------- | |
| property | |
| The property object. | |
| """ | |
| def _setter(self, value: Any): | |
| if not isinstance(value, prop_type): | |
| raise TypeError( | |
| f"{prop_name} must be a {prop_type.__name__} value, but got {type(value).__name__}." | |
| ) | |
| # Set for self | |
| setattr(self, f"_{prop_name}", value) | |
| # Set for submodules | |
| submodules = iter(self.modules()) | |
| next(submodules) # Skip self | |
| for m in submodules: | |
| if hasattr(m, prop_name): | |
| setattr(m, prop_name, value) | |
| def _getter(self): | |
| return getattr(self, f"_{prop_name}") | |
| return property(_getter, _setter, doc=doc) | |
| def _wrapped_property(prop_name: str, wrapped_obj_name: str, doc: str) -> property: | |
| """ | |
| Property factory to define a property on a Module ``self`` that is | |
| wraps another Module in an attribute ``self.<wrapped_obj_name>``. The | |
| property delegates the setter and getter to the wrapped object's. | |
| Parameters | |
| ---------- | |
| prop_name : str | |
| The name of the property. | |
| wrapped_obj_name : str | |
| The name of the attribute that wraps the other Module. | |
| doc : str | |
| The documentation string for the property. | |
| Returns | |
| ------- | |
| property | |
| The property object. | |
| """ | |
| def _setter(self, value: Any): | |
| wrapped_obj = getattr(self, wrapped_obj_name) | |
| if hasattr(wrapped_obj, prop_name): | |
| setattr(wrapped_obj, prop_name, value) | |
| else: | |
| raise AttributeError(f"{prop_name} is not supported by the wrapped model.") | |
| def _getter(self): | |
| wrapped_obj = getattr(self, wrapped_obj_name) | |
| return getattr(wrapped_obj, prop_name) | |
| return property(_getter, _setter, doc=doc) | |