File size: 1,728 Bytes
52da7b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import math

from .linalg import Vector, mean


def quantize_vector_absmean(
    values: Vector,
    *,
    threshold: float = 0.5,
) -> tuple[float, list[int]]:
    if not values:
        return 1.0, []

    scale = mean([abs(value) for value in values])
    if scale == 0.0:
        return 1.0, [0 for _ in values]

    quantized: list[int] = []
    for value in values:
        normalized = value / scale
        if normalized >= threshold:
            quantized.append(1)
        elif normalized <= -threshold:
            quantized.append(-1)
        else:
            quantized.append(0)
    return scale, quantized


def derive_ternary_mask_from_states(states: list[Vector]) -> tuple[float, list[int]]:
    if not states:
        return 1.0, []
    feature_count = len(states[0])
    feature_energy = [
        mean([state[feature] * state[feature] for state in states])
        for feature in range(feature_count)
    ]
    return derive_ternary_mask_from_feature_energy(feature_energy)


def derive_ternary_mask_from_feature_energy(
    feature_energy: Vector,
    *,
    threshold: float = 0.02,
) -> tuple[float, list[int]]:
    if not feature_energy:
        return 1.0, []

    rms_values = [math.sqrt(max(value, 0.0)) for value in feature_energy]
    scale = mean(rms_values)
    if scale == 0.0:
        return 1.0, [0 for _ in feature_energy]

    mask = [1 if value >= threshold * scale else 0 for value in rms_values]
    if not any(mask):
        mask = [1 for _ in feature_energy]
    return 1.0, mask


def apply_ternary_mask(values: Vector, mask: list[int], scale: float) -> Vector:
    if not mask:
        return values[:]
    return [scale * mask[index] * values[index] for index in range(len(values))]