File size: 3,620 Bytes
9dd3461 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
#pragma once
#include <ATen/ATen.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/quantized/AffineQuantizerBase.h>
namespace at {
namespace native {
Tensor& quantize_tensor_per_tensor_affine(
const Tensor& rtensor,
Tensor& qtensor,
double scale,
int64_t zero_point);
Tensor& quantize_tensor_per_channel_affine(
const Tensor& rtensor,
Tensor& qtensor,
Tensor scales,
Tensor zero_points,
int64_t axis);
Tensor& quantize_tensor_per_channel_float_qparams(
const Tensor& rtensor,
Tensor& qtensor,
Tensor scales,
Tensor zero_points,
int64_t axis);
Tensor& dequantize_tensor_per_tensor_affine(
const Tensor& qtensor,
Tensor& rtensor,
double scale,
int64_t zero_point);
Tensor& dequantize_tensor_per_channel_affine(
const Tensor& qtensor,
Tensor& rtensor,
Tensor scales,
Tensor zero_points,
int64_t axis);
Tensor& dequantize_tensor_per_channel_float_qparams(
const Tensor& qtensor,
Tensor& rtensor,
Tensor scales,
Tensor zero_points,
int64_t axis);
using quantize_tensor_per_tensor_affine_fn =
void (*)(const Tensor& rtensor, Tensor& qtensor, double scale, int64_t zero_point);
using quantize_tensor_per_channel_affine_fn = void (*)(
const Tensor& rtensor,
Tensor& qtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis);
using quantize_tensor_per_channel_float_qparams_fn = void (*)(
const Tensor& rtensor,
Tensor& qtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis);
using dequantize_tensor_per_tensor_affine_fn =
void (*)(const Tensor& qtensor, Tensor& rtensor, double scale, int64_t zero_point);
using dequantize_tensor_per_channel_affine_fn = void (*)(
const Tensor& qtensor,
Tensor& rtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis);
using dequantize_tensor_per_channel_float_qparams_fn = void (*)(
const Tensor& qtensor,
Tensor& rtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis);
using quantize_tensor_per_tensor_affine_sub_byte_fn =
void (*)(const Tensor& rtensor, Tensor& qtensor, float scale, float zero_point);
using dequantize_tensor_per_tensor_affine_sub_byte_fn =
void (*)(const Tensor& qtensor, Tensor& rtensor, float scale, float zero_point);
DECLARE_DISPATCH(
quantize_tensor_per_tensor_affine_fn,
quantize_tensor_per_tensor_affine_stub);
DECLARE_DISPATCH(
quantize_tensor_per_channel_affine_fn,
quantize_tensor_per_channel_affine_stub);
DECLARE_DISPATCH(
quantize_tensor_per_channel_float_qparams_fn,
quantize_tensor_per_channel_float_qparams_stub);
DECLARE_DISPATCH(
dequantize_tensor_per_tensor_affine_fn,
dequantize_tensor_per_tensor_affine_stub);
DECLARE_DISPATCH(
dequantize_tensor_per_channel_affine_fn,
dequantize_tensor_per_channel_affine_stub);
DECLARE_DISPATCH(
dequantize_tensor_per_channel_float_qparams_fn,
dequantize_tensor_per_channel_float_qparams_stub);
DECLARE_DISPATCH(
quantize_tensor_per_tensor_affine_sub_byte_fn,
quantize_tensor_per_tensor_affine_sub_byte_stub);
DECLARE_DISPATCH(
dequantize_tensor_per_tensor_affine_sub_byte_fn,
dequantize_tensor_per_tensor_affine_sub_byte_stub);
template <typename T>
TORCH_API Tensor quantize_tensor(
Tensor rtensor,
Tensor qtensor,
double scale,
int64_t zero_point);
template <typename T>
TORCH_API Tensor dequantize_tensor(
Tensor qtensor,
Tensor rtensor,
double scale,
int64_t zero_point);
} // namespace native
} // namespace at
|