File size: 25,571 Bytes
5000658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union

import numpy as np
import tensorrt as trt

from .._common import default_net, default_trtnet
from .._utils import str_dtype_to_np, str_dtype_to_trt
from ..functional import (Tensor, _add_plugin_info, _create_tensor, cast, clip,
                          constant, matmul, repeat_interleave, round)
from ..plugin import TRT_LLM_PLUGIN_NAMESPACE
from .mode import QuantMode


def smooth_quant_gemm(input: Tensor, weights: Tensor, scales_a: Tensor,
                      scales_b: Tensor, per_token_scaling: bool,
                      per_channel_scaling: bool) -> Tensor:
    if not default_net().plugin_config.smooth_quant_gemm_plugin:
        raise TypeError("Smooth Quant GEMM is only supported with plugin")
    else:
        plg_creator = trt.get_plugin_registry().get_plugin_creator(
            'SmoothQuantGemm', '1', TRT_LLM_PLUGIN_NAMESPACE)
        assert plg_creator is not None

        per_channel_scaling = 1 if per_channel_scaling else 0
        per_channel_scaling = trt.PluginField(
            "has_per_channel_scaling",
            np.array(per_channel_scaling, dtype=np.int32),
            trt.PluginFieldType.INT32)

        per_token_scaling = 1 if per_token_scaling else 0
        per_token_scaling = trt.PluginField(
            "has_per_token_scaling", np.array(per_token_scaling,
                                              dtype=np.int32),
            trt.PluginFieldType.INT32)

        p_dtype = default_net().plugin_config.smooth_quant_gemm_plugin
        pf_type = trt.PluginField(
            "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
            trt.PluginFieldType.INT32)

        pfc = trt.PluginFieldCollection(
            [per_channel_scaling, per_token_scaling, pf_type])
        gemm_plug = plg_creator.create_plugin("sq_gemm", pfc)
        plug_inputs = [
            input.trt_tensor, weights.trt_tensor, scales_a.trt_tensor,
            scales_b.trt_tensor
        ]
        layer = default_trtnet().add_plugin_v2(plug_inputs, gemm_plug)
        _add_plugin_info(layer, plg_creator, "sq_gemm", pfc)
        if not default_net().strongly_typed:
            layer.get_input(0).set_dynamic_range(-127, 127)
            layer.get_input(1).set_dynamic_range(-127, 127)
        return _create_tensor(layer.get_output(0), layer)


def fp8_rowwise_gemm(input: Tensor, weights: Tensor, scales_a: Tensor,
                     scales_b: Tensor, per_token_scaling: bool,
                     per_channel_scaling: bool) -> Tensor:
    if not default_net().plugin_config.fp8_rowwise_gemm_plugin:
        raise TypeError("Fp8 Rowwise GEMM is only supported with plugin")
    else:
        plg_creator = trt.get_plugin_registry().get_plugin_creator(
            'Fp8RowwiseGemm', '1', TRT_LLM_PLUGIN_NAMESPACE)
        assert plg_creator is not None

        per_channel_scaling = 1 if per_channel_scaling else 0
        per_channel_scaling = trt.PluginField(
            "has_per_channel_scaling",
            np.array(per_channel_scaling, dtype=np.int32),
            trt.PluginFieldType.INT32)

        per_token_scaling = 1 if per_token_scaling else 0
        per_token_scaling = trt.PluginField(
            "has_per_token_scaling", np.array(per_token_scaling,
                                              dtype=np.int32),
            trt.PluginFieldType.INT32)

        p_dtype = default_net().plugin_config.fp8_rowwise_gemm_plugin
        pf_type = trt.PluginField(
            "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
            trt.PluginFieldType.INT32)

        pfc = trt.PluginFieldCollection(
            [per_channel_scaling, per_token_scaling, pf_type])
        gemm_plug = plg_creator.create_plugin("fp8_rowwise_gemm", pfc)
        plug_inputs = [
            input.trt_tensor, weights.trt_tensor, scales_a.trt_tensor,
            scales_b.trt_tensor
        ]
        layer = default_trtnet().add_plugin_v2(plug_inputs, gemm_plug)
        _add_plugin_info(layer, plg_creator, "fp8_rowwise_gemm", pfc)
        if not default_net().strongly_typed:
            layer.get_input(0).set_dynamic_range(-448, 448)
            layer.get_input(1).set_dynamic_range(-448, 448)
        return _create_tensor(layer.get_output(0), layer)


def weight_only_quant_matmul(input: Tensor,
                             weights: Tensor,
                             scales: Tensor,
                             weightTypeId: int,
                             dtype: str = 'float16',
                             transa: bool = False,
                             transb: bool = False) -> Tensor:

    if not default_net(
    ).plugin_config.weight_only_quant_matmul_plugin or transa or transb:
        scale_axis = 0 if transb else 1
        if weights.dtype != trt.int8:
            # Q->DQ
            weights = quantize(weights, scales, dtype='int8', axis=1)
            weights = dequantize(weights, scales, scale_axis, input.dtype)
        else:
            weights = dequantize(weights, scales, scale_axis, input.dtype)

        res = matmul(input, weights, transa=transa, transb=transb)
        return cast(res, dtype)
    else:
        plg_creator = trt.get_plugin_registry().get_plugin_creator(
            'WeightOnlyQuantMatmul', '1', TRT_LLM_PLUGIN_NAMESPACE)
        assert plg_creator is not None

        weight_type_id = trt.PluginField("weight_type_id",
                                         np.array(weightTypeId, dtype=np.int32),
                                         trt.PluginFieldType.INT32)

        p_dtype = default_net().plugin_config.weight_only_quant_matmul_plugin
        pf_type = trt.PluginField(
            "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
            trt.PluginFieldType.INT32)

        pfc = trt.PluginFieldCollection([pf_type, weight_type_id])
        matmul_plug = plg_creator.create_plugin("woq_matmul", pfc)
        plug_inputs = [input.trt_tensor, weights.trt_tensor, scales.trt_tensor]
        layer = default_trtnet().add_plugin_v2(plug_inputs, matmul_plug)
        _add_plugin_info(layer, plg_creator, "woq_matmul", pfc)
        if not default_net().strongly_typed:
            layer.get_input(1).set_dynamic_range(-127, 127)
        return _create_tensor(layer.get_output(0), layer)


def weight_only_groupwise_quant_matmul(input: Tensor,
                                       pre_quant_scale: Tensor,
                                       weights: Tensor,
                                       scales: Tensor,
                                       zeros: Tensor,
                                       biases: Tensor,
                                       alpha: Tensor,
                                       quant_algo: int,
                                       group_size: int,
                                       dtype: str = 'float16') -> Tensor:

    if not default_net(
    ).plugin_config.weight_only_groupwise_quant_matmul_plugin:
        scales = repeat_interleave(scales, group_size, 0)
        weights = quantize(weights, scales, dtype='int8', axis=1)
        weights = dequantize(weights, scales, 1, input.dtype)

        if quant_algo & 8:
            # fp8_alpha
            input = input * alpha
        if quant_algo & 4:
            # pre quant
            input = input * pre_quant_scale
        elif quant_algo & 2:
            # zero
            zeros = repeat_interleave(zeros, group_size, 0)
            weights += zeros
        res = matmul(input, weights)
        if quant_algo & 1:
            # bias
            res += biases

        return cast(res, dtype)
    else:
        plg_creator = trt.get_plugin_registry().get_plugin_creator(
            'WeightOnlyGroupwiseQuantMatmul', '1', TRT_LLM_PLUGIN_NAMESPACE)
        assert plg_creator is not None

        quant_algo_ = trt.PluginField("quant_algo",
                                      np.array(quant_algo, dtype=np.int32),
                                      trt.PluginFieldType.INT32)
        group_size_ = trt.PluginField("group_size",
                                      np.array(group_size, dtype=np.int32),
                                      trt.PluginFieldType.INT32)

        p_dtype = default_net(
        ).plugin_config.weight_only_groupwise_quant_matmul_plugin
        pf_type_ = trt.PluginField(
            "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
            trt.PluginFieldType.INT32)

        pfc = trt.PluginFieldCollection([pf_type_, quant_algo_, group_size_])

        matmul_plug = plg_creator.create_plugin("woq_groupwise_matmul", pfc)

        # quant_algo = fp8_alpha * 8 + pre_quant_scale * 4 + zero * 2 + bias
        plug_inputs = [input.trt_tensor]

        # Flags for indicating whether the corresponding inputs are applied in quant_algo
        # quant_algo = fp8_alpha * FP8_ALPHA + pre_quant_scale * PRE_QUANT_SCALE + zero * ZERO + bias * BIAS
        # Here pre_quant_scale, zero and bias are boolean type
        BIAS = 1
        ZERO = 2
        PRE_QUANT_SCALE = 4
        FP8_ALPHA = 8

        if quant_algo & PRE_QUANT_SCALE:
            plug_inputs += [pre_quant_scale.trt_tensor]

        plug_inputs += [weights.trt_tensor, scales.trt_tensor]

        if quant_algo & ZERO:
            plug_inputs += [zeros.trt_tensor]
        if quant_algo & BIAS:
            plug_inputs += [biases.trt_tensor]
        if quant_algo & FP8_ALPHA:
            plug_inputs += [alpha.trt_tensor]

        layer = default_trtnet().add_plugin_v2(plug_inputs, matmul_plug)
        _add_plugin_info(layer, plg_creator, "woq_groupwise_matmul", pfc)

        return _create_tensor(layer.get_output(0), layer)


def smooth_quant_layer_norm(input: Tensor,
                            normalized_shape: Union[int, Tuple[int]],
                            weight: Optional[Tensor] = None,
                            bias: Optional[Tensor] = None,
                            scale: Optional[Tensor] = None,
                            eps: float = 1e-05,
                            use_diff_of_squares: bool = True,
                            dynamic_act_scaling: bool = False) -> Tensor:
    if not default_net().plugin_config.layernorm_quantization_plugin:
        raise TypeError("Smooth Quant Layer Norm is only supported with plugin")
    else:
        plg_creator = trt.get_plugin_registry().get_plugin_creator(
            'LayernormQuantization', '1', TRT_LLM_PLUGIN_NAMESPACE)
        assert plg_creator is not None

        eps = trt.PluginField("eps", np.array(eps, dtype=np.float32),
                              trt.PluginFieldType.FLOAT32)
        use_diff_of_squares = trt.PluginField(
            "use_diff_of_squares",
            np.array([int(use_diff_of_squares)], dtype=np.int32),
            trt.PluginFieldType.INT32)

        dyn_act_scaling = trt.PluginField(
            "dyn_act_scaling", np.array([int(dynamic_act_scaling)], np.int32),
            trt.PluginFieldType.INT32)

        p_dtype = default_net().plugin_config.layernorm_quantization_plugin
        pf_type = trt.PluginField(
            "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
            trt.PluginFieldType.INT32)
        pfc = trt.PluginFieldCollection(
            [eps, use_diff_of_squares, dyn_act_scaling, pf_type])
        layernorm_plug = plg_creator.create_plugin("layernorm_quantized", pfc)
        normalized_shape = [normalized_shape] if isinstance(
            normalized_shape, int) else normalized_shape
        if weight is None:
            weight = constant(
                np.ones(normalized_shape, dtype=str_dtype_to_np(p_dtype)))
        if bias is None:
            bias = constant(
                np.zeros(normalized_shape, dtype=str_dtype_to_np(p_dtype)))

        # LayerNorm plugin only supports float32 scale
        scale = cast(scale, "float32")
        plug_inputs = [
            input.trt_tensor, weight.trt_tensor, bias.trt_tensor,
            scale.trt_tensor
        ]
        layer = default_trtnet().add_plugin_v2(plug_inputs, layernorm_plug)
        if not default_net().strongly_typed:
            layer.get_output(0).set_dynamic_range(-127, 127)
        _add_plugin_info(layer, plg_creator, "layernorm_quantized", pfc)
        if not dynamic_act_scaling:
            return _create_tensor(layer.get_output(0), layer)

        return _create_tensor(layer.get_output(0),
                              layer), _create_tensor(layer.get_output(1), layer)


def smooth_quant_rms_norm(input: Tensor,
                          normalized_shape: Union[int, Tuple[int]],
                          weight: Optional[Tensor] = None,
                          bias: Optional[Tensor] = None,
                          scale: Optional[Tensor] = None,
                          clamp_val: Optional[Tensor] = None,
                          eps: float = 1e-05,
                          dynamic_act_scaling: bool = False) -> Tensor:
    if not default_net().plugin_config.rmsnorm_quantization_plugin:
        raise TypeError("Smooth Quant Rms Norm is only supported with plugin")
    else:
        plg_creator = trt.get_plugin_registry().get_plugin_creator(
            'RmsnormQuantization', '1', TRT_LLM_PLUGIN_NAMESPACE)
        assert plg_creator is not None

        output_type = trt.PluginField("out_type_id",
                                      np.array([int(trt.int8)], np.int32),
                                      trt.PluginFieldType.INT32)
        quant_mode = trt.PluginField(
            "quant_mode",
            np.array([int(QuantMode.use_smooth_quant(per_token=True))],
                     np.int32), trt.PluginFieldType.INT32)
        clamp_enabled = trt.PluginField(
            "clamp_enabled", np.array([clamp_val is not None], np.int32),
            trt.PluginFieldType.INT32)

        eps = trt.PluginField("eps", np.array(eps, dtype=np.float32),
                              trt.PluginFieldType.FLOAT32)

        dyn_act_scaling = trt.PluginField(
            "dyn_act_scaling", np.array([int(dynamic_act_scaling)], np.int32),
            trt.PluginFieldType.INT32)

        p_dtype = default_net().plugin_config.rmsnorm_quantization_plugin
        pf_type = trt.PluginField(
            "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
            trt.PluginFieldType.INT32)
        pfc = trt.PluginFieldCollection([
            eps, dyn_act_scaling, clamp_enabled, quant_mode, pf_type,
            output_type
        ])
        rmsnorm_plug = plg_creator.create_plugin("rmsnorm_quantized", pfc)
        normalized_shape = [normalized_shape] if isinstance(
            normalized_shape, int) else normalized_shape
        if weight is None:
            weight = constant(
                np.ones(normalized_shape, dtype=str_dtype_to_np(p_dtype)))
        if bias is None:
            bias = constant(
                np.zeros(normalized_shape, dtype=str_dtype_to_np(p_dtype)))

        # RMS Norm Plugin only supports float32 scale
        scale = cast(scale, "float32")
        plug_inputs = [
            input.trt_tensor, weight.trt_tensor, bias.trt_tensor,
            scale.trt_tensor
        ]
        if clamp_val:
            plug_inputs += [clamp_val.trt_tensor]
        layer = default_trtnet().add_plugin_v2(plug_inputs, rmsnorm_plug)
        if not default_net().strongly_typed:
            layer.get_output(0).set_dynamic_range(-127, 127)
        _add_plugin_info(layer, plg_creator, "rmsnorm_quantized", pfc)
        if not dynamic_act_scaling:
            return _create_tensor(layer.get_output(0), layer)

        return _create_tensor(layer.get_output(0),
                              layer), _create_tensor(layer.get_output(1), layer)


def fp8_rowwise_rms_norm(input: Tensor,
                         normalized_shape: Union[int, Tuple[int]],
                         weight: Optional[Tensor] = None,
                         bias: Optional[Tensor] = None,
                         scale: Optional[Tensor] = None,
                         clamp_val: Optional[Tensor] = None,
                         eps: float = 1e-05,
                         dynamic_act_scaling: bool = True) -> Tensor:
    if not default_net().plugin_config.rmsnorm_quantization_plugin:
        raise TypeError("Fp8 Rowwise Rms Norm is only supported with plugin")
    else:
        plg_creator = trt.get_plugin_registry().get_plugin_creator(
            'RmsnormQuantization', '1', TRT_LLM_PLUGIN_NAMESPACE)
        assert plg_creator is not None

        output_type = trt.PluginField("out_type_id",
                                      np.array([int(trt.fp8)], np.int32),
                                      trt.PluginFieldType.INT32)
        quant_mode = trt.PluginField(
            "quant_mode",
            np.array([int(QuantMode.from_description(use_fp8_rowwise=True))],
                     np.int32), trt.PluginFieldType.INT32)
        clamp_enabled = trt.PluginField(
            "clamp_enabled", np.array([clamp_val is not None], np.int32),
            trt.PluginFieldType.INT32)

        eps = trt.PluginField("eps", np.array(eps, dtype=np.float32),
                              trt.PluginFieldType.FLOAT32)

        dyn_act_scaling = trt.PluginField(
            "dyn_act_scaling", np.array([int(dynamic_act_scaling)], np.int32),
            trt.PluginFieldType.INT32)

        p_dtype = default_net().plugin_config.rmsnorm_quantization_plugin
        pf_type = trt.PluginField(
            "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
            trt.PluginFieldType.INT32)
        pfc = trt.PluginFieldCollection([
            eps, dyn_act_scaling, clamp_enabled, quant_mode, pf_type,
            output_type
        ])
        rmsnorm_plug = plg_creator.create_plugin("rmsnorm_quantized", pfc)
        normalized_shape = [normalized_shape] if isinstance(
            normalized_shape, int) else normalized_shape
        if weight is None:
            weight = constant(
                np.ones(normalized_shape, dtype=str_dtype_to_np(p_dtype)))
        if bias is None:
            bias = constant(
                np.zeros(normalized_shape, dtype=str_dtype_to_np(p_dtype)))
        if scale is None:
            scale = constant(np.ones((1, ), dtype=str_dtype_to_np(p_dtype)))

        # RMS Norm Plugin only supports float32 scale
        scale = cast(scale, "float32")
        plug_inputs = [
            input.trt_tensor, weight.trt_tensor, bias.trt_tensor,
            scale.trt_tensor
        ]
        if clamp_val:
            plug_inputs += [clamp_val.trt_tensor]
        layer = default_trtnet().add_plugin_v2(plug_inputs, rmsnorm_plug)
        if not default_net().strongly_typed:
            layer.get_output(0).set_dynamic_range(-448, 448)
        _add_plugin_info(layer, plg_creator, "rmsnorm_quantized", pfc)
        if not dynamic_act_scaling:
            return _create_tensor(layer.get_output(0), layer)

        return _create_tensor(layer.get_output(0),
                              layer), _create_tensor(layer.get_output(1), layer)


def quantize(input: Tensor,
             scale_factor: Tensor,
             dtype: str,
             axis: int = -1) -> Tensor:
    layer = default_trtnet().add_quantize(input.trt_tensor,
                                          scale_factor.trt_tensor,
                                          str_dtype_to_trt(dtype))
    layer.axis = axis

    output = _create_tensor(layer.get_output(0), layer)

    return output


def dequantize(input: Tensor,
               scale_factor: Tensor,
               axis: int = -1,
               output_type: Union[str, trt.DataType] = 'float16') -> Tensor:

    if isinstance(output_type, str):
        output_type = str_dtype_to_trt(output_type)

    layer = default_trtnet().add_dequantize(input.trt_tensor,
                                            scale_factor.trt_tensor,
                                            output_type)
    layer.axis = axis

    if not default_net().strongly_typed:
        layer.precision = input.dtype

    output = _create_tensor(layer.get_output(0), layer)

    return output


def quantize_per_token(x: Tensor,
                       clamp_val: Optional[Tensor] = None) -> Tuple[Tensor]:
    if not default_net().plugin_config.quantize_per_token_plugin:
        x = cast(x, 'float32')
        xmax = x.abs().max(-1, keepdim=True)
        scale = xmax / 127.0
        out = x * 127.0 / xmax
        out = round(out)
        out = clip(out, -128, 127)
        quantized_out = cast(out, 'int8')
        return quantized_out, scale
    else:
        plg_creator = trt.get_plugin_registry().get_plugin_creator(
            'QuantizePerToken', '1', TRT_LLM_PLUGIN_NAMESPACE)
        assert plg_creator is not None

        output_type = trt.PluginField("type_id",
                                      np.array([int(trt.int8)], np.int32),
                                      trt.PluginFieldType.INT32)
        quant_mode = trt.PluginField(
            "quant_mode",
            np.array([int(QuantMode.use_smooth_quant(per_token=True))],
                     np.int32), trt.PluginFieldType.INT32)
        clamp_enabled = trt.PluginField(
            "clamp_enabled", np.array([clamp_val is not None], np.int8),
            trt.PluginFieldType.INT8)
        pfc = trt.PluginFieldCollection(
            [output_type, quant_mode, clamp_enabled])
        quantize_plug = plg_creator.create_plugin("quantize_per_token_plugin",
                                                  pfc)

        plug_inputs = [x.trt_tensor]
        if clamp_val:
            plug_inputs += [clamp_val.trt_tensor]
        layer = default_trtnet().add_plugin_v2(plug_inputs, quantize_plug)
        if not default_net().strongly_typed:
            layer.get_output(0).set_dynamic_range(-127, 127)
        _add_plugin_info(layer, plg_creator, "quantize_per_token_plugin", pfc)

        quantized = _create_tensor(layer.get_output(0), layer)
        scales = _create_tensor(layer.get_output(1), layer)

        return quantized, scales


def quantize_fp8_per_token(x: Tensor,
                           clamp_val: Optional[Tensor] = None) -> Tuple[Tensor]:
    if not default_net().plugin_config.quantize_per_token_plugin:
        x = cast(x, 'float32')
        xmax = x.abs().max(-1, keepdim=True)
        scale = xmax / 448.0
        out = x * 448.0 / xmax
        out = round(out)
        out = clip(out, -448, 448)
        quantized_out = cast(out, 'fp8')
        return quantized_out, scale
    else:
        plg_creator = trt.get_plugin_registry().get_plugin_creator(
            'QuantizePerToken', '1', TRT_LLM_PLUGIN_NAMESPACE)
        assert plg_creator is not None

        output_type = trt.PluginField("type_id",
                                      np.array([int(trt.fp8)], np.int32),
                                      trt.PluginFieldType.INT32)
        quant_mode = trt.PluginField(
            "quant_mode",
            np.array([int(QuantMode.from_description(use_fp8_rowwise=True))],
                     np.int32), trt.PluginFieldType.INT32)
        clamp_enabled = trt.PluginField(
            "clamp_enabled", np.array([clamp_val is not None], np.int8),
            trt.PluginFieldType.INT8)
        pfc = trt.PluginFieldCollection(
            [output_type, quant_mode, clamp_enabled])
        quantize_plug = plg_creator.create_plugin("quantize_per_token_plugin",
                                                  pfc)

        plug_inputs = [x.trt_tensor]
        if clamp_val:
            plug_inputs += [clamp_val.trt_tensor]
        layer = default_trtnet().add_plugin_v2(plug_inputs, quantize_plug)
        if not default_net().strongly_typed:
            layer.get_output(0).set_dynamic_range(-448, 448)
        _add_plugin_info(layer, plg_creator, "quantize_per_token_plugin", pfc)

        quantized = _create_tensor(layer.get_output(0), layer)
        scales = _create_tensor(layer.get_output(1), layer)

        return quantized, scales


def quantize_tensor(x, scale):
    if not default_net().plugin_config.quantize_tensor_plugin:
        scaled = x * scale
        rounded = round(scaled)
        clipped = clip(rounded, -128, 127)
        quantized = cast(clipped, 'int8')
    else:
        plg_creator = trt.get_plugin_registry().get_plugin_creator(
            'QuantizeTensor', '1', TRT_LLM_PLUGIN_NAMESPACE)
        assert plg_creator is not None

        pfc = trt.PluginFieldCollection([])
        quantize_plug = plg_creator.create_plugin("quantize_tensor_plugin", pfc)

        plug_inputs = [x.trt_tensor, scale.trt_tensor]
        layer = default_trtnet().add_plugin_v2(plug_inputs, quantize_plug)
        if not default_net().strongly_typed:
            layer.get_output(0).set_dynamic_range(-127, 127)
        _add_plugin_info(layer, plg_creator, "quantize_tensor_plugin", pfc)

        quantized = _create_tensor(layer.get_output(0), layer)
    return quantized