|
|
#pragma once
|
|
|
|
|
|
#include <c10/core/QScheme.h>
|
|
|
#include <c10/core/MemoryFormat.h>
|
|
|
#include <c10/macros/Macros.h>
|
|
|
#include <c10/util/Exception.h>
|
|
|
#include <c10/util/intrusive_ptr.h>
|
|
|
#include <c10/core/ScalarType.h>
|
|
|
#include <c10/core/TensorOptions.h>
|
|
|
|
|
|
#include <ATen/Tensor.h>
|
|
|
#include <ATen/TensorUtils.h>
|
|
|
|
|
|
#include <ATen/core/QuantizerBase.h>
|
|
|
|
|
|
#include <cmath>
|
|
|
#include <memory>
|
|
|
#include <utility>
|
|
|
|
|
|
namespace at {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API UnknownQuantizer : public Quantizer {
|
|
|
explicit UnknownQuantizer(ScalarType scalar_type)
|
|
|
: Quantizer(scalar_type) {}
|
|
|
|
|
|
Tensor quantize(const Tensor& tensor) override;
|
|
|
Tensor dequantize(const Tensor& qtensor) override;
|
|
|
Tensor& dequantize_out(Tensor& rtensor, const Tensor& qtensor) override;
|
|
|
QScheme qscheme() const override;
|
|
|
bool equalTo(QuantizerPtr other) const override;
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API UniformQuantizer : public Quantizer {
|
|
|
explicit UniformQuantizer(ScalarType scalar_type) : Quantizer(scalar_type) {}
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API NonUniformQuantizer : public Quantizer {
|
|
|
explicit NonUniformQuantizer(ScalarType scalar_type) : Quantizer(scalar_type) {}
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API AffineQuantizer : public UniformQuantizer {
|
|
|
explicit AffineQuantizer(ScalarType scalar_type) : UniformQuantizer(scalar_type) {}
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API PerTensorAffineQuantizer : public AffineQuantizer {
|
|
|
explicit PerTensorAffineQuantizer(ScalarType scalar_type, double scale, int64_t zero_point)
|
|
|
: AffineQuantizer(scalar_type),
|
|
|
scale_(scale),
|
|
|
zero_point_(zero_point) {}
|
|
|
|
|
|
Tensor quantize(const Tensor& tensor) override;
|
|
|
Tensor dequantize(const Tensor& qtensor) override;
|
|
|
Tensor& dequantize_out(Tensor& rtensor, const Tensor& qtensor) override;
|
|
|
|
|
|
QScheme qscheme() const override {
|
|
|
return kPerTensorAffine;
|
|
|
}
|
|
|
|
|
|
double scale() const {
|
|
|
return scale_;
|
|
|
}
|
|
|
|
|
|
int64_t zero_point() const {
|
|
|
return zero_point_;
|
|
|
}
|
|
|
|
|
|
bool equalTo(QuantizerPtr other) const override {
|
|
|
if (!other.get() || other->qscheme() != kPerTensorAffine) {
|
|
|
return false;
|
|
|
}
|
|
|
auto* other_per_tensor_affine =
|
|
|
static_cast<PerTensorAffineQuantizer*>(other.get());
|
|
|
return scalar_type() == other_per_tensor_affine->scalar_type() &&
|
|
|
scale() == other_per_tensor_affine->scale() &&
|
|
|
zero_point() == other_per_tensor_affine->zero_point();
|
|
|
}
|
|
|
|
|
|
private:
|
|
|
const double scale_;
|
|
|
|
|
|
const int64_t zero_point_;
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API PerChannelAffineQuantizer : public AffineQuantizer {
|
|
|
explicit PerChannelAffineQuantizer(
|
|
|
ScalarType scalar_type,
|
|
|
Tensor scales,
|
|
|
Tensor zero_points,
|
|
|
int64_t axis)
|
|
|
: AffineQuantizer(scalar_type),
|
|
|
scales_(std::move(scales)),
|
|
|
zero_points_(std::move(zero_points)),
|
|
|
axis_(axis) {}
|
|
|
|
|
|
QScheme qscheme() const override {
|
|
|
return kPerChannelAffine;
|
|
|
}
|
|
|
|
|
|
Tensor scales() const {
|
|
|
return scales_;
|
|
|
}
|
|
|
|
|
|
Tensor zero_points() const {
|
|
|
return zero_points_;
|
|
|
}
|
|
|
|
|
|
int64_t axis() const {
|
|
|
return axis_;
|
|
|
}
|
|
|
|
|
|
Tensor quantize(const Tensor& tensor) override;
|
|
|
Tensor dequantize(const Tensor& qtensor) override;
|
|
|
Tensor& dequantize_out(Tensor& rtensor, const Tensor& qtensor) override;
|
|
|
|
|
|
bool equalTo(QuantizerPtr other) const override {
|
|
|
if (!other.get() || other->qscheme() != kPerChannelAffine) {
|
|
|
return false;
|
|
|
}
|
|
|
auto* other_per_channel_affine =
|
|
|
static_cast<PerChannelAffineQuantizer*>(other.get());
|
|
|
return scalar_type() == other_per_channel_affine->scalar_type() &&
|
|
|
scales().equal(other_per_channel_affine->scales()) &&
|
|
|
zero_points().equal(other_per_channel_affine->zero_points()) &&
|
|
|
axis() == other_per_channel_affine->axis();
|
|
|
}
|
|
|
|
|
|
protected:
|
|
|
Tensor scales_;
|
|
|
Tensor zero_points_;
|
|
|
const int64_t axis_;
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API PerChannelAffineFloatQParamsQuantizer : public PerChannelAffineQuantizer {
|
|
|
explicit PerChannelAffineFloatQParamsQuantizer(
|
|
|
ScalarType scalar_type,
|
|
|
Tensor scales,
|
|
|
Tensor zero_points,
|
|
|
int64_t axis)
|
|
|
: PerChannelAffineQuantizer(scalar_type,
|
|
|
scales,
|
|
|
zero_points,
|
|
|
axis) {}
|
|
|
|
|
|
QScheme qscheme() const override {
|
|
|
return kPerChannelAffineFloatQParams;
|
|
|
}
|
|
|
|
|
|
Tensor quantize(const Tensor& tensor) override;
|
|
|
Tensor dequantize(const Tensor& qtensor) override;
|
|
|
Tensor& dequantize_out(Tensor& rtensor, const Tensor& qtensor) override;
|
|
|
|
|
|
bool equalTo(QuantizerPtr other) const override {
|
|
|
if (!other.get() || other->qscheme() != kPerChannelAffineFloatQParams) {
|
|
|
return false;
|
|
|
}
|
|
|
auto* other_per_channel_float_qparams =
|
|
|
static_cast<PerChannelAffineFloatQParamsQuantizer*>(other.get());
|
|
|
return scalar_type() == other_per_channel_float_qparams->scalar_type() &&
|
|
|
scales().equal(other_per_channel_float_qparams->scales()) &&
|
|
|
zero_points().equal(other_per_channel_float_qparams->zero_points()) &&
|
|
|
axis() == other_per_channel_float_qparams->axis();
|
|
|
}
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TORCH_API QTensorImpl* get_qtensorimpl(const TensorBase& self);
|
|
|
|
|
|
|
|
|
|
|
|
TORCH_API QuantizerPtr
|
|
|
make_per_tensor_affine_quantizer(
|
|
|
double scale, int64_t zero_point, ScalarType scalar_type);
|
|
|
|
|
|
TORCH_API QuantizerPtr make_per_channel_affine_quantizer(
|
|
|
const Tensor& scales,
|
|
|
const Tensor& zero_points,
|
|
|
int64_t axis,
|
|
|
ScalarType scalar_type);
|
|
|
|
|
|
TORCH_API QuantizerPtr make_unknown_quantizer(ScalarType scalar_type);
|
|
|
|
|
|
|
|
|
TORCH_API Tensor new_qtensor(
|
|
|
IntArrayRef sizes,
|
|
|
const TensorOptions& options,
|
|
|
QuantizerPtr quantizer);
|
|
|
|
|
|
TORCH_API void set_quantizer_(const Tensor& self, ConstQuantizerPtr quantizer);
|
|
|
|
|
|
TORCH_API Tensor from_blob_quantized_per_tensor_affine(
|
|
|
void* data,
|
|
|
IntArrayRef sizes,
|
|
|
IntArrayRef strides,
|
|
|
std::function<void(void*)> deleter,
|
|
|
const float scale,
|
|
|
const int64_t zeroPoint,
|
|
|
const TensorOptions& options);
|
|
|
|
|
|
TORCH_API Tensor from_blob_quantized_per_tensor_affine(
|
|
|
void* data,
|
|
|
IntArrayRef sizes,
|
|
|
std::function<void(void*)> deleter,
|
|
|
const float scale,
|
|
|
const int64_t zeroPoint,
|
|
|
const TensorOptions& options);
|
|
|
|
|
|
TORCH_API Tensor from_blob_quantized_per_channel_affine(
|
|
|
void* data,
|
|
|
IntArrayRef sizes,
|
|
|
std::function<void(void*)> deleter,
|
|
|
const Tensor& scales,
|
|
|
const Tensor& zero_points,
|
|
|
const int64_t axis,
|
|
|
const TensorOptions& options);
|
|
|
|
|
|
}
|
|
|
|