File size: 7,893 Bytes
055eba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_

#include <math.h>  // pow
#include <stddef.h>
#include <stdio.h>

#include <vector>

#include "hwy/aligned_allocator.h"  // HWY_ALIGNMENT
#include "hwy/base.h"               // ScalarAbs
#include "hwy/contrib/sort/vqsort.h"
#include "hwy/stats.h"

namespace gcpp {

// Returns `sum` and `err` such that `sum + err` is exactly equal to `a + b`,
// despite floating-point rounding. `sum` is already the best estimate, so do
// not actually add `err` to it. Knuth98/Moller65. Unlike Fast2Sum [Dekker71],
// this does not require any relative ordering of the exponents of a and b.
template <typename T>
static inline T TwoSum(T a, T b, T& err) {
  const T sum = a + b;
  const T a2 = sum - b;
  const T b2 = sum - a2;
  const T err_a = a - a2;
  const T err_b = b - b2;
  err = err_a + err_b;
  return sum;
}

// Accumulates numbers with about twice the precision of T using 7 * n FLOPS.
// Rump/Ogita/Oishi08, Algorithm 6.11 in Handbook of Floating-Point Arithmetic.
template <typename T>
class CascadedSummation {
 public:
  void Notify(T t) {
    T err;
    sum_ = TwoSum(sum_, t, err);
    sum_err_ += err;
  }

  void Assimilate(const CascadedSummation& other) {
    Notify(other.sum_);
    sum_err_ += other.sum_err_;
  }

  // Allows users to observe how much difference the extra precision made.
  T Err() const { return sum_err_; }

  // Returns the sum of all `t` passed to `Notify`.
  T Total() const { return sum_ + sum_err_; }

 private:
  T sum_ = T{0};
  T sum_err_ = T{0};
};

// Summarizes the error of a distortion (e.g. quantization) applied to a series
// of numbers.
// Users should check all four resulting metrics (NumExact, NumRoundedToZero,
// GeomeanValueDivL1, WeightedAverageL1) because each covers different aspects.
class DistortionStats {
 public:
  void Notify(float original, float distorted) {
    (void)padding_;  // prevent unused member warning

    const bool rounded_to_zero = (original != 0.0f) && (distorted == 0.0f);
    // We expect original == 0 is not distorted (can be exactly represented).
    HWY_ASSERT(original != 0.0f || distorted == 0.0f);

    s_original_.Notify(original);
    const float l1f = hwy::ScalarAbs(original - distorted);
    const double l1 = static_cast<double>(l1f);
    s_l1_.Notify(l1f);
    b_l1_.Notify(HWY_MIN(99, static_cast<int>(l1f * 1E4)));
    if (l1f != 0.0f) {
      l1_.push_back(l1f);
    }
    sum_l1_.Notify(l1f);
    if (rounded_to_zero) sum_l1_rounded_.Notify(l1f);

    // Event counts
    {
      n_ += 1;
      // Rounding (small) negative numbers to 0 does not influence dot products
      // as much as an actual sign flip, so do not count them.
      n_sign_flip_ +=
          ((original < 0.0f) != (distorted < 0.0f)) && !rounded_to_zero;
      n_exact_ += (l1f == 0.0f);
      n_rounded_to_zero += rounded_to_zero;
    }

    // Signal to noise ratio (Shannon's channel capacity, NOT the L2-based and
    // logarithmic PSNR) to estimate the ratios of original to the L1 norm.
    if (l1f != 0.0) {  // prevent division by zero
      const double snr =
          1.0 + static_cast<double>(hwy::ScalarAbs(original)) / l1;
      // For numerical purposes (prevents overflow). A hierarchical geomean
      // could also work, but that is more complex and not necessarily better.
      // We will return exp() of the arithmetic mean.
      sum_log_snr_ += log(snr);
      num_snr_ += 1;
    }
  }

  void Assimilate(const DistortionStats& other) {
    s_original_.Assimilate(other.s_original_);
    s_l1_.Assimilate(other.s_l1_);
    b_l1_.Assimilate(other.b_l1_);
    sum_l1_.Assimilate(other.sum_l1_);
    sum_l1_rounded_.Assimilate(other.sum_l1_rounded_);
    l1_.insert(l1_.end(), other.l1_.begin(), other.l1_.end());

    n_ += other.n_;
    n_sign_flip_ += other.n_sign_flip_;
    n_exact_ += other.n_exact_;
    n_rounded_to_zero += other.n_rounded_to_zero;

    sum_log_snr_ += other.sum_log_snr_;
    num_snr_ += other.num_snr_;
  }

  size_t NumExact() const { return n_exact_; }
  size_t NumSignFlip() const { return n_sign_flip_; }
  size_t NumRoundedToZero() const { return n_rounded_to_zero; }
  // Total absolute error.
  double SumL1() const { return sum_l1_.Total(); }
  // Total absolute error for numbers that were rounded to zero.
  double SumL1Rounded() const { return sum_l1_rounded_.Total(); }

  // Returns geomean of 1 + S/N (Shannon channel capacity). This is computed via
  // the ratio of input magnitude to nonzero L1 norms. Higher is better.
  double GeomeanValueDivL1() const {
    if (num_snr_ == 0) return 0.0;
    return exp(sum_log_snr_ / static_cast<double>(num_snr_));
  }

  // Returns weighted average of nonzero L1 norms. Those further from the median
  // L1 norm are much more heavily weighted, such that this behaves more like
  // the L-infinity norm, but still includes all differences, not just the max.
  // Lower is better, magnitude depends on the input magnitude.
  double WeightedAverageL1() const {
    if (l1_.empty()) return 0.0f;  // all exact

    std::vector<float> weights(l1_);  // copy so we can modify
    const float median = [&weights]() {
      const size_t mid = weights.size() / 2;
      // We just want the median; partial sort is faster.
      hwy::VQSelect(weights.data(), weights.size(), mid, hwy::SortAscending());
      return weights[mid];
    }();
    weights = l1_;  // restore original order

    // Replace with distance from median (might have too few samples for mode).
    float max_abs = -1.0f;
    for (float& d : weights) {
      d = hwy::ScalarAbs(d - median);
      max_abs = HWY_MAX(max_abs, d);
    }
    HWY_ASSERT(max_abs >= 0.0f);
    // All equal - return the distance value to prevent division by zero.
    if (max_abs == 0.0f) return median;

    // Normalize to max difference and exponentiate.
    const double inv_max = 1.0 / static_cast<double>(max_abs);
    double sum_weights = 0.0;
    for (float& w : weights) {
      const double normalized = static_cast<double>(w) * inv_max;
      const double amplified = exp(4.0 * normalized * normalized);
      sum_weights += amplified;
      w = static_cast<float>(amplified);
    }
    // At least 1.0 per weight, plus more for at least one weight because we
    // verified via max_abs that not all are equal.
    HWY_ASSERT(sum_weights > static_cast<double>(weights.size()));

    // Return weighted average.
    double weighted_sum = 0.0;
    for (size_t i = 0; i < weights.size(); ++i) {
      weighted_sum += l1_[i] * weights[i];
    }
    return weighted_sum / sum_weights;
  }

  hwy::Stats& L1() { return s_l1_; }
  hwy::Stats& Original() { return s_original_; }

 private:
  hwy::Stats s_original_;
  hwy::Stats s_l1_;
  hwy::Bins<100> b_l1_;
  CascadedSummation<double> sum_l1_;          // all
  CascadedSummation<double> sum_l1_rounded_;  // only if rounded_to_zero
  std::vector<float> l1_;

  // Event counts
  size_t n_ = 0;
  size_t n_sign_flip_ = 0;
  size_t n_exact_ = 0;
  size_t n_rounded_to_zero = 0;

  double sum_log_snr_ = 0.0;
  size_t num_snr_ = 0;

  uint8_t padding_[HWY_ALIGNMENT];  // prevents false sharing
};

}  // namespace gcpp

#endif  // THIRD_PARTY_GEMMA_CPP_COMPRESSION_DISTORTION_H_