obliteratus / tests /test_leace.py
pliny-the-prompter's picture
Upload 135 files
102206c verified
"""Tests for LEACE (LEAst-squares Concept Erasure) direction extraction."""
from __future__ import annotations
import pytest
import torch
from obliteratus.analysis.leace import LEACEExtractor, LEACEResult
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def extractor():
return LEACEExtractor(regularization_eps=1e-4)
@pytest.fixture
def separable_data():
"""Generate clearly separable harmful/harmless activations."""
torch.manual_seed(42)
d = 64
n = 20
# Harmful activations: cluster around [1, 0, 0, ...]
harmful_dir = torch.zeros(d)
harmful_dir[0] = 1.0
harmful = [harmful_dir + 0.1 * torch.randn(d) for _ in range(n)]
# Harmless activations: cluster around [-1, 0, 0, ...]
harmless = [-harmful_dir + 0.1 * torch.randn(d) for _ in range(n)]
return harmful, harmless
@pytest.fixture
def isotropic_data():
"""Data where classes differ only in mean, with isotropic variance."""
torch.manual_seed(123)
d = 32
n = 30
direction = torch.randn(d)
direction = direction / direction.norm()
harmful = [direction * 2.0 + torch.randn(d) for _ in range(n)]
harmless = [-direction * 2.0 + torch.randn(d) for _ in range(n)]
return harmful, harmless, direction
# ---------------------------------------------------------------------------
# LEACEResult
# ---------------------------------------------------------------------------
class TestLEACEResult:
def test_result_fields(self, extractor, separable_data):
harmful, harmless = separable_data
result = extractor.extract(harmful, harmless, layer_idx=5)
assert isinstance(result, LEACEResult)
assert result.layer_idx == 5
assert result.direction.shape == (64,)
assert result.generalized_eigenvalue > 0
assert result.within_class_condition > 0
assert result.mean_diff_norm > 0
assert result.erasure_loss >= 0
def test_direction_is_unit_vector(self, extractor, separable_data):
harmful, harmless = separable_data
result = extractor.extract(harmful, harmless)
norm = result.direction.norm().item()
assert abs(norm - 1.0) < 1e-5
# ---------------------------------------------------------------------------
# Direction quality
# ---------------------------------------------------------------------------
class TestDirectionQuality:
def test_finds_true_direction(self, extractor, separable_data):
"""LEACE should find a direction aligned with the true separation axis."""
harmful, harmless = separable_data
result = extractor.extract(harmful, harmless)
# True direction is [1, 0, 0, ...]
true_dir = torch.zeros(64)
true_dir[0] = 1.0
cosine = (result.direction @ true_dir).abs().item()
# With 20 samples in 64 dims, some noise is expected
assert cosine > 0.5, f"LEACE direction not aligned with true direction: {cosine}"
def test_isotropic_matches_diff_of_means(self, extractor, isotropic_data):
"""With isotropic noise, LEACE should roughly match diff-of-means."""
harmful, harmless, true_dir = isotropic_data
result = extractor.extract(harmful, harmless)
# Diff of means
diff = torch.stack(harmful).mean(0) - torch.stack(harmless).mean(0)
diff_normalized = diff / diff.norm()
cosine = (result.direction @ diff_normalized).abs().item()
# With finite samples and regularization, some deviation is expected
assert cosine > 0.5
def test_leace_differs_from_diff_means_with_anisotropic_noise(self):
"""With anisotropic noise, LEACE should find a better direction than diff-of-means."""
torch.manual_seed(77)
d = 64
n = 50
# True refusal direction
true_dir = torch.zeros(d)
true_dir[0] = 1.0
# Add anisotropic noise: high variance in dim 1 (NOT the refusal direction)
noise_scale = torch.ones(d) * 0.1
noise_scale[1] = 5.0 # Rogue dimension
harmful = [true_dir * 0.5 + torch.randn(d) * noise_scale for _ in range(n)]
harmless = [-true_dir * 0.5 + torch.randn(d) * noise_scale for _ in range(n)]
extractor = LEACEExtractor()
result = extractor.extract(harmful, harmless)
cosine_to_true = (result.direction @ true_dir).abs().item()
# LEACE should still find the true direction, not be distracted by rogue dim
assert cosine_to_true > 0.5, f"LEACE distracted by rogue dimension: {cosine_to_true}"
# ---------------------------------------------------------------------------
# Comparison with diff-of-means
# ---------------------------------------------------------------------------
class TestCompareWithDiffOfMeans:
def test_comparison_output(self, extractor, separable_data):
harmful, harmless = separable_data
result = extractor.extract(harmful, harmless)
harmful_mean = torch.stack(harmful).mean(0)
harmless_mean = torch.stack(harmless).mean(0)
comparison = LEACEExtractor.compare_with_diff_of_means(
result, harmful_mean, harmless_mean,
)
assert "cosine_similarity" in comparison
assert "leace_eigenvalue" in comparison
assert "leace_erasure_loss" in comparison
assert "within_class_condition" in comparison
assert "mean_diff_norm" in comparison
assert 0 <= comparison["cosine_similarity"] <= 1.0
# ---------------------------------------------------------------------------
# Multi-layer extraction
# ---------------------------------------------------------------------------
class TestMultiLayer:
def test_extract_all_layers(self, extractor):
torch.manual_seed(42)
d = 32
n = 15
harmful_acts = {}
harmless_acts = {}
for layer in [0, 1, 2, 5]:
harmful_acts[layer] = [torch.randn(d) + 0.5 for _ in range(n)]
harmless_acts[layer] = [torch.randn(d) - 0.5 for _ in range(n)]
results = extractor.extract_all_layers(harmful_acts, harmless_acts)
assert set(results.keys()) == {0, 1, 2, 5}
for idx, result in results.items():
assert result.layer_idx == idx
assert result.direction.shape == (d,)
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
class TestEdgeCases:
def test_single_sample(self, extractor):
"""Should handle single sample per class gracefully."""
d = 32
harmful = [torch.randn(d)]
harmless = [torch.randn(d)]
result = extractor.extract(harmful, harmless)
assert result.direction.shape == (d,)
assert torch.isfinite(result.direction).all()
def test_identical_activations(self, extractor):
"""Should handle case where harmful == harmless."""
d = 32
x = torch.randn(d)
harmful = [x.clone() for _ in range(5)]
harmless = [x.clone() for _ in range(5)]
result = extractor.extract(harmful, harmless)
assert result.direction.shape == (d,)
# Direction norm should be ~0 or direction is a fallback
assert torch.isfinite(result.direction).all()
def test_3d_input_squeezed(self, extractor):
"""Should handle (n, 1, d) shaped inputs."""
d = 32
harmful = [torch.randn(1, d) for _ in range(10)]
harmless = [torch.randn(1, d) for _ in range(10)]
result = extractor.extract(harmful, harmless)
assert result.direction.shape == (d,)
def test_shrinkage(self):
"""Shrinkage should produce valid results."""
torch.manual_seed(42)
d = 64
n = 10 # n < d → need shrinkage
harmful = [torch.randn(d) + 0.3 for _ in range(n)]
harmless = [torch.randn(d) - 0.3 for _ in range(n)]
extractor = LEACEExtractor(shrinkage=0.5)
result = extractor.extract(harmful, harmless)
assert result.direction.shape == (d,)
assert torch.isfinite(result.direction).all()