File size: 6,203 Bytes
5000658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Sequence, Union

import numpy as np

# isort: off
import torch
import tensorrt as trt
# isort: on

from ._common import default_net
from ._utils import (copy_torch_to_numpy, np_dtype_to_trt, str_dtype_to_trt,
                     torch_to_numpy, trt_dtype_to_np, trt_dtype_to_torch)
from .functional import Tensor, constant
from .logger import logger


class Parameter:
    _DEFAULT_DTYPE = trt.DataType.FLOAT

    def __init__(self,
                 value: Optional[Union[np.ndarray, torch.Tensor]] = None,
                 shape: Sequence[int] = None,
                 dtype: Union[str, trt.DataType] = None,
                 is_buffer: bool = False):
        if dtype is None:
            logger.warning(
                f'Parameter dtype is None, using default dtype: {self._DEFAULT_DTYPE}, it is recommended to always specify dtype explicitly'
            )
        dtype = self._DEFAULT_DTYPE if dtype is None else dtype
        if isinstance(dtype, str):
            dtype = str_dtype_to_trt(dtype)
        self._dtype: trt.DataType = dtype
        if value is None:
            assert isinstance(shape, (
                list,
                tuple)), f"shape must be list or tuple, receive {(type(shape))}"
            self._shape = tuple(shape)
            self._value = None
        else:
            self._shape = value.shape
            self._value = self._regularize_value(value)
        self.is_buffer = is_buffer

    @property
    def shape(self):
        return self._shape

    @property
    def dtype(self):
        return self._dtype

    @property
    def value(self) -> Tensor:
        if (self._value is not None and isinstance(self._value, np.ndarray)
                and self._value.flags['C_CONTIGUOUS']):
            self._value = constant(self._value)
        elif self._value is None or isinstance(self._value, np.ndarray):
            dtype = trt_dtype_to_np(self.dtype)
            ndarray = np.empty(self.shape, dtype)
            value = self._value
            self._value = constant(ndarray)
            default_net()._register_unfilled_weights(self._value.producer.name,
                                                     ndarray, value)
        return self._value

    @classmethod
    def xavier_init(cls, weights: np.ndarray):
        shape = weights.shape
        dtype = np_dtype_to_trt(weights.dtype)
        if len(shape) == 2:
            # Xavier initialization see https://paperswithcode.com/method/xavier-initialization
            v_range = math.sqrt(6) / math.sqrt(shape[0] + shape[1])
        else:
            v_range = 0.1

        if dtype == trt.DataType.INT8:
            upper = math.ceil(128 * v_range)
            value = torch.randint(-upper,
                                  upper, (shape),
                                  dtype=trt_dtype_to_torch(dtype),
                                  device='cuda')
            # value ~ U[int(-128 * v_range), int(128 * v_range)]
        elif dtype == trt.DataType.FP8:
            value = torch.rand((shape), device='cuda') * 2 - 1
            # value ~ U[-v_range, v_range]
            value = value * v_range
            value = value.to(trt_dtype_to_torch(dtype))
        else:
            value = torch.rand(
                (shape), dtype=trt_dtype_to_torch(dtype), device='cuda') * 2 - 1
            # value ~ U[-v_range, v_range]
            value = value * v_range

        stream = torch.cuda.Stream()
        with torch.cuda.stream(stream):
            copy_torch_to_numpy(value, weights)

    def is_inited(self) -> bool:
        return self._value is not None

    @property
    def raw_value(self) -> np.ndarray:
        if self._value is None:
            dtype = trt_dtype_to_np(self.dtype)
            self._value = np.empty(self.shape, dtype)
            Parameter.xavier_init(self._value)
        assert isinstance(
            self._value, np.ndarray
        ), "Must be np.ndarray. Proper usage: get parameter.raw_value before getting parameter.value"
        return self._value

    @value.setter
    def value(self, v: Union[np.ndarray, torch.Tensor]):
        v = self._regularize_value(v)

        if v.shape != self.shape and v.ndim == 0 and max(self.shape) == 1:
            # convert the scalar into a tensor which each dim is 1.
            v = v.reshape(self.shape)

        assert v.shape == self.shape, \
            f'The value updated is not the same shape as the original. ' \
            f'Updated: {v.shape}, original: {self.shape}'
        dtype = np_dtype_to_trt(v.dtype)
        if self.dtype != dtype:
            logger.warning(
                f"Parameter was initialized as {self.dtype} but set to {dtype}")
        self._value = v

    def set_value_or_dummy(self, v: Union[np.ndarray, torch.Tensor]):
        v = self._regularize_value(v)
        if v.shape != self._shape:
            self.value = np.empty(self._shape, trt_dtype_to_np(self._dtype))
            return

        self.value = v

    def _get_weights(self) -> trt.Weights:
        if isinstance(self._value, Tensor):
            self._value.producer.__class__ = trt.IConstantLayer
            return self._value.producer.weights
        else:
            return None

    def _regularize_value(self, value):
        if isinstance(value, np.ndarray):
            return value
        elif isinstance(value, torch.Tensor):
            return torch_to_numpy(value)
        raise TypeError(
            f'Expected numpy.ndarray or torch.Tensor, got {type(value)}')