| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_ |
| #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_ |
|
|
| #include <stddef.h> |
| #include <stdint.h> |
|
|
| #include "compression/nuq.h" |
| #include "compression/sfp.h" |
| #include "hwy/base.h" |
|
|
| #endif |
|
|
| |
| #if defined(THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE) == \ |
| defined(HWY_TARGET_TOGGLE) |
| #ifdef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE |
| #undef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE |
| #else |
| #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE |
| #endif |
|
|
| #include "compression/sfp-inl.h" |
| #include "hwy/contrib/sort/vqsort-inl.h" |
| #include "hwy/highway.h" |
|
|
| #ifndef HWY_IF_CONSTEXPR |
| #define HWY_IF_CONSTEXPR if |
| #endif |
|
|
| HWY_BEFORE_NAMESPACE(); |
| namespace gcpp { |
| namespace HWY_NAMESPACE { |
| namespace hn = hwy::HWY_NAMESPACE; |
|
|
| |
| class NuqClustering { |
| |
| |
| |
| struct FloatPayload { |
| |
| static HWY_INLINE float Clear(float f) { |
| const uint32_t binary32 = hwy::BitCastScalar<uint32_t>(f); |
| return hwy::BitCastScalar<float>(binary32 & |
| ~static_cast<uint32_t>(kGroupSize - 1)); |
| } |
|
|
| |
| static HWY_INLINE float Set(float f, size_t bits) { |
| HWY_DASSERT(bits < kGroupSize); |
| const uint32_t binary32 = hwy::BitCastScalar<uint32_t>(Clear(f)); |
| return hwy::BitCastScalar<float>(static_cast<uint32_t>(binary32 | bits)); |
| } |
|
|
| |
| static HWY_INLINE size_t Get(float f) { |
| return hwy::BitCastScalar<uint32_t>(f) & |
| static_cast<uint32_t>(kGroupSize - 1); |
| } |
| }; |
|
|
| |
| class ClusterCost { |
| public: |
| explicit ClusterCost(const float* sorted) { |
| cumsum_[0] = cumsum2_[0] = 0.0; |
| for (size_t i = 0; i < kGroupSize; ++i) { |
| const float x = FloatPayload::Clear(sorted[i]); |
| cumsum_[1 + i] = x + cumsum_[i]; |
| cumsum2_[1 + i] = x * x + cumsum2_[i]; |
| } |
|
|
| inv_len_[0] = 0.0f; |
| for (size_t i = 0; i <= kGroupSize; ++i) { |
| inv_len_[i] = 1.0f / static_cast<float>(i); |
| } |
| } |
|
|
| float SumOfSorted(size_t first, size_t last) const { |
| return cumsum_[last + 1] - cumsum_[first]; |
| } |
|
|
| |
| |
| |
| template <class DF> |
| hn::Vec<DF> operator()(DF df, size_t first, size_t last) const { |
| |
| HWY_DASSERT(first < kGroupSize); |
| HWY_DASSERT(last < kGroupSize); |
| const int len = static_cast<int>(last) - static_cast<int>(first) + 1; |
| const hn::Vec<DF> vlen = hn::Iota(df, static_cast<float>(len)); |
|
|
| const hn::Vec<DF> u_lo = hn::Set(df, cumsum_[first]); |
| const hn::Vec<DF> u_lo2 = hn::Set(df, cumsum2_[first]); |
| const hn::Vec<DF> hi = hn::LoadU(df, cumsum_ + last + 1); |
| const hn::Vec<DF> hi2 = hn::LoadU(df, cumsum2_ + last + 1); |
| const hn::Vec<DF> sum = hn::Sub(hi, u_lo); |
| const hn::Vec<DF> sum2 = hn::Sub(hi2, u_lo2); |
|
|
| |
| const hn::Vec<DF> mu = hn::Mul(sum, hn::LoadU(df, inv_len_ + len)); |
|
|
| |
| const hn::Vec<DF> mu2 = hn::Mul(mu, mu); |
| const hn::Vec<DF> two_mu = hn::Add(mu, mu); |
| return hn::NegMulAdd(two_mu, sum, hn::MulAdd(vlen, mu2, sum2)); |
| } |
|
|
| private: |
| |
| float cumsum_[kGroupSize + 1]; |
| float cumsum2_[kGroupSize + 1]; |
| float inv_len_[kGroupSize + 1]; |
| }; |
|
|
| |
| |
| |
| template <class DF> |
| static HWY_INLINE hn::Vec<DF> ClusterDynProg( |
| DF df, const AlignedMatrix<float>& D, const ClusterCost& cc, |
| const size_t num_clusters, const size_t last, const size_t j) { |
| HWY_DASSERT(last < kGroupSize); |
| HWY_DASSERT(0 != j && j < kGroupSize); |
|
|
| const hn::RebindToSigned<decltype(df)> di; |
| using VF = hn::Vec<decltype(df)>; |
| using VI = hn::Vec<decltype(di)>; |
| using MI = hn::Mask<decltype(di)>; |
|
|
| const VI vlast = hn::Iota(di, static_cast<int32_t>(last)); |
|
|
| |
| const MI valid = hn::Lt(hn::Set(di, static_cast<int32_t>(j) - 1), vlast); |
| |
| const VF max = hn::Set(df, 1E38f); |
| |
| const VF vd = hn::Set(df, D(num_clusters - 1, j - 1)); |
| |
| return hn::MaskedAddOr(max, RebindMask(df, valid), vd, cc(df, j, last)); |
| } |
|
|
| public: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| template <class DF> |
| static HWY_NOINLINE size_t ClusterExactL2(DF df, const float* x, |
| ClusterBuf& buf, |
| float* HWY_RESTRICT centers, |
| uint16_t* HWY_RESTRICT indices) { |
| const hn::RebindToSigned<decltype(df)> di; |
| using VF = hn::Vec<decltype(df)>; |
| using VI = hn::Vec<decltype(di)>; |
| const VI k1 = hn::Set(di, 1); |
| const size_t N = hn::Lanes(df); |
|
|
| HWY_ALIGN float sorted_and_i[kGroupSize]; |
| for (size_t i = 0; i < kGroupSize; ++i) { |
| sorted_and_i[i] = FloatPayload::Set(x[i], i); |
| } |
| hn::VQSortStatic(sorted_and_i, kGroupSize, hwy::SortAscending()); |
| ClusterCost cc(sorted_and_i); |
|
|
| |
| |
| AlignedMatrix<float>& D = buf.d; |
| |
| AlignedMatrix<int32_t>& T = buf.t; |
|
|
| |
| for (size_t last = 0; last < kGroupSize; last += N) { |
| hn::Store(cc(df, 0, last), df, &D(0, last)); |
| hn::Store(Zero(di), di, &T(0, last)); |
| } |
|
|
| for (size_t num_clusters = 1; num_clusters < kClusters; ++num_clusters) { |
| |
| for (size_t last = 0; last < kGroupSize; last += N) { |
| VF min = hn::LoadU(df, &D(0, last)); |
| VI arg = hn::Zero(di); |
| |
| VI vj = k1; |
| for (size_t j = 1; j < last + N; ++j, vj = hn::Add(vj, k1)) { |
| const VF c = ClusterDynProg(df, D, cc, num_clusters, last, j); |
|
|
| |
| const auto less = hn::Lt(c, min); |
| min = hn::IfThenElse(less, c, min); |
| arg = hn::IfThenElse(RebindMask(di, less), vj, arg); |
| } |
| hn::Store(min, df, &D(num_clusters, last)); |
| hn::Store(arg, di, &T(num_clusters, last)); |
| } |
| } |
|
|
| |
| size_t last = kGroupSize - 1; |
| size_t unused_clusters = 0; |
| for (size_t k = kClusters - 1; k < kClusters; --k) { |
| const size_t start = static_cast<size_t>(T(k, last)); |
| |
| const float sum = cc.SumOfSorted(start, last); |
| const int size = static_cast<int>(last) - static_cast<int>(start) + 1; |
| HWY_DASSERT(0 < size && size <= static_cast<int>(kGroupSize)); |
| centers[k] = sum / static_cast<float>(size); |
|
|
| |
| |
| for (size_t i = start; i <= last; ++i) { |
| const size_t idx_x = FloatPayload::Get(sorted_and_i[i]); |
| HWY_DASSERT(idx_x < kGroupSize); |
| indices[idx_x] = static_cast<uint16_t>(k); |
| } |
|
|
| |
| if (start == 0) { |
| unused_clusters = k; |
| for (size_t cluster = 0; cluster < unused_clusters; ++cluster) { |
| centers[cluster] = 0.0f; |
| } |
| break; |
| } |
|
|
| last = start - 1; |
| HWY_DASSERT(last < kGroupSize); |
| } |
|
|
| if (HWY_IS_DEBUG_BUILD) { |
| |
| for (size_t i = unused_clusters + 1; i < kClusters; ++i) { |
| HWY_DASSERT(centers[i] >= centers[i - 1]); |
| } |
| } |
| return unused_clusters; |
| } |
| }; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| class NibbleCodec { |
| public: |
| |
| |
| template <class D16, class V16 = hn::Vec<D16>> |
| static HWY_INLINE void OrderedPackU16(D16 d16, V16 in0, V16 in1, V16 in2, |
| V16 in3, uint8_t* HWY_RESTRICT out) { |
| const hn::Repartition<uint8_t, D16> d8; |
| const hn::Repartition<uint32_t, D16> d32; |
| const hn::Repartition<uint64_t, D16> d64; |
| using V8 = hn::Vec<decltype(d8)>; |
|
|
| |
| |
| |
| const auto combine_u16_pair_to_8 = [d16, d32](V16 v16) HWY_ATTR { |
| return hn::Xor( |
| v16, hn::BitCast(d16, hn::ShiftRight<12>(hn::BitCast(d32, v16)))); |
| }; |
|
|
| const V16 u8_0 = combine_u16_pair_to_8(in0); |
| const V16 u8_1 = combine_u16_pair_to_8(in1); |
| const V16 u8_2 = combine_u16_pair_to_8(in2); |
| const V16 u8_3 = combine_u16_pair_to_8(in3); |
| V8 packed; |
| if (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) { |
| |
| |
| |
| const V8 x2_0 = hn::ConcatEven(d8, BitCast(d8, u8_1), BitCast(d8, u8_0)); |
| const V8 x2_1 = hn::ConcatEven(d8, BitCast(d8, u8_3), BitCast(d8, u8_2)); |
| packed = hn::ConcatEven(d8, x2_1, x2_0); |
| } else { |
| |
| |
| const auto combine_u32_pair_to_16 = [d16, d64](V16 v16) HWY_ATTR { |
| return hn::Xor( |
| v16, hn::BitCast(d16, hn::ShiftRight<24>(hn::BitCast(d64, v16)))); |
| }; |
| const V16 u16_0 = combine_u32_pair_to_16(u8_0); |
| const V16 u16_1 = combine_u32_pair_to_16(u8_1); |
| const V16 u16_2 = combine_u32_pair_to_16(u8_2); |
| const V16 u16_3 = combine_u32_pair_to_16(u8_3); |
| |
| |
| const V16 x2_0 = hn::ConcatEven(d16, u16_1, u16_0); |
| const V16 x2_1 = hn::ConcatEven(d16, u16_3, u16_2); |
| packed = hn::BitCast(d8, hn::ConcatEven(d16, x2_1, x2_0)); |
| } |
| hn::StoreU(packed, d8, out); |
| } |
|
|
| |
| |
| template <class D16, class V16 = hn::Vec<D16>> |
| static HWY_INLINE V16 OrderedUnpackU16(D16 d16, const uint8_t* packed) { |
| const hn::Repartition<uint8_t, D16> d8; |
| using V8 = hn::Vec<decltype(d8)>; |
| const hn::CappedTag<uint8_t, d16.MaxBytes() / 4> d_load; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| V8 rep4; |
| if (HWY_HAVE_SCALABLE) { |
| |
| const size_t num_bytes = HWY_MAX(1, hn::Lanes(d8) / 4); |
| const V8 bytes = hn::LoadN(d8, packed, num_bytes); |
| |
| const V8 idx = hn::And(hn::Iota(d8, 0), hn::Set(d8, 0xFCu)); |
| rep4 = hn::TableLookupBytes(bytes, idx); |
| } else if (hn::MaxLanes(d16) <= 8) { |
| const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed)); |
| alignas(16) static constexpr uint8_t kRep4[16] = { |
| HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3)}; |
| rep4 = hn::TableLookupBytes(bytes, hn::Load(d8, kRep4)); |
| } else if (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) { |
| |
| const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed)); |
| alignas(64) static constexpr uint8_t kRep4[64] = { |
| HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3), |
| HWY_REP4(4), HWY_REP4(5), HWY_REP4(6), HWY_REP4(7), |
| HWY_REP4(8), HWY_REP4(9), HWY_REP4(10), HWY_REP4(11), |
| HWY_REP4(12), HWY_REP4(13), HWY_REP4(14), HWY_REP4(15)}; |
| rep4 = hn::TableLookupLanes(bytes, hn::SetTableIndices(d8, kRep4)); |
| } else if (hn::MaxLanes(d16) == 16) { |
| const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed)); |
| |
| |
| const V8 bcast = hn::ConcatLowerLower(d8, bytes, bytes); |
| alignas(32) static constexpr uint8_t kRep4[32] = { |
| HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3), |
| HWY_REP4(4), HWY_REP4(5), HWY_REP4(6), HWY_REP4(7)}; |
| rep4 = hn::TableLookupBytes(bcast, hn::Load(d8, kRep4)); |
| } else if (hn::MaxLanes(d16) == 32) { |
| const V8 bytes = hn::LoadDup128(d8, packed); |
| alignas(64) static constexpr uint8_t kRep4[64] = { |
| HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3), |
| HWY_REP4(4), HWY_REP4(5), HWY_REP4(6), HWY_REP4(7), |
| HWY_REP4(8), HWY_REP4(9), HWY_REP4(10), HWY_REP4(11), |
| HWY_REP4(12), HWY_REP4(13), HWY_REP4(14), HWY_REP4(15)}; |
| rep4 = hn::TableLookupBytes(bytes, hn::Load(d8, kRep4)); |
| } else { |
| HWY_DASSERT(false); |
| } |
|
|
| const V16 mask4 = hn::Set(d16, 0xF); |
| const V16 u16 = BitCast(d16, rep4); |
| |
| |
| |
| |
| |
| return hn::And(mask4, hn::OddEven(hn::ShiftRight<4>(u16), u16)); |
| } |
| }; |
|
|
| |
| class NuqCodec { |
| |
| template <class DU> |
| static constexpr size_t NumTables(DU du) { |
| return (!HWY_HAVE_SCALABLE && du.MaxBytes() >= 32) ? 1 : 2; |
| } |
|
|
| |
| |
| |
| template <class DU, HWY_IF_U16_D(DU)> |
| static HWY_INLINE hn::Vec<DU> LoadTable(DU du, const uint8_t* centers, |
| hn::Vec<DU>* HWY_RESTRICT tbl1) { |
| |
| |
| const hn::CappedTag<hwy::bfloat16_t, kClusters> d_table; |
| |
| |
| HWY_DASSERT(hn::Lanes(du) >= hn::Lanes(d_table) || NumTables(du) == 2); |
|
|
| HWY_ALIGN hwy::bfloat16_t table[kClusters]; |
| SfpCodec::Dec(d_table, reinterpret_cast<const SfpStream*>(centers), |
| kClusters, table); |
|
|
| |
| |
| HWY_DASSERT(hn::Lanes(du) >= 8); |
|
|
| HWY_IF_CONSTEXPR(NumTables(du) == 2) { |
| |
| const hn::CappedTag<hwy::bfloat16_t, kClusters / 2> d_table2; |
| *tbl1 = hn::ResizeBitCast(du, hn::LoadU(d_table2, table + kClusters / 2)); |
| } |
| return hn::ResizeBitCast(du, hn::Load(d_table, table)); |
| } |
|
|
| |
| template <class DU> |
| static HWY_INLINE void TableLookups(DU du, hn::Vec<DU> tbl0, hn::Vec<DU> tbl1, |
| const uint8_t* packed, hn::Vec<DU>& c0, |
| hn::Vec<DU>& c1) { |
| using V16 = hn::Vec<decltype(du)>; |
| const size_t N16 = hn::Lanes(du); |
|
|
| const V16 idx0 = NibbleCodec::OrderedUnpackU16(du, packed); |
| const V16 idx1 = NibbleCodec::OrderedUnpackU16(du, packed + N16 / 2); |
|
|
| const auto indices0 = hn::IndicesFromVec(du, idx0); |
| const auto indices1 = hn::IndicesFromVec(du, idx1); |
|
|
| HWY_IF_CONSTEXPR(NumTables(du) == 1) { |
| (void)tbl1; |
| c0 = hn::TableLookupLanes(tbl0, indices0); |
| c1 = hn::TableLookupLanes(tbl0, indices1); |
| } |
| HWY_IF_CONSTEXPR(NumTables(du) == 2) { |
| c0 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices0); |
| c1 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices1); |
| } |
| } |
|
|
| public: |
| |
| |
| |
| |
| |
| |
| template <class DF, HWY_IF_F32_D(DF)> |
| static HWY_INLINE size_t Enc(DF df, const float* const in, const size_t num, |
| ClusterBuf& buf, const size_t out_capacity, |
| NuqStream* const out, const size_t out_ofs) { |
| const hn::Repartition<uint16_t, DF> d16; |
| using V16 = hn::Vec<decltype(d16)>; |
|
|
| const size_t N16 = hn::Lanes(d16); |
| HWY_ASSERT(kGroupSize >= 4 * N16); |
|
|
| HWY_ASSERT(out_ofs + num <= out_capacity); |
| buf.Resize(num); |
| HWY_ASSERT(num % kGroupSize == 0); |
| HWY_ASSERT(out_capacity % kGroupSize == 0); |
| HWY_ASSERT(out_ofs % kGroupSize == 0); |
| const size_t num_groups = num / kGroupSize; |
| const size_t ofs_groups = out_ofs / kGroupSize; |
|
|
| size_t unused_clusters = 0; |
| for (size_t g = 0; g < num_groups; ++g) { |
| const float* HWY_RESTRICT g_in = in + g * kGroupSize; |
| float* HWY_RESTRICT g_centers = buf.centers.get() + g * kClusters; |
| uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize; |
| unused_clusters += |
| NuqClustering::ClusterExactL2(df, g_in, buf, g_centers, g_idx); |
| } |
|
|
| uint8_t* centers = &out->byte + ofs_groups * kClusters; |
| SfpCodec::Enc(df, buf.centers.get(), num_groups * kClusters, |
| reinterpret_cast<SfpStream*>(centers)); |
| uint8_t* packed_start = &out->byte + NuqStream::PackedStart(out_capacity) + |
| ofs_groups * kGroupSize / 2; |
|
|
| HWY_UNROLL(1) |
| for (size_t g = 0; g < num_groups; ++g) { |
| const uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize; |
| uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; |
|
|
| HWY_UNROLL(1) |
| for (size_t i = 0; i < kGroupSize; i += 4 * N16) { |
| const V16 idx0 = hn::LoadU(d16, g_idx + i + N16 * 0); |
| const V16 idx1 = hn::LoadU(d16, g_idx + i + N16 * 1); |
| const V16 idx2 = hn::LoadU(d16, g_idx + i + N16 * 2); |
| const V16 idx3 = hn::LoadU(d16, g_idx + i + N16 * 3); |
| NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3, |
| g_packed + i / 2); |
| } |
| } |
|
|
| return unused_clusters; |
| } |
|
|
| |
| |
| |
| template <class DBF, HWY_IF_BF16_D(DBF)> |
| static HWY_INLINE void Dec(DBF dbf, const size_t in_capacity, |
| const NuqStream* const in, const size_t in_ofs, |
| hwy::bfloat16_t* const out, const size_t num) { |
| const hn::RebindToUnsigned<decltype(dbf)> d16; |
| using V16 = hn::Vec<decltype(d16)>; |
|
|
| const size_t N16 = hn::Lanes(d16); |
| HWY_DASSERT(kGroupSize >= 4 * N16); |
|
|
| HWY_DASSERT(in_ofs + num <= in_capacity); |
| HWY_DASSERT(in_capacity % kGroupSize == 0); |
| HWY_DASSERT(in_ofs % kGroupSize == 0); |
| HWY_DASSERT(num % kGroupSize == 0); |
| const size_t num_groups = num / kGroupSize; |
| const size_t ofs_groups = in_ofs / kGroupSize; |
| const uint8_t* tables = &in->byte + ofs_groups * kClusters; |
| const uint8_t* packed_start = &in->byte + |
| NuqStream::PackedStart(in_capacity) + |
| ofs_groups * kGroupSize / 2; |
|
|
| HWY_UNROLL(1) |
| for (size_t g = 0; g < num_groups; ++g) { |
| const uint8_t* g_centers = tables + g * kClusters; |
| const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; |
| hwy::bfloat16_t* HWY_RESTRICT g_out = out + g * kGroupSize; |
|
|
| V16 tbl1 = Zero(d16); |
| const V16 tbl0 = LoadTable(d16, g_centers, &tbl1); |
|
|
| HWY_UNROLL(1) |
| for (size_t i = 0; i < kGroupSize; i += 2 * N16) { |
| V16 c0, c1; |
| TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1); |
| hn::StoreU(BitCast(dbf, c0), dbf, g_out + i + N16 * 0); |
| hn::StoreU(BitCast(dbf, c1), dbf, g_out + i + N16 * 1); |
| } |
| } |
| } |
|
|
| |
| |
| |
| template <class DF, HWY_IF_F32_D(DF)> |
| static HWY_INLINE void Dec(DF df, const size_t in_capacity, |
| const NuqStream* const in, const size_t in_ofs, |
| float* const out, const size_t num) { |
| const hn::Repartition<hwy::bfloat16_t, DF> dbf; |
| const hn::RebindToUnsigned<decltype(dbf)> d16; |
| using V16 = hn::Vec<decltype(d16)>; |
| using VF = hn::Vec<DF>; |
|
|
| const size_t NF = hn::Lanes(df); |
| HWY_DASSERT(kGroupSize >= 4 * NF); |
|
|
| HWY_DASSERT(in_ofs + num <= in_capacity); |
| HWY_DASSERT(in_capacity % kGroupSize == 0); |
| HWY_DASSERT(in_ofs % kGroupSize == 0); |
| HWY_DASSERT(num % kGroupSize == 0); |
| const size_t ofs_groups = in_ofs / kGroupSize; |
| const size_t num_groups = num / kGroupSize; |
| const uint8_t* tables = &in->byte + ofs_groups * kClusters; |
| const uint8_t* packed_start = &in->byte + |
| NuqStream::PackedStart(in_capacity) + |
| ofs_groups * kGroupSize / 2; |
|
|
| HWY_UNROLL(1) |
| for (size_t g = 0; g < num_groups; ++g) { |
| const uint8_t* g_centers = tables + g * kClusters; |
| const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; |
| float* HWY_RESTRICT g_out = out + g * kGroupSize; |
|
|
| V16 tbl1 = Zero(d16); |
| const V16 tbl0 = LoadTable(d16, g_centers, &tbl1); |
|
|
| HWY_UNROLL(1) |
| for (size_t i = 0; i < kGroupSize; i += 4 * NF) { |
| V16 c0, c1; |
| TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1); |
| const VF f0 = hn::PromoteLowerTo(df, BitCast(dbf, c0)); |
| const VF f1 = hn::PromoteUpperTo(df, BitCast(dbf, c0)); |
| const VF f2 = hn::PromoteLowerTo(df, BitCast(dbf, c1)); |
| const VF f3 = hn::PromoteUpperTo(df, BitCast(dbf, c1)); |
| hn::StoreU(f0, df, g_out + i + NF * 0); |
| hn::StoreU(f1, df, g_out + i + NF * 1); |
| hn::StoreU(f2, df, g_out + i + NF * 2); |
| hn::StoreU(f3, df, g_out + i + NF * 3); |
| } |
| } |
| } |
|
|
| |
| |
| |
| template <class DF, HWY_IF_F32_D(DF)> |
| static HWY_INLINE void Dot(DF df, const size_t in_capacity, |
| const NuqStream* const in, const size_t in_ofs, |
| const hwy::bfloat16_t* const vec_aligned, |
| const size_t num, hn::Vec<DF>& sum0, |
| hn::Vec<DF>& sum1, hn::Vec<DF>& sum2, |
| hn::Vec<DF>& sum3) { |
| const hn::Repartition<hwy::bfloat16_t, DF> dbf; |
| const hn::RebindToUnsigned<decltype(dbf)> d16; |
| using VBF = hn::Vec<decltype(dbf)>; |
| using V16 = hn::Vec<decltype(d16)>; |
| const size_t N16 = hn::Lanes(d16); |
| HWY_DASSERT(kGroupSize >= 4 * N16); |
|
|
| HWY_DASSERT(in_ofs + num <= in_capacity); |
| HWY_DASSERT(in_capacity % kGroupSize == 0); |
| HWY_DASSERT(in_ofs % kGroupSize == 0); |
| HWY_DASSERT(num % kGroupSize == 0); |
| const size_t ofs_groups = in_ofs / kGroupSize; |
| const size_t num_groups = num / kGroupSize; |
| const uint8_t* tables = &in->byte + ofs_groups * kClusters; |
| const uint8_t* packed_start = &in->byte + |
| NuqStream::PackedStart(in_capacity) + |
| ofs_groups * kGroupSize / 2; |
|
|
| HWY_UNROLL(1) |
| for (size_t g = 0; g < num_groups; ++g) { |
| const uint8_t* g_centers = tables + g * kClusters; |
| const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; |
| const hwy::bfloat16_t* HWY_RESTRICT g_in = vec_aligned + g * kGroupSize; |
|
|
| V16 tbl1 = Zero(d16); |
| const V16 tbl0 = LoadTable(d16, g_centers, &tbl1); |
|
|
| HWY_UNROLL(1) |
| for (size_t i = 0; i < kGroupSize; i += 2 * N16) { |
| V16 c0, c1; |
| TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1); |
| const VBF in0 = hn::Load(dbf, g_in + i + N16 * 0); |
| const VBF in1 = hn::Load(dbf, g_in + i + N16 * 1); |
| sum0 = hn::ReorderWidenMulAccumulate(df, in0, BitCast(dbf, c0), sum0, |
| sum1); |
| sum2 = hn::ReorderWidenMulAccumulate(df, in1, BitCast(dbf, c1), sum2, |
| sum3); |
| } |
| } |
| } |
|
|
| |
| |
| |
| template <class DF, HWY_IF_F32_D(DF)> |
| static HWY_INLINE void Dot(DF df, const size_t in_capacity, |
| const NuqStream* const in, const size_t in_ofs, |
| const float* const vec_aligned, const size_t num, |
| hn::Vec<DF>& sum0, hn::Vec<DF>& sum1, |
| hn::Vec<DF>& sum2, hn::Vec<DF>& sum3) { |
| const hn::Repartition<hwy::bfloat16_t, DF> dbf; |
| const hn::RebindToUnsigned<decltype(dbf)> d16; |
| using VF = hn::Vec<decltype(df)>; |
| using V16 = hn::Vec<decltype(d16)>; |
| const size_t NF = hn::Lanes(df); |
| HWY_DASSERT(kGroupSize >= 4 * NF); |
|
|
| HWY_DASSERT(in_ofs + num <= in_capacity); |
| HWY_DASSERT(in_capacity % kGroupSize == 0); |
| HWY_DASSERT(in_ofs % kGroupSize == 0); |
| HWY_DASSERT(num % kGroupSize == 0); |
| const size_t ofs_groups = in_ofs / kGroupSize; |
| const size_t num_groups = num / kGroupSize; |
| const uint8_t* tables = &in->byte + ofs_groups * kClusters; |
| const uint8_t* packed_start = &in->byte + |
| NuqStream::PackedStart(in_capacity) + |
| ofs_groups * kGroupSize / 2; |
|
|
| HWY_UNROLL(1) |
| for (size_t g = 0; g < num_groups; ++g) { |
| const uint8_t* g_centers = tables + g * kClusters; |
| const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; |
| const float* HWY_RESTRICT g_in = vec_aligned + g * kGroupSize; |
|
|
| V16 tbl1 = Zero(d16); |
| const V16 tbl0 = LoadTable(d16, g_centers, &tbl1); |
|
|
| HWY_UNROLL(1) |
| for (size_t i = 0; i < kGroupSize; i += 4 * NF) { |
| V16 c0, c1; |
| TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1); |
| const VF in0 = hn::LoadU(df, g_in + i + NF * 0); |
| const VF in1 = hn::LoadU(df, g_in + i + NF * 1); |
| const VF in2 = hn::LoadU(df, g_in + i + NF * 2); |
| const VF in3 = hn::LoadU(df, g_in + i + NF * 3); |
| const VF f0 = hn::PromoteLowerTo(df, BitCast(dbf, c0)); |
| const VF f1 = hn::PromoteUpperTo(df, BitCast(dbf, c0)); |
| const VF f2 = hn::PromoteLowerTo(df, BitCast(dbf, c1)); |
| const VF f3 = hn::PromoteUpperTo(df, BitCast(dbf, c1)); |
| sum0 = hn::MulAdd(in0, f0, sum0); |
| sum1 = hn::MulAdd(in1, f1, sum1); |
| sum2 = hn::MulAdd(in2, f2, sum2); |
| sum3 = hn::MulAdd(in3, f3, sum3); |
| } |
| } |
| } |
| }; |
|
|
| |
| } |
| } |
| HWY_AFTER_NAMESPACE(); |
|
|
| #endif |
|
|