File size: 11,869 Bytes
9dd3461 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 |
#pragma once
#ifdef USE_XNNPACK
#include <cstdint>
#include <ATen/ATen.h>
#include <ATen/native/xnnpack/Common.h>
using xnnpack_operator = at::native::xnnpack::Operator;
namespace at {
namespace native {
namespace xnnp_utils {
/*
* Return shape in the same order as the memory format
* e.g. channels_last will return NHWC instead of NCHW
*/
std::vector<size_t> get_mem_format_aware_shape(const at::Tensor& in);
/*
* Input is always int8_t, output can be [int8_t, uint8_t].
* input + offset = output
* int8_t + 128 = uint8_t
* int8_t + 0 = int8_t
*/
template <typename PT>
void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out);
template <int kSpatialDim>
Tensor convert_conv_weights_to_channel_last_tensor(
const at::Tensor& src,
int groups,
bool transpose);
/*
* Series of create wrapper functions to call xnn_create_[de]conv* functions.
*/
C10_ALWAYS_INLINE
enum xnn_status xnnp_create_convolution2d_nhwc(
uint32_t pad_top,
uint32_t pad_right,
uint32_t pad_bottom,
uint32_t pad_left,
uint32_t kernel_h,
uint32_t kernel_w,
uint32_t stride_h,
uint32_t stride_w,
uint32_t dilation_h,
uint32_t dilation_w,
uint32_t groups,
size_t group_input_channels,
size_t group_output_channels,
size_t ip_chan_stride,
size_t op_chan_stride,
int8_t izp,
float ip_scale,
int8_t kzp,
const float* k_scales,
const int8_t* kernel,
const int32_t* bias,
int8_t ozp,
float op_scale,
int8_t op_min,
int8_t op_max,
uint32_t flags,
xnn_operator_t* op,
bool per_channel,
bool transpose) {
/* Symmetric quantization forces kzp = 0 */
TORCH_CHECK(!kzp, "XNNPACK Q[SC]8 conv kernels expects kernel zero point to be zero."
"But got: ", kzp);
if (transpose) {
TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
return xnn_create_deconvolution2d_nhwc_qs8(
pad_top, /* uint32_t output_padding_top */
pad_right, /* uint32_t output_padding_right */
pad_bottom, /* uint32_t output_padding_bottom */
pad_left, /* uint32_t output_padding_left */
kernel_h, /* uint32_t kernel_height */
kernel_w, /* uint32_t kernel_width */
stride_h, /* uint32_t stride_height */
stride_w, /* uint32_t stride_width */
dilation_h, /* uint32_t dilation_height */
dilation_w, /* uint32_t dilation_width */
groups, /* uint32_t groups */
group_input_channels, /* size_t group_input_channels */
group_output_channels, /* size_t group_output_channels */
ip_chan_stride, /* size_t input_pixel_stride */
op_chan_stride, /* size_t output_pixel_stride */
izp, /* int8_t input_zero_point */
ip_scale, /* float input_scale */
k_scales[0], /* float kernel_scale */
kernel, /* const int8_t* kernel */
bias, /* const int32_t* bias */
ozp, /* int8_t output_zero_point */
op_scale, /* float output_scale */
op_min, /* int8_t output_min */
op_max, /* int8_t output_max */
flags, /* uint32_t flags */
op); /* xnn_operator_t* deconvolution_op_out */
}
if (!per_channel) {
return xnn_create_convolution2d_nhwc_qs8(
pad_top, /* uint32_t input_padding_top */
pad_right, /* uint32_t input_padding_right */
pad_bottom, /* uint32_t input_padding_bottom */
pad_left, /* uint32_t input_padding_left */
kernel_h, /* uint32_t kernel_height */
kernel_w, /* uint32_t kernel_width */
stride_h, /* uint32_t subsampling_height */
stride_w, /* uint32_t subsampling_width */
dilation_h, /* uint32_t dilation_height */
dilation_w, /* uint32_t dilation_width */
groups, /* uint32_t groups */
group_input_channels, /* size_t group_input_channels */
group_output_channels, /* size_t group_output_channels*/
ip_chan_stride, /* size_t input_channel_stride */
op_chan_stride, /* size_t output_channel_stride */
izp, /* int8_t input_zero_point */
ip_scale, /* float input_scale */
k_scales[0], /* float kernel_scale */
kernel, /* const int8_t* kernel */
bias, /* const int32_t* bias */
ozp, /* int8_t output_zero_point */
op_scale, /* float output_scale */
op_min, /* int8_t output_min */
op_max, /* int8_t output_max */
flags, /* uint32_t flags */
op); /* xnn_operator_t* convolution_op_out */
} else { /* per_channel */
return xnn_create_convolution2d_nhwc_qc8(
pad_top, /* uint32_t input_padding_top */
pad_right, /* uint32_t input_padding_right */
pad_bottom, /* uint32_t input_padding_bottom */
pad_left, /* uint32_t input_padding_left */
kernel_h, /* uint32_t kernel_height */
kernel_w, /* uint32_t kernel_width */
stride_h, /* uint32_t subsampling_height */
stride_w, /* uint32_t subsampling_width */
dilation_h, /* uint32_t dilation_height */
dilation_w, /* uint32_t dilation_width */
groups, /* uint32_t groups */
group_input_channels, /* size_t group_input_channels */
group_output_channels, /* size_t group_output_channels*/
ip_chan_stride, /* size_t input_channel_stride */
op_chan_stride, /* size_t output_channel_stride */
izp, /* int8_t input_zero_point */
ip_scale, /* float input_scale */
k_scales, /* const float* kernel_scale */
kernel, /* const int8_t* kernel */
bias, /* const int32_t* bias */
ozp, /* int8_t output_zero_point */
op_scale, /* float output_scale */
op_min, /* int8_t output_min */
op_max, /* int8_t output_max */
flags, /* uint32_t flags */
op); /* xnn_operator_t* convolution_op_out */
}
}
/*
* Series of setup wrapper functions to call xnn_setup_[de]conv* functions.
*/
C10_ALWAYS_INLINE
enum xnn_status xnnp_setup_convolution2d_nhwc(
xnn_operator_t op,
size_t batch,
size_t in_h,
size_t in_w,
const int8_t* inp,
int8_t* outp,
pthreadpool_t pt_pool,
bool per_channel = false,
bool transpose = false,
uint32_t adj_h = 0,
uint32_t adj_w = 0) {
if(transpose) {
TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
return xnn_setup_deconvolution2d_nhwc_qs8(
op, /* xnn_operator_t deconvolution_op */
batch, /* size_t batch_size */
in_h, /* size_t input_height */
in_w, /* size_t input_width */
adj_h, /* uint32_t adjustment_height */
adj_w, /* uint32_t adjustment_width */
inp, /* const int8_t* input */
outp, /* int8_t* output */
pt_pool); /* pthreadpool_t threadpool */
}
if (!per_channel) {
return xnn_setup_convolution2d_nhwc_qs8(
op, /* xnn_operator_t convolution_op */
batch, /* size_t batch_size */
in_h, /* size_t input_height */
in_w, /* size_t input_width */
inp, /* const int8_t* input */
outp, /* int8_t* output */
pt_pool); /* pthreadpool_t threadpool */
} else { /* per_channel */
return xnn_setup_convolution2d_nhwc_qc8(
op, /* xnn_operator_t convolution_op */
batch, /* size_t batch_size */
in_h, /* size_t input_height */
in_w, /* size_t input_width */
inp, /* const int8_t* input */
outp, /* int8_t* output */
pt_pool); /* pthreadpool_t threadpool */
}
}
/*
* Series of wrapper functions to call xnn_create* and xnn_setup*
* functions for linear
*/
C10_ALWAYS_INLINE
enum xnn_status xnnp_create_fully_connected_nc(
size_t input_channels,
size_t output_channels,
size_t input_stride,
size_t output_stride,
int8_t input_zero_point,
float input_scale,
int8_t kernel_zero_point,
float kernel_scale,
const int8_t* kernel,
const int32_t* bias,
int8_t output_zero_point,
float output_scale,
int8_t output_min,
int8_t output_max,
uint32_t flags,
xnn_operator_t* fully_connected_op_out) {
/* Symmetric quantization forces kzp = 0 */
TORCH_CHECK(!kernel_zero_point, "XNNPACK QS8 linear kernel expects kernel zero point to be zero."
"But got: ", kernel_zero_point);
return xnn_create_fully_connected_nc_qs8(
input_channels, /* size_t input_channels */
output_channels, /* size_t output_channels */
input_stride, /* size_t input_stride */
output_stride, /* size_t output_stride */
input_zero_point, /* int8_t input_zero_point */
input_scale, /* float input_scale */
kernel_scale, /* float kernel_scale */
kernel, /* const int8_t* kernel */
bias, /* const int32_t* bias */
output_zero_point, /* int8_t output_zero_point */
output_scale, /* float output_scale */
output_min, /* int8_t output_min */
output_max, /* int8_t output_max */
flags, /* uint32_t flags */
fully_connected_op_out); /* xnn_operator_t* fully_connected_op_out */
}
C10_ALWAYS_INLINE
enum xnn_status xnnp_setup_fully_connected_nc(
xnn_operator_t fully_connected_op,
size_t batch_size,
const int8_t* input,
int8_t* output,
pthreadpool_t threadpool) {
return xnn_setup_fully_connected_nc_qs8(
fully_connected_op, /* xnn_operator_t fully_connected_op */
batch_size, /* size_t batch_size */
input, /* const int8_t* input */
output, /* int8_t* output */
threadpool); /* pthreadpool_t threadpool */
}
} // namespace xnnp_utils
} // namespace native
} // namespace at
#endif // USE_XNNPACK
|