| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_ |
| #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_ |
|
|
| #include <math.h> |
| #include <stddef.h> |
| #include <stdio.h> |
|
|
| #include <vector> |
|
|
| #include "hwy/aligned_allocator.h" |
| #include "hwy/base.h" |
| #include "hwy/contrib/sort/vqsort.h" |
| #include "hwy/stats.h" |
|
|
| namespace gcpp { |
|
|
| |
| |
| |
| |
| template <typename T> |
| static inline T TwoSum(T a, T b, T& err) { |
| const T sum = a + b; |
| const T a2 = sum - b; |
| const T b2 = sum - a2; |
| const T err_a = a - a2; |
| const T err_b = b - b2; |
| err = err_a + err_b; |
| return sum; |
| } |
|
|
| |
| |
| template <typename T> |
| class CascadedSummation { |
| public: |
| void Notify(T t) { |
| T err; |
| sum_ = TwoSum(sum_, t, err); |
| sum_err_ += err; |
| } |
|
|
| void Assimilate(const CascadedSummation& other) { |
| Notify(other.sum_); |
| sum_err_ += other.sum_err_; |
| } |
|
|
| |
| T Err() const { return sum_err_; } |
|
|
| |
| T Total() const { return sum_ + sum_err_; } |
|
|
| private: |
| T sum_ = T{0}; |
| T sum_err_ = T{0}; |
| }; |
|
|
| |
| |
| |
| |
| class DistortionStats { |
| public: |
| void Notify(float original, float distorted) { |
| (void)padding_; |
|
|
| const bool rounded_to_zero = (original != 0.0f) && (distorted == 0.0f); |
| |
| HWY_ASSERT(original != 0.0f || distorted == 0.0f); |
|
|
| s_original_.Notify(original); |
| const float l1f = hwy::ScalarAbs(original - distorted); |
| const double l1 = static_cast<double>(l1f); |
| s_l1_.Notify(l1f); |
| b_l1_.Notify(HWY_MIN(99, static_cast<int>(l1f * 1E4))); |
| if (l1f != 0.0f) { |
| l1_.push_back(l1f); |
| } |
| sum_l1_.Notify(l1f); |
| if (rounded_to_zero) sum_l1_rounded_.Notify(l1f); |
|
|
| |
| { |
| n_ += 1; |
| |
| |
| n_sign_flip_ += |
| ((original < 0.0f) != (distorted < 0.0f)) && !rounded_to_zero; |
| n_exact_ += (l1f == 0.0f); |
| n_rounded_to_zero += rounded_to_zero; |
| } |
|
|
| |
| |
| if (l1f != 0.0) { |
| const double snr = |
| 1.0 + static_cast<double>(hwy::ScalarAbs(original)) / l1; |
| |
| |
| |
| sum_log_snr_ += log(snr); |
| num_snr_ += 1; |
| } |
| } |
|
|
| void Assimilate(const DistortionStats& other) { |
| s_original_.Assimilate(other.s_original_); |
| s_l1_.Assimilate(other.s_l1_); |
| b_l1_.Assimilate(other.b_l1_); |
| sum_l1_.Assimilate(other.sum_l1_); |
| sum_l1_rounded_.Assimilate(other.sum_l1_rounded_); |
| l1_.insert(l1_.end(), other.l1_.begin(), other.l1_.end()); |
|
|
| n_ += other.n_; |
| n_sign_flip_ += other.n_sign_flip_; |
| n_exact_ += other.n_exact_; |
| n_rounded_to_zero += other.n_rounded_to_zero; |
|
|
| sum_log_snr_ += other.sum_log_snr_; |
| num_snr_ += other.num_snr_; |
| } |
|
|
| size_t NumExact() const { return n_exact_; } |
| size_t NumSignFlip() const { return n_sign_flip_; } |
| size_t NumRoundedToZero() const { return n_rounded_to_zero; } |
| |
| double SumL1() const { return sum_l1_.Total(); } |
| |
| double SumL1Rounded() const { return sum_l1_rounded_.Total(); } |
|
|
| |
| |
| double GeomeanValueDivL1() const { |
| if (num_snr_ == 0) return 0.0; |
| return exp(sum_log_snr_ / static_cast<double>(num_snr_)); |
| } |
|
|
| |
| |
| |
| |
| double WeightedAverageL1() const { |
| if (l1_.empty()) return 0.0f; |
|
|
| std::vector<float> weights(l1_); |
| const float median = [&weights]() { |
| const size_t mid = weights.size() / 2; |
| |
| hwy::VQSelect(weights.data(), weights.size(), mid, hwy::SortAscending()); |
| return weights[mid]; |
| }(); |
| weights = l1_; |
|
|
| |
| float max_abs = -1.0f; |
| for (float& d : weights) { |
| d = hwy::ScalarAbs(d - median); |
| max_abs = HWY_MAX(max_abs, d); |
| } |
| HWY_ASSERT(max_abs >= 0.0f); |
| |
| if (max_abs == 0.0f) return median; |
|
|
| |
| const double inv_max = 1.0 / static_cast<double>(max_abs); |
| double sum_weights = 0.0; |
| for (float& w : weights) { |
| const double normalized = static_cast<double>(w) * inv_max; |
| const double amplified = exp(4.0 * normalized * normalized); |
| sum_weights += amplified; |
| w = static_cast<float>(amplified); |
| } |
| |
| |
| HWY_ASSERT(sum_weights > static_cast<double>(weights.size())); |
|
|
| |
| double weighted_sum = 0.0; |
| for (size_t i = 0; i < weights.size(); ++i) { |
| weighted_sum += l1_[i] * weights[i]; |
| } |
| return weighted_sum / sum_weights; |
| } |
|
|
| hwy::Stats& L1() { return s_l1_; } |
| hwy::Stats& Original() { return s_original_; } |
|
|
| private: |
| hwy::Stats s_original_; |
| hwy::Stats s_l1_; |
| hwy::Bins<100> b_l1_; |
| CascadedSummation<double> sum_l1_; |
| CascadedSummation<double> sum_l1_rounded_; |
| std::vector<float> l1_; |
|
|
| |
| size_t n_ = 0; |
| size_t n_sign_flip_ = 0; |
| size_t n_exact_ = 0; |
| size_t n_rounded_to_zero = 0; |
|
|
| double sum_log_snr_ = 0.0; |
| size_t num_snr_ = 0; |
|
|
| uint8_t padding_[HWY_ALIGNMENT]; |
| }; |
|
|
| } |
|
|
| #endif |
|
|