Spaces:
Sleeping
Sleeping
File size: 5,061 Bytes
c3d0544 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import numpy as np
import torch
def weight_init(shape: tuple, mode: str, fan_in: int, fan_out: int):
"""
Unified routine for initializing weights and biases.
This function provides a unified interface for various weight initialization
strategies like Xavier (Glorot) and Kaiming (He) initializations.
Parameters
----------
shape : tuple
The shape of the tensor to initialize. It could represent weights or biases
of a layer in a neural network.
mode : str
The mode/type of initialization to use. Supported values are:
- "xavier_uniform": Xavier (Glorot) uniform initialization.
- "xavier_normal": Xavier (Glorot) normal initialization.
- "kaiming_uniform": Kaiming (He) uniform initialization.
- "kaiming_normal": Kaiming (He) normal initialization.
fan_in : int
The number of input units in the weight tensor. For convolutional layers,
this typically represents the number of input channels times the kernel height
times the kernel width.
fan_out : int
The number of output units in the weight tensor. For convolutional layers,
this typically represents the number of output channels times the kernel height
times the kernel width.
Returns
-------
torch.Tensor
The initialized tensor based on the specified mode.
Raises
------
ValueError
If the provided `mode` is not one of the supported initialization modes.
"""
if mode == "xavier_uniform":
return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
if mode == "xavier_normal":
return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
if mode == "kaiming_uniform":
return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
if mode == "kaiming_normal":
return np.sqrt(1 / fan_in) * torch.randn(*shape)
raise ValueError(f'Invalid init mode "{mode}"')
def _recursive_property(prop_name: str, prop_type: type, doc: str) -> property:
"""
Property factory that sets the property on a Module ``self`` and
recursively on all submodules.
For ``self``, the property is stored under a semi-private ``_<prop_name>`` attribute
and for submodules the setter is delegated to the ``setattr`` function.
Parameters
----------
prop_name : str
The name of the property.
prop_type : type
The type of the property.
doc : str
The documentation string for the property.
Returns
-------
property
The property object.
"""
def _setter(self, value: Any):
if not isinstance(value, prop_type):
raise TypeError(
f"{prop_name} must be a {prop_type.__name__} value, but got {type(value).__name__}."
)
# Set for self
setattr(self, f"_{prop_name}", value)
# Set for submodules
submodules = iter(self.modules())
next(submodules) # Skip self
for m in submodules:
if hasattr(m, prop_name):
setattr(m, prop_name, value)
def _getter(self):
return getattr(self, f"_{prop_name}")
return property(_getter, _setter, doc=doc)
def _wrapped_property(prop_name: str, wrapped_obj_name: str, doc: str) -> property:
"""
Property factory to define a property on a Module ``self`` that is
wraps another Module in an attribute ``self.<wrapped_obj_name>``. The
property delegates the setter and getter to the wrapped object's.
Parameters
----------
prop_name : str
The name of the property.
wrapped_obj_name : str
The name of the attribute that wraps the other Module.
doc : str
The documentation string for the property.
Returns
-------
property
The property object.
"""
def _setter(self, value: Any):
wrapped_obj = getattr(self, wrapped_obj_name)
if hasattr(wrapped_obj, prop_name):
setattr(wrapped_obj, prop_name, value)
else:
raise AttributeError(f"{prop_name} is not supported by the wrapped model.")
def _getter(self):
wrapped_obj = getattr(self, wrapped_obj_name)
return getattr(wrapped_obj, prop_name)
return property(_getter, _setter, doc=doc)
|