| from __future__ import annotations |
|
|
| import numpy as np |
|
|
|
|
| def _validate_inputs(x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]: |
| x = np.asarray(x, dtype=np.float64) |
| y = np.asarray(y, dtype=np.float64) |
|
|
| if x.ndim != 2 or y.ndim != 2: |
| raise ValueError("CKA expects 2D arrays shaped [num_samples, dim].") |
| if x.shape[0] != y.shape[0]: |
| raise ValueError("CKA expects the same number of samples in both embeddings.") |
|
|
| return x, y |
|
|
|
|
| def hsic_biased(k: np.ndarray, l: np.ndarray) -> float: |
| """Biased HSIC, matches the reference implementation.""" |
| m = k.shape[0] |
| h = np.eye(m, dtype=k.dtype) - (1.0 / m) |
| return float(np.trace(k @ h @ l @ h)) |
|
|
|
|
| def hsic_unbiased(k: np.ndarray, l: np.ndarray) -> float: |
| """Unbiased HSIC as in Song et al. (2012).""" |
| m = k.shape[0] |
| if m < 4: |
| return 0.0 |
|
|
| k_tilde = k.copy() |
| l_tilde = l.copy() |
| np.fill_diagonal(k_tilde, 0.0) |
| np.fill_diagonal(l_tilde, 0.0) |
|
|
| term1 = np.sum(k_tilde * l_tilde.T) |
| term2 = (np.sum(k_tilde) * np.sum(l_tilde)) / ((m - 1) * (m - 2)) |
| term3 = (2 * np.sum(k_tilde @ l_tilde)) / (m - 2) |
| return float((term1 + term2 - term3) / (m * (m - 3))) |
|
|
|
|
| def linear_cka_feature(x: np.ndarray, y: np.ndarray, eps: float = 1e-6) -> float: |
| """Direct feature-space linear CKA (equivalent to biased HSIC for linear kernels).""" |
| x, y = _validate_inputs(x, y) |
|
|
| x = x - x.mean(axis=0, keepdims=True) |
| y = y - y.mean(axis=0, keepdims=True) |
|
|
| numerator = np.linalg.norm(x.T @ y, ord="fro") ** 2 |
| denom = np.linalg.norm(x.T @ x, ord="fro") * np.linalg.norm(y.T @ y, ord="fro") |
| if denom == 0: |
| return 0.0 |
| return float(numerator / (denom + eps)) |
|
|
|
|
| def linear_cka( |
| x: np.ndarray, |
| y: np.ndarray, |
| *, |
| unbiased: bool = False, |
| eps: float = 1e-6, |
| ) -> float: |
| """Linear CKA computed via HSIC, matching the reference implementation.""" |
| x, y = _validate_inputs(x, y) |
|
|
| k = x @ x.T |
| l = y @ y.T |
|
|
| hsic_fn = hsic_unbiased if unbiased else hsic_biased |
| hsic_kk = hsic_fn(k, k) |
| hsic_ll = hsic_fn(l, l) |
| hsic_kl = hsic_fn(k, l) |
|
|
| denom = np.sqrt(hsic_kk * hsic_ll) |
| return float(hsic_kl / (denom + eps)) |
|
|