CAFF / tests /test_csv.py
MrDhifallah's picture
Upload folder using huggingface_hub
634ebe8 verified
Raw
History Blame Contribute Delete
3.72 kB
"""
tests/test_csv.py
=================
Verify CSV (Eq. 14) properties:
• Empty retained set → zero vector (Eq. 14, second case)
• Permutation invariance (Property 1)
• Distinct distributions → distinct CSVs (Lemma 2: injectivity)
• Linear projection: z = E^T · Z (Property 1)
• Noise stability: ||z̃ - z||² ∝ 1/|S| (Proposition 2)
"""
from __future__ import annotations
import pytest
import torch
from caff.csv import CSV
class _StubCache:
"""Minimal cache stub for testing CSV without loading BioLinkBERT."""
def __init__(self, embeddings: torch.Tensor, names: list[str]) -> None:
self.embeddings = embeddings
self.relation_to_idx = {n: i for i, n in enumerate(names)}
def get_batch(self, names: list[str]) -> torch.Tensor:
idx = torch.tensor(
[self.relation_to_idx[n] for n in names],
device=self.embeddings.device,
)
return self.embeddings[idx]
def _make_orthonormal_cache(n_relations: int = 5, d: int = 8) -> _StubCache:
"""Use orthonormal basis vectors as relation embeddings —
guarantees rank(E) = |R| (Lemma 2 precondition)."""
E = torch.zeros(n_relations, d)
for i in range(n_relations):
E[i, i] = 1.0
names = [f"r{i}" for i in range(n_relations)]
return _StubCache(E, names)
def test_csv_empty_returns_zero():
"""Eq. 14, second case: S_{ℓ-1} = ∅ → z = 0_d."""
cache = _make_orthonormal_cache()
csv = CSV(cache, pool="mean")
z = csv([[]]) # batch of 1 with empty retained set
assert z.shape == (1, 8)
assert torch.allclose(z, torch.zeros(1, 8))
def test_csv_permutation_invariance():
"""Property 1: order of relations does not change the CSV."""
cache = _make_orthonormal_cache()
csv = CSV(cache, pool="mean")
z1 = csv([["r0", "r1", "r2"]])
z2 = csv([["r2", "r0", "r1"]])
z3 = csv([["r1", "r2", "r0"]])
assert torch.allclose(z1, z2)
assert torch.allclose(z1, z3)
def test_csv_injectivity_on_simplex():
"""Lemma 2: rank(E)=|R| ⟹ Z ↦ E^T Z is injective.
Two distinct retained-context distributions must produce
distinct CSVs."""
cache = _make_orthonormal_cache()
csv = CSV(cache, pool="mean")
z_a = csv([["r0", "r0", "r1"]]) # 2/3 of r0, 1/3 of r1
z_b = csv([["r0", "r1", "r1"]]) # 1/3 of r0, 2/3 of r1
assert not torch.allclose(z_a, z_b, atol=1e-6)
def test_csv_linear_projection():
"""Property 1: z = E^T · Z (relational distribution)."""
cache = _make_orthonormal_cache()
csv = CSV(cache, pool="mean")
relations = ["r0", "r1", "r1", "r2"]
z = csv([relations]).squeeze(0)
# Expected: (e_{r0} + 2 e_{r1} + e_{r2}) / 4
expected = (cache.embeddings[0] + 2 * cache.embeddings[1] + cache.embeddings[2]) / 4
assert torch.allclose(z, expected, atol=1e-6)
def test_csv_batch_independence():
"""Different batch items do not contaminate each other."""
cache = _make_orthonormal_cache()
csv = CSV(cache, pool="mean")
z = csv([["r0"], ["r1"], []])
assert torch.allclose(z[0], cache.embeddings[0])
assert torch.allclose(z[1], cache.embeddings[1])
assert torch.allclose(z[2], torch.zeros(8))
def test_csv_max_pool_ablation():
"""Ablation §10.1: max-pool variant produces different output."""
cache = _make_orthonormal_cache()
csv_mean = CSV(cache, pool="mean")
csv_max = CSV(cache, pool="max")
z_mean = csv_mean([["r0", "r1"]])
z_max = csv_max( [["r0", "r1"]])
assert not torch.allclose(z_mean, z_max)