| """
|
| tests/test_csv.py
|
| =================
|
| Verify CSV (Eq. 14) properties:
|
|
|
| • Empty retained set → zero vector (Eq. 14, second case)
|
| • Permutation invariance (Property 1)
|
| • Distinct distributions → distinct CSVs (Lemma 2: injectivity)
|
| • Linear projection: z = E^T · Z (Property 1)
|
| • Noise stability: ||z̃ - z||² ∝ 1/|S| (Proposition 2)
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import pytest
|
| import torch
|
|
|
| from caff.csv import CSV
|
|
|
|
|
| class _StubCache:
|
| """Minimal cache stub for testing CSV without loading BioLinkBERT."""
|
| def __init__(self, embeddings: torch.Tensor, names: list[str]) -> None:
|
| self.embeddings = embeddings
|
| self.relation_to_idx = {n: i for i, n in enumerate(names)}
|
|
|
| def get_batch(self, names: list[str]) -> torch.Tensor:
|
| idx = torch.tensor(
|
| [self.relation_to_idx[n] for n in names],
|
| device=self.embeddings.device,
|
| )
|
| return self.embeddings[idx]
|
|
|
|
|
| def _make_orthonormal_cache(n_relations: int = 5, d: int = 8) -> _StubCache:
|
| """Use orthonormal basis vectors as relation embeddings —
|
| guarantees rank(E) = |R| (Lemma 2 precondition)."""
|
| E = torch.zeros(n_relations, d)
|
| for i in range(n_relations):
|
| E[i, i] = 1.0
|
| names = [f"r{i}" for i in range(n_relations)]
|
| return _StubCache(E, names)
|
|
|
|
|
| def test_csv_empty_returns_zero():
|
| """Eq. 14, second case: S_{ℓ-1} = ∅ → z = 0_d."""
|
| cache = _make_orthonormal_cache()
|
| csv = CSV(cache, pool="mean")
|
| z = csv([[]])
|
| assert z.shape == (1, 8)
|
| assert torch.allclose(z, torch.zeros(1, 8))
|
|
|
|
|
| def test_csv_permutation_invariance():
|
| """Property 1: order of relations does not change the CSV."""
|
| cache = _make_orthonormal_cache()
|
| csv = CSV(cache, pool="mean")
|
| z1 = csv([["r0", "r1", "r2"]])
|
| z2 = csv([["r2", "r0", "r1"]])
|
| z3 = csv([["r1", "r2", "r0"]])
|
| assert torch.allclose(z1, z2)
|
| assert torch.allclose(z1, z3)
|
|
|
|
|
| def test_csv_injectivity_on_simplex():
|
| """Lemma 2: rank(E)=|R| ⟹ Z ↦ E^T Z is injective.
|
|
|
| Two distinct retained-context distributions must produce
|
| distinct CSVs."""
|
| cache = _make_orthonormal_cache()
|
| csv = CSV(cache, pool="mean")
|
| z_a = csv([["r0", "r0", "r1"]])
|
| z_b = csv([["r0", "r1", "r1"]])
|
| assert not torch.allclose(z_a, z_b, atol=1e-6)
|
|
|
|
|
| def test_csv_linear_projection():
|
| """Property 1: z = E^T · Z (relational distribution)."""
|
| cache = _make_orthonormal_cache()
|
| csv = CSV(cache, pool="mean")
|
| relations = ["r0", "r1", "r1", "r2"]
|
| z = csv([relations]).squeeze(0)
|
|
|
|
|
| expected = (cache.embeddings[0] + 2 * cache.embeddings[1] + cache.embeddings[2]) / 4
|
| assert torch.allclose(z, expected, atol=1e-6)
|
|
|
|
|
| def test_csv_batch_independence():
|
| """Different batch items do not contaminate each other."""
|
| cache = _make_orthonormal_cache()
|
| csv = CSV(cache, pool="mean")
|
| z = csv([["r0"], ["r1"], []])
|
| assert torch.allclose(z[0], cache.embeddings[0])
|
| assert torch.allclose(z[1], cache.embeddings[1])
|
| assert torch.allclose(z[2], torch.zeros(8))
|
|
|
|
|
| def test_csv_max_pool_ablation():
|
| """Ablation §10.1: max-pool variant produces different output."""
|
| cache = _make_orthonormal_cache()
|
| csv_mean = CSV(cache, pool="mean")
|
| csv_max = CSV(cache, pool="max")
|
| z_mean = csv_mean([["r0", "r1"]])
|
| z_max = csv_max( [["r0", "r1"]])
|
| assert not torch.allclose(z_mean, z_max) |