File size: 3,438 Bytes
52da7b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from .linalg import Matrix, Vector, identity, invert_matrix, matmul, matvec, np, scale_matrix, transpose


def _empty_matrix(matrix: Matrix) -> bool:
    if np is not None and hasattr(matrix, "size"):
        return int(matrix.size) == 0
    return not matrix


def ridge_regression_readout(
    states: list[Vector],
    targets: list[Vector],
    *,
    regularization: float,
) -> Matrix:
    if not states or not targets:
        raise ValueError("States and targets must be non-empty for ridge readout.")
    if np is not None:
        state_matrix = np.asarray(states, dtype=np.float64).T
        target_matrix = np.asarray(targets, dtype=np.float64).T
        gram = state_matrix @ state_matrix.T
        regularized = gram + (regularization * np.eye(gram.shape[0], dtype=np.float64))
        cross_covariance = target_matrix @ state_matrix.T
        return np.linalg.solve(regularized.T, cross_covariance.T).T.tolist()

    state_matrix = transpose(states)
    target_matrix = transpose(targets)
    gram = matmul(state_matrix, transpose(state_matrix))
    regularized = [
        [
            gram[row][col] + (regularization if row == col else 0.0)
            for col in range(len(gram[row]))
        ]
        for row in range(len(gram))
    ]
    inverse = invert_matrix(regularized)
    cross_covariance = matmul(target_matrix, transpose(state_matrix))
    return matmul(cross_covariance, inverse)


def ridge_regression_readout_from_moments(
    gram: Matrix,
    cross_covariance: Matrix,
    *,
    regularization: float,
) -> Matrix:
    if _empty_matrix(gram) or _empty_matrix(cross_covariance):
        raise ValueError("Gram and cross-covariance moments must be non-empty for ridge readout.")
    if np is not None:
        gram_array = np.asarray(gram, dtype=np.float64)
        regularized = gram_array + (regularization * np.eye(gram_array.shape[0], dtype=np.float64))
        cross_covariance_array = np.asarray(cross_covariance, dtype=np.float64)
        return np.linalg.solve(regularized.T, cross_covariance_array.T).T

    regularized = [
        [
            gram[row][col] + (regularization if row == col else 0.0)
            for col in range(len(gram[row]))
        ]
        for row in range(len(gram))
    ]
    inverse = invert_matrix(regularized)
    return matmul(cross_covariance, inverse)


def ridge_regression_readout_from_diagonal_moments(
    feature_second_moment: Vector,
    cross_covariance: Matrix,
    *,
    regularization: float,
) -> Matrix:
    if _empty_matrix(feature_second_moment) or _empty_matrix(cross_covariance):
        raise ValueError("Diagonal moments and cross-covariance must be non-empty for ridge readout.")
    if np is not None:
        denominator = np.asarray(feature_second_moment, dtype=np.float64) + regularization
        denominator = np.where(np.abs(denominator) > 1e-12, denominator, regularization)
        cross_covariance_array = np.asarray(cross_covariance, dtype=np.float64)
        return cross_covariance_array / denominator[None, :]

    denominator = [
        value + regularization if abs(value + regularization) > 1e-12 else regularization
        for value in feature_second_moment
    ]
    return [
        [
            value / denominator[col]
            for col, value in enumerate(row)
        ]
        for row in cross_covariance
    ]


def apply_readout(weights: Matrix, state: Vector) -> Vector:
    return matvec(weights, state)