siddsuresh97's picture
Initial commit: ICLR 2026 Representational Alignment Challenge
d6c8a4f
from __future__ import annotations
import numpy as np
def _validate_inputs(x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
x = np.asarray(x, dtype=np.float64)
y = np.asarray(y, dtype=np.float64)
if x.ndim != 2 or y.ndim != 2:
raise ValueError("CKA expects 2D arrays shaped [num_samples, dim].")
if x.shape[0] != y.shape[0]:
raise ValueError("CKA expects the same number of samples in both embeddings.")
return x, y
def hsic_biased(k: np.ndarray, l: np.ndarray) -> float:
"""Biased HSIC, matches the reference implementation."""
m = k.shape[0]
h = np.eye(m, dtype=k.dtype) - (1.0 / m)
return float(np.trace(k @ h @ l @ h))
def hsic_unbiased(k: np.ndarray, l: np.ndarray) -> float:
"""Unbiased HSIC as in Song et al. (2012)."""
m = k.shape[0]
if m < 4:
return 0.0
k_tilde = k.copy()
l_tilde = l.copy()
np.fill_diagonal(k_tilde, 0.0)
np.fill_diagonal(l_tilde, 0.0)
term1 = np.sum(k_tilde * l_tilde.T)
term2 = (np.sum(k_tilde) * np.sum(l_tilde)) / ((m - 1) * (m - 2))
term3 = (2 * np.sum(k_tilde @ l_tilde)) / (m - 2)
return float((term1 + term2 - term3) / (m * (m - 3)))
def linear_cka_feature(x: np.ndarray, y: np.ndarray, eps: float = 1e-6) -> float:
"""Direct feature-space linear CKA (equivalent to biased HSIC for linear kernels)."""
x, y = _validate_inputs(x, y)
x = x - x.mean(axis=0, keepdims=True)
y = y - y.mean(axis=0, keepdims=True)
numerator = np.linalg.norm(x.T @ y, ord="fro") ** 2
denom = np.linalg.norm(x.T @ x, ord="fro") * np.linalg.norm(y.T @ y, ord="fro")
if denom == 0:
return 0.0
return float(numerator / (denom + eps))
def linear_cka(
x: np.ndarray,
y: np.ndarray,
*,
unbiased: bool = False,
eps: float = 1e-6,
) -> float:
"""Linear CKA computed via HSIC, matching the reference implementation."""
x, y = _validate_inputs(x, y)
k = x @ x.T
l = y @ y.T
hsic_fn = hsic_unbiased if unbiased else hsic_biased
hsic_kk = hsic_fn(k, k)
hsic_ll = hsic_fn(l, l)
hsic_kl = hsic_fn(k, l)
denom = np.sqrt(hsic_kk * hsic_ll)
return float(hsic_kl / (denom + eps))