File size: 6,203 Bytes
5000658 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Sequence, Union
import numpy as np
# isort: off
import torch
import tensorrt as trt
# isort: on
from ._common import default_net
from ._utils import (copy_torch_to_numpy, np_dtype_to_trt, str_dtype_to_trt,
torch_to_numpy, trt_dtype_to_np, trt_dtype_to_torch)
from .functional import Tensor, constant
from .logger import logger
class Parameter:
_DEFAULT_DTYPE = trt.DataType.FLOAT
def __init__(self,
value: Optional[Union[np.ndarray, torch.Tensor]] = None,
shape: Sequence[int] = None,
dtype: Union[str, trt.DataType] = None,
is_buffer: bool = False):
if dtype is None:
logger.warning(
f'Parameter dtype is None, using default dtype: {self._DEFAULT_DTYPE}, it is recommended to always specify dtype explicitly'
)
dtype = self._DEFAULT_DTYPE if dtype is None else dtype
if isinstance(dtype, str):
dtype = str_dtype_to_trt(dtype)
self._dtype: trt.DataType = dtype
if value is None:
assert isinstance(shape, (
list,
tuple)), f"shape must be list or tuple, receive {(type(shape))}"
self._shape = tuple(shape)
self._value = None
else:
self._shape = value.shape
self._value = self._regularize_value(value)
self.is_buffer = is_buffer
@property
def shape(self):
return self._shape
@property
def dtype(self):
return self._dtype
@property
def value(self) -> Tensor:
if (self._value is not None and isinstance(self._value, np.ndarray)
and self._value.flags['C_CONTIGUOUS']):
self._value = constant(self._value)
elif self._value is None or isinstance(self._value, np.ndarray):
dtype = trt_dtype_to_np(self.dtype)
ndarray = np.empty(self.shape, dtype)
value = self._value
self._value = constant(ndarray)
default_net()._register_unfilled_weights(self._value.producer.name,
ndarray, value)
return self._value
@classmethod
def xavier_init(cls, weights: np.ndarray):
shape = weights.shape
dtype = np_dtype_to_trt(weights.dtype)
if len(shape) == 2:
# Xavier initialization see https://paperswithcode.com/method/xavier-initialization
v_range = math.sqrt(6) / math.sqrt(shape[0] + shape[1])
else:
v_range = 0.1
if dtype == trt.DataType.INT8:
upper = math.ceil(128 * v_range)
value = torch.randint(-upper,
upper, (shape),
dtype=trt_dtype_to_torch(dtype),
device='cuda')
# value ~ U[int(-128 * v_range), int(128 * v_range)]
elif dtype == trt.DataType.FP8:
value = torch.rand((shape), device='cuda') * 2 - 1
# value ~ U[-v_range, v_range]
value = value * v_range
value = value.to(trt_dtype_to_torch(dtype))
else:
value = torch.rand(
(shape), dtype=trt_dtype_to_torch(dtype), device='cuda') * 2 - 1
# value ~ U[-v_range, v_range]
value = value * v_range
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
copy_torch_to_numpy(value, weights)
def is_inited(self) -> bool:
return self._value is not None
@property
def raw_value(self) -> np.ndarray:
if self._value is None:
dtype = trt_dtype_to_np(self.dtype)
self._value = np.empty(self.shape, dtype)
Parameter.xavier_init(self._value)
assert isinstance(
self._value, np.ndarray
), "Must be np.ndarray. Proper usage: get parameter.raw_value before getting parameter.value"
return self._value
@value.setter
def value(self, v: Union[np.ndarray, torch.Tensor]):
v = self._regularize_value(v)
if v.shape != self.shape and v.ndim == 0 and max(self.shape) == 1:
# convert the scalar into a tensor which each dim is 1.
v = v.reshape(self.shape)
assert v.shape == self.shape, \
f'The value updated is not the same shape as the original. ' \
f'Updated: {v.shape}, original: {self.shape}'
dtype = np_dtype_to_trt(v.dtype)
if self.dtype != dtype:
logger.warning(
f"Parameter was initialized as {self.dtype} but set to {dtype}")
self._value = v
def set_value_or_dummy(self, v: Union[np.ndarray, torch.Tensor]):
v = self._regularize_value(v)
if v.shape != self._shape:
self.value = np.empty(self._shape, trt_dtype_to_np(self._dtype))
return
self.value = v
def _get_weights(self) -> trt.Weights:
if isinstance(self._value, Tensor):
self._value.producer.__class__ = trt.IConstantLayer
return self._value.producer.weights
else:
return None
def _regularize_value(self, value):
if isinstance(value, np.ndarray):
return value
elif isinstance(value, torch.Tensor):
return torch_to_numpy(value)
raise TypeError(
f'Expected numpy.ndarray or torch.Tensor, got {type(value)}')
|