File size: 3,438 Bytes
52da7b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | from .linalg import Matrix, Vector, identity, invert_matrix, matmul, matvec, np, scale_matrix, transpose
def _empty_matrix(matrix: Matrix) -> bool:
if np is not None and hasattr(matrix, "size"):
return int(matrix.size) == 0
return not matrix
def ridge_regression_readout(
states: list[Vector],
targets: list[Vector],
*,
regularization: float,
) -> Matrix:
if not states or not targets:
raise ValueError("States and targets must be non-empty for ridge readout.")
if np is not None:
state_matrix = np.asarray(states, dtype=np.float64).T
target_matrix = np.asarray(targets, dtype=np.float64).T
gram = state_matrix @ state_matrix.T
regularized = gram + (regularization * np.eye(gram.shape[0], dtype=np.float64))
cross_covariance = target_matrix @ state_matrix.T
return np.linalg.solve(regularized.T, cross_covariance.T).T.tolist()
state_matrix = transpose(states)
target_matrix = transpose(targets)
gram = matmul(state_matrix, transpose(state_matrix))
regularized = [
[
gram[row][col] + (regularization if row == col else 0.0)
for col in range(len(gram[row]))
]
for row in range(len(gram))
]
inverse = invert_matrix(regularized)
cross_covariance = matmul(target_matrix, transpose(state_matrix))
return matmul(cross_covariance, inverse)
def ridge_regression_readout_from_moments(
gram: Matrix,
cross_covariance: Matrix,
*,
regularization: float,
) -> Matrix:
if _empty_matrix(gram) or _empty_matrix(cross_covariance):
raise ValueError("Gram and cross-covariance moments must be non-empty for ridge readout.")
if np is not None:
gram_array = np.asarray(gram, dtype=np.float64)
regularized = gram_array + (regularization * np.eye(gram_array.shape[0], dtype=np.float64))
cross_covariance_array = np.asarray(cross_covariance, dtype=np.float64)
return np.linalg.solve(regularized.T, cross_covariance_array.T).T
regularized = [
[
gram[row][col] + (regularization if row == col else 0.0)
for col in range(len(gram[row]))
]
for row in range(len(gram))
]
inverse = invert_matrix(regularized)
return matmul(cross_covariance, inverse)
def ridge_regression_readout_from_diagonal_moments(
feature_second_moment: Vector,
cross_covariance: Matrix,
*,
regularization: float,
) -> Matrix:
if _empty_matrix(feature_second_moment) or _empty_matrix(cross_covariance):
raise ValueError("Diagonal moments and cross-covariance must be non-empty for ridge readout.")
if np is not None:
denominator = np.asarray(feature_second_moment, dtype=np.float64) + regularization
denominator = np.where(np.abs(denominator) > 1e-12, denominator, regularization)
cross_covariance_array = np.asarray(cross_covariance, dtype=np.float64)
return cross_covariance_array / denominator[None, :]
denominator = [
value + regularization if abs(value + regularization) > 1e-12 else regularization
for value in feature_second_moment
]
return [
[
value / denominator[col]
for col, value in enumerate(row)
]
for row in cross_covariance
]
def apply_readout(weights: Matrix, state: Vector) -> Vector:
return matvec(weights, state)
|