|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
|
|
#include "./Utils.h" |
|
|
|
|
|
#include <asmjit/core.h> |
|
|
#include <asmjit/x86.h> |
|
|
|
|
|
namespace fbgemm { |
|
|
|
|
|
#if ASMJIT_LIBRARY_VERSION >= ASMJIT_LIBRARY_MAKE_VERSION(1, 17, 0) |
|
|
|
|
|
class Xmm : public asmjit::x86::Vec { |
|
|
public: |
|
|
using Vec::Vec; |
|
|
using Vec::operator=; |
|
|
Xmm(uint32_t regId) : Vec(asmjit::x86::Vec::make_xmm(regId)) {} |
|
|
|
|
|
ASMJIT_INLINE_NODEBUG Xmm half() const noexcept { |
|
|
return Xmm(id()); |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
class Ymm : public asmjit::x86::Vec { |
|
|
public: |
|
|
using Vec::Vec; |
|
|
using Vec::operator=; |
|
|
Ymm(uint32_t regId) : Vec(asmjit::x86::Vec::make_ymm(regId)) {} |
|
|
|
|
|
ASMJIT_INLINE_NODEBUG Xmm half() const noexcept { |
|
|
return Xmm(id()); |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
class Zmm : public asmjit::x86::Vec { |
|
|
public: |
|
|
using Vec::Vec; |
|
|
using Vec::operator=; |
|
|
Zmm(uint32_t regId) : Vec(asmjit::x86::Vec::make_zmm(regId)) {} |
|
|
|
|
|
ASMJIT_INLINE_NODEBUG Ymm half() const noexcept { |
|
|
return Ymm(id()); |
|
|
} |
|
|
}; |
|
|
#else |
|
|
using Xmm = asmjit::x86::Xmm; |
|
|
using Ymm = asmjit::x86::Ymm; |
|
|
using Zmm = asmjit::x86::Zmm; |
|
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <inst_set_t inst_set> |
|
|
struct simd_info; |
|
|
|
|
|
template <> |
|
|
struct simd_info<inst_set_t::avx2> { |
|
|
static constexpr int WIDTH_BITS = 256; |
|
|
static constexpr int WIDTH_BYTES = 32; |
|
|
static constexpr int WIDTH_32BIT_ELEMS = 8; |
|
|
static constexpr int NUM_VEC_REGS = 16; |
|
|
|
|
|
using vec_reg_t = Ymm; |
|
|
}; |
|
|
|
|
|
template <> |
|
|
struct simd_info<inst_set_t::sve> { |
|
|
|
|
|
static constexpr int WIDTH_BITS = 256; |
|
|
static constexpr int WIDTH_BYTES = 32; |
|
|
static constexpr int WIDTH_32BIT_ELEMS = 8; |
|
|
static constexpr int NUM_VEC_REGS = 32; |
|
|
}; |
|
|
|
|
|
template <> |
|
|
struct simd_info<inst_set_t::avx512> { |
|
|
static constexpr int WIDTH_BITS = 512; |
|
|
static constexpr int WIDTH_BYTES = 64; |
|
|
static constexpr int WIDTH_32BIT_ELEMS = 16; |
|
|
static constexpr int NUM_VEC_REGS = 32; |
|
|
|
|
|
using vec_reg_t = Zmm; |
|
|
}; |
|
|
|
|
|
template <> |
|
|
struct simd_info<inst_set_t::avx512_vnni> |
|
|
: public simd_info<inst_set_t::avx512> {}; |
|
|
|
|
|
template <> |
|
|
struct simd_info<inst_set_t::avx512_ymm> { |
|
|
static constexpr int WIDTH_BITS = 256; |
|
|
static constexpr int WIDTH_BYTES = 32; |
|
|
static constexpr int WIDTH_32BIT_ELEMS = 8; |
|
|
static constexpr int NUM_VEC_REGS = 32; |
|
|
|
|
|
using vec_reg_t = Ymm; |
|
|
}; |
|
|
|
|
|
template <> |
|
|
struct simd_info<inst_set_t::avx512_vnni_ymm> |
|
|
: public simd_info<inst_set_t::avx512_ymm> {}; |
|
|
|
|
|
} |
|
|
|