File size: 4,778 Bytes
d1d4335 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <cstdint>
#include "./FbgemmBuild.h"
#include "./UtilsAvx2.h"
/// @defgroup fbgemm-quant-utils-avx2 Quantization Utilities (AVX2)
///
namespace fbgemm {
/// Struct from <a href="https://github.com/google/gemmlowp">`gemmlowp`</a>
///
/// A structure to hold quantization parameters `scale` and `zero_point`.
/// The meaning of these values is as the constants in the quantization equation
///
/// `real_value = scale * (quantized_value - zero_point)`
///
/// In other words, 'zero_point' is the quantized value that corresponds
/// to the real value 0, and 'scale' is the difference of real values
/// corresponding to consecutive quantized values.
struct FBGEMM_API TensorQuantizationParams {
float scale;
std::int32_t zero_point;
int precision;
float Min() const;
float Max() const;
};
/// Parameters when we scale from int32 intermediate matrix multiplication
/// results to 8-bit integers
struct FBGEMM_API RequantizationParams {
/// For floating-point requantization
float real_multiplier;
/// For fixed-point requantization
std::int32_t multiplier;
int right_shift;
TensorQuantizationParams target_qparams;
};
////////////////////////////////////////////////////////////////////////////////
// Utility functions
////////////////////////////////////////////////////////////////////////////////
template <typename T = std::uint8_t, bool LEGACY = true>
void QuantizeAvx2(
const float* src,
T* dst,
int64_t len,
const TensorQuantizationParams& qparams);
template <typename T = std::uint8_t>
void FusedQuantizeDequantizeAvx2(
const float* src,
float* dst,
int len,
const TensorQuantizationParams& qparams,
float noise_ratio = 0.0f);
/// @ingroup fbgemm-quant-utils-avx2
///
/// Random number generator in [0, 9] based on
/// <a href="https://www.jstatsoft.org/v08/i14/paper">this paper</a>.
uint32_t FBGEMM_API Xor128(void);
/// @ingroup fbgemm-quant-utils-avx2
///
/// @brief Find the min and max value in a float matrix.
void FBGEMM_API FindMinMax(const float* m, float* min, float* max, int64_t len);
void RequantizeFixedPointAvx2(
const std::int32_t* src,
std::uint8_t* dst,
int len,
const RequantizationParams& params);
void RequantizeAvx2(
const std::int32_t* src,
std::uint8_t* dst,
int len,
const RequantizationParams& params);
/// @ingroup fbgemm-quant-utils-avx2
///
/// Requantize with avx2 and bias is fused.
template <
bool A_SYMMETRIC,
bool B_SYMMETRIC,
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
bool FUSE_RELU,
typename BIAS_TYPE = std::int32_t,
bool DIRECT = false>
FBGEMM_API void requantizeOutputProcessingAvx2(
std::uint8_t* out,
const std::int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
const requantizationParams_t<BIAS_TYPE>& r);
template <
bool A_SYMMETRIC,
bool B_SYMMETRIC,
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
bool FUSE_RELU,
int C_PER_G,
typename BIAS_TYPE = std::int32_t>
FBGEMM_API void requantizeOutputProcessingGConvAvx2(
std::uint8_t* out,
const std::int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
const requantizationParams_t<BIAS_TYPE>& r);
template <
bool A_SYMMETRIC,
bool B_SYMMETRIC,
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
bool FUSE_RELU>
FBGEMM_API void requantizeForFloatAvx2(
float* out,
const std::int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
const requantizationForFloatParams_t& r);
template <typename InputType, int BIT_RATE>
void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
const InputType* input,
size_t input_rows,
int input_columns,
std::uint8_t* output);
template <typename InputType>
void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2(
const InputType* input,
size_t input_rows,
int input_columns,
std::uint8_t* output);
template <typename OutputType, int BIT_RATE>
void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2(
const std::uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output);
template <typename OutputType>
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(
const std::uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output);
} // namespace fbgemm
|