| import numpy as np | |
| from src.cka.compute import linear_cka, linear_cka_feature | |
| def hsic_biased(k, l): | |
| h = np.eye(k.shape[0], dtype=k.dtype) - 1 / k.shape[0] | |
| return float(np.trace(k @ h @ l @ h)) | |
| def hsic_unbiased(k, l): | |
| m = k.shape[0] | |
| k_tilde = k.copy() | |
| l_tilde = l.copy() | |
| np.fill_diagonal(k_tilde, 0.0) | |
| np.fill_diagonal(l_tilde, 0.0) | |
| hsic_value = ( | |
| (np.sum(k_tilde * l_tilde.T)) | |
| + (np.sum(k_tilde) * np.sum(l_tilde) / ((m - 1) * (m - 2))) | |
| - (2 * np.sum(k_tilde @ l_tilde) / (m - 2)) | |
| ) | |
| hsic_value /= m * (m - 3) | |
| return float(hsic_value) | |
| def ref_cka(a, b, unbiased=False): | |
| k = a @ a.T | |
| l = b @ b.T | |
| hsic_fn = hsic_unbiased if unbiased else hsic_biased | |
| hsic_kk = hsic_fn(k, k) | |
| hsic_ll = hsic_fn(l, l) | |
| hsic_kl = hsic_fn(k, l) | |
| return float(hsic_kl / (np.sqrt(hsic_kk * hsic_ll) + 1e-6)) | |
| def main(): | |
| rng = np.random.default_rng(0) | |
| a = rng.standard_normal((64, 128)).astype(np.float64) | |
| b = rng.standard_normal((64, 128)).astype(np.float64) | |
| our_biased = linear_cka(a, b, unbiased=False) | |
| ref_biased = ref_cka(a, b, unbiased=False) | |
| our_unbiased = linear_cka(a, b, unbiased=True) | |
| ref_unbiased = ref_cka(a, b, unbiased=True) | |
| our_feature = linear_cka_feature(a, b) | |
| print("biased:", our_biased, ref_biased, "diff", abs(our_biased - ref_biased)) | |
| print("unbiased:", our_unbiased, ref_unbiased, "diff", abs(our_unbiased - ref_unbiased)) | |
| print("feature:", our_feature, ref_biased, "diff", abs(our_feature - ref_biased)) | |
| if __name__ == "__main__": | |
| main() | |