|
|
#pragma once |
|
|
|
|
|
#ifdef USE_XNNPACK |
|
|
#include <cstdint> |
|
|
|
|
|
#include <ATen/ATen.h> |
|
|
#include <ATen/native/xnnpack/Common.h> |
|
|
|
|
|
using xnnpack_operator = at::native::xnnpack::Operator; |
|
|
|
|
|
namespace at { |
|
|
namespace native { |
|
|
namespace xnnp_utils { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> get_mem_format_aware_shape(const at::Tensor& in); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename PT> |
|
|
void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out); |
|
|
|
|
|
template <int kSpatialDim> |
|
|
Tensor convert_conv_weights_to_channel_last_tensor( |
|
|
const at::Tensor& src, |
|
|
int groups, |
|
|
bool transpose); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
C10_ALWAYS_INLINE |
|
|
enum xnn_status xnnp_create_convolution2d_nhwc( |
|
|
uint32_t pad_top, |
|
|
uint32_t pad_right, |
|
|
uint32_t pad_bottom, |
|
|
uint32_t pad_left, |
|
|
uint32_t kernel_h, |
|
|
uint32_t kernel_w, |
|
|
uint32_t stride_h, |
|
|
uint32_t stride_w, |
|
|
uint32_t dilation_h, |
|
|
uint32_t dilation_w, |
|
|
uint32_t groups, |
|
|
size_t group_input_channels, |
|
|
size_t group_output_channels, |
|
|
size_t ip_chan_stride, |
|
|
size_t op_chan_stride, |
|
|
int8_t izp, |
|
|
float ip_scale, |
|
|
int8_t kzp, |
|
|
const float* k_scales, |
|
|
const int8_t* kernel, |
|
|
const int32_t* bias, |
|
|
int8_t ozp, |
|
|
float op_scale, |
|
|
int8_t op_min, |
|
|
int8_t op_max, |
|
|
uint32_t flags, |
|
|
xnn_operator_t* op, |
|
|
bool per_channel, |
|
|
bool transpose) { |
|
|
|
|
|
TORCH_CHECK(!kzp, "XNNPACK Q[SC]8 conv kernels expects kernel zero point to be zero." |
|
|
"But got: ", kzp); |
|
|
|
|
|
if (transpose) { |
|
|
TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); |
|
|
return xnn_create_deconvolution2d_nhwc_qs8( |
|
|
pad_top, |
|
|
pad_right, |
|
|
pad_bottom, |
|
|
pad_left, |
|
|
kernel_h, |
|
|
kernel_w, |
|
|
stride_h, |
|
|
stride_w, |
|
|
dilation_h, |
|
|
dilation_w, |
|
|
groups, |
|
|
group_input_channels, |
|
|
group_output_channels, |
|
|
ip_chan_stride, |
|
|
op_chan_stride, |
|
|
izp, |
|
|
ip_scale, |
|
|
k_scales[0], |
|
|
kernel, |
|
|
bias, |
|
|
ozp, |
|
|
op_scale, |
|
|
op_min, |
|
|
op_max, |
|
|
flags, |
|
|
op); |
|
|
|
|
|
} |
|
|
|
|
|
if (!per_channel) { |
|
|
return xnn_create_convolution2d_nhwc_qs8( |
|
|
pad_top, |
|
|
pad_right, |
|
|
pad_bottom, |
|
|
pad_left, |
|
|
kernel_h, |
|
|
kernel_w, |
|
|
stride_h, |
|
|
stride_w, |
|
|
dilation_h, |
|
|
dilation_w, |
|
|
groups, |
|
|
group_input_channels, |
|
|
group_output_channels, |
|
|
ip_chan_stride, |
|
|
op_chan_stride, |
|
|
izp, |
|
|
ip_scale, |
|
|
k_scales[0], |
|
|
kernel, |
|
|
bias, |
|
|
ozp, |
|
|
op_scale, |
|
|
op_min, |
|
|
op_max, |
|
|
flags, |
|
|
op); |
|
|
} else { |
|
|
return xnn_create_convolution2d_nhwc_qc8( |
|
|
pad_top, |
|
|
pad_right, |
|
|
pad_bottom, |
|
|
pad_left, |
|
|
kernel_h, |
|
|
kernel_w, |
|
|
stride_h, |
|
|
stride_w, |
|
|
dilation_h, |
|
|
dilation_w, |
|
|
groups, |
|
|
group_input_channels, |
|
|
group_output_channels, |
|
|
ip_chan_stride, |
|
|
op_chan_stride, |
|
|
izp, |
|
|
ip_scale, |
|
|
k_scales, |
|
|
kernel, |
|
|
bias, |
|
|
ozp, |
|
|
op_scale, |
|
|
op_min, |
|
|
op_max, |
|
|
flags, |
|
|
op); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
C10_ALWAYS_INLINE |
|
|
enum xnn_status xnnp_setup_convolution2d_nhwc( |
|
|
xnn_operator_t op, |
|
|
size_t batch, |
|
|
size_t in_h, |
|
|
size_t in_w, |
|
|
const int8_t* inp, |
|
|
int8_t* outp, |
|
|
pthreadpool_t pt_pool, |
|
|
bool per_channel = false, |
|
|
bool transpose = false, |
|
|
uint32_t adj_h = 0, |
|
|
uint32_t adj_w = 0) { |
|
|
if(transpose) { |
|
|
TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); |
|
|
return xnn_setup_deconvolution2d_nhwc_qs8( |
|
|
op, |
|
|
batch, |
|
|
in_h, |
|
|
in_w, |
|
|
adj_h, |
|
|
adj_w, |
|
|
inp, |
|
|
outp, |
|
|
pt_pool); |
|
|
} |
|
|
|
|
|
if (!per_channel) { |
|
|
return xnn_setup_convolution2d_nhwc_qs8( |
|
|
op, |
|
|
batch, |
|
|
in_h, |
|
|
in_w, |
|
|
inp, |
|
|
outp, |
|
|
pt_pool); |
|
|
} else { |
|
|
return xnn_setup_convolution2d_nhwc_qc8( |
|
|
op, |
|
|
batch, |
|
|
in_h, |
|
|
in_w, |
|
|
inp, |
|
|
outp, |
|
|
pt_pool); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
C10_ALWAYS_INLINE |
|
|
enum xnn_status xnnp_create_fully_connected_nc( |
|
|
size_t input_channels, |
|
|
size_t output_channels, |
|
|
size_t input_stride, |
|
|
size_t output_stride, |
|
|
int8_t input_zero_point, |
|
|
float input_scale, |
|
|
int8_t kernel_zero_point, |
|
|
float kernel_scale, |
|
|
const int8_t* kernel, |
|
|
const int32_t* bias, |
|
|
int8_t output_zero_point, |
|
|
float output_scale, |
|
|
int8_t output_min, |
|
|
int8_t output_max, |
|
|
uint32_t flags, |
|
|
xnn_operator_t* fully_connected_op_out) { |
|
|
|
|
|
TORCH_CHECK(!kernel_zero_point, "XNNPACK QS8 linear kernel expects kernel zero point to be zero." |
|
|
"But got: ", kernel_zero_point); |
|
|
return xnn_create_fully_connected_nc_qs8( |
|
|
input_channels, |
|
|
output_channels, |
|
|
input_stride, |
|
|
output_stride, |
|
|
input_zero_point, |
|
|
input_scale, |
|
|
kernel_scale, |
|
|
kernel, |
|
|
bias, |
|
|
output_zero_point, |
|
|
output_scale, |
|
|
output_min, |
|
|
output_max, |
|
|
flags, |
|
|
fully_connected_op_out); |
|
|
} |
|
|
|
|
|
C10_ALWAYS_INLINE |
|
|
enum xnn_status xnnp_setup_fully_connected_nc( |
|
|
xnn_operator_t fully_connected_op, |
|
|
size_t batch_size, |
|
|
const int8_t* input, |
|
|
int8_t* output, |
|
|
pthreadpool_t threadpool) { |
|
|
return xnn_setup_fully_connected_nc_qs8( |
|
|
fully_connected_op, |
|
|
batch_size, |
|
|
input, |
|
|
output, |
|
|
threadpool); |
|
|
} |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
#endif |
|
|
|