File size: 4,886 Bytes
c1af2fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/core/ivalue.h>
struct LinearPackedParamsBase : public torch::jit::CustomClassHolder {
virtual at::Tensor apply(
at::Tensor input,
double output_scale,
int64_t output_zero_point) = 0;
virtual at::Tensor apply_relu(
at::Tensor input,
double output_scale,
int64_t output_zero_point) = 0;
// out variant of LinearPackedParamsBase::apply
virtual at::Tensor& apply_out(
const at::Tensor& /*input*/,
double /*output_scale*/,
int64_t /*output_zero_point*/,
at::Tensor& output) {
throw std::runtime_error(
"apply_out is not implemented for this packed "
"parameter type");
return output;
}
virtual at::Tensor& apply_relu_out(
const at::Tensor& /*input*/,
double /*output_scale*/,
int64_t /*output_zero_point*/,
at::Tensor& output) {
throw std::runtime_error(
"apply_relu_out is not implemented for this packed "
"parameter type");
return output;
}
// Corresponding pattern (the ops with `*` are part of the pattern that
// represents the computation of quantized::linear_with_input_q_dq_qweight_dq_output_fp32):
// input -> q* -> dq* -> linear* ->
// qweight -> dq* /
//
// After fusion:
// input -> quantized::linear_with_input_q_dq_qweight_dq_output_fp32* ->
// qweight /
//
// Additional Note: the weight is packed as well
// Params:
// X: float32 Tensor, will be quantized to quint8 in the op
// W_prepack: packed qint8 quantized weight and bias
// Returns:
// Y: float32 Tensor
virtual at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32(
at::Tensor input,
double input_scale,
int64_t input_zero_point) {
throw std::runtime_error(
"apply_with_input_q_dq_qweight_dq_output_fp32 is not implemented for this packed "
"parameter type");
return {};
}
// Corresponding pattern (the ops with `*` are part of the pattern that
// represents the computation of quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32):
// input -> q* -> dq* -> linear* -> relu* ->
// qweight -> dq* /
//
// After fusion:
// input -> quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32* ->
// qweight /
//
// Additional Note: the weight is packed as well
// Params:
// input: float32 Tensor, will be quantized to quint8 in the op
// Returns:
// float32 Tensor
virtual at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32(
at::Tensor input,
double input_scale,
int64_t input_zero_point) {
throw std::runtime_error(
"apply_with_input_q_dq_qweight_dq_relu_output_fp32 is not implemented for this packed "
"parameter type");
return {};
}
virtual at::Tensor apply_dynamic(
at::Tensor input,
bool reduce_range = false) = 0;
virtual at::Tensor apply_dynamic_relu(
at::Tensor input,
bool reduce_range = false) = 0;
virtual at::Tensor& apply_dynamic_out(
const at::Tensor& /* input */,
at::Tensor& output,
bool /* reduce_range */) {
throw std::runtime_error(
"apply_dynamic_out is not implemented for this packed "
"parameter type");
return output;
}
virtual at::Tensor& apply_dynamic_relu_out(
const at::Tensor& /* input */,
at::Tensor& output,
bool /* reduce_range */) {
throw std::runtime_error(
"apply_dynamic_relu_out is not implemented for this packed "
"parameter type");
return output;
}
virtual std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() = 0;
virtual std::optional<at::Tensor> bias() = 0;
virtual void set_bias(std::optional<at::Tensor> /*bias*/) {
throw std::runtime_error(
"set_bias is not implemented for this packed "
"parameter type");
}
};
template <int kSpatialDim = 2>
struct ConvPackedParamsBase : public torch::jit::CustomClassHolder {
virtual at::Tensor apply(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) = 0;
virtual at::Tensor apply_relu(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) = 0;
virtual at::Tensor apply_dynamic(
const at::Tensor& input,
bool reduce_range) = 0;
virtual std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() = 0;
virtual torch::List<int64_t> stride() const = 0;
virtual torch::List<int64_t> padding() const = 0;
virtual torch::List<int64_t> output_padding() const = 0;
virtual torch::List<int64_t> dilation() const = 0;
virtual int64_t groups() const = 0;
virtual bool transpose() const = 0;
};
|