"""Tests for LEACE (LEAst-squares Concept Erasure) direction extraction.""" from __future__ import annotations import pytest import torch from obliteratus.analysis.leace import LEACEExtractor, LEACEResult # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture def extractor(): return LEACEExtractor(regularization_eps=1e-4) @pytest.fixture def separable_data(): """Generate clearly separable harmful/harmless activations.""" torch.manual_seed(42) d = 64 n = 20 # Harmful activations: cluster around [1, 0, 0, ...] harmful_dir = torch.zeros(d) harmful_dir[0] = 1.0 harmful = [harmful_dir + 0.1 * torch.randn(d) for _ in range(n)] # Harmless activations: cluster around [-1, 0, 0, ...] harmless = [-harmful_dir + 0.1 * torch.randn(d) for _ in range(n)] return harmful, harmless @pytest.fixture def isotropic_data(): """Data where classes differ only in mean, with isotropic variance.""" torch.manual_seed(123) d = 32 n = 30 direction = torch.randn(d) direction = direction / direction.norm() harmful = [direction * 2.0 + torch.randn(d) for _ in range(n)] harmless = [-direction * 2.0 + torch.randn(d) for _ in range(n)] return harmful, harmless, direction # --------------------------------------------------------------------------- # LEACEResult # --------------------------------------------------------------------------- class TestLEACEResult: def test_result_fields(self, extractor, separable_data): harmful, harmless = separable_data result = extractor.extract(harmful, harmless, layer_idx=5) assert isinstance(result, LEACEResult) assert result.layer_idx == 5 assert result.direction.shape == (64,) assert result.generalized_eigenvalue > 0 assert result.within_class_condition > 0 assert result.mean_diff_norm > 0 assert result.erasure_loss >= 0 def test_direction_is_unit_vector(self, extractor, separable_data): harmful, harmless = separable_data result = extractor.extract(harmful, harmless) norm = result.direction.norm().item() assert abs(norm - 1.0) < 1e-5 # --------------------------------------------------------------------------- # Direction quality # --------------------------------------------------------------------------- class TestDirectionQuality: def test_finds_true_direction(self, extractor, separable_data): """LEACE should find a direction aligned with the true separation axis.""" harmful, harmless = separable_data result = extractor.extract(harmful, harmless) # True direction is [1, 0, 0, ...] true_dir = torch.zeros(64) true_dir[0] = 1.0 cosine = (result.direction @ true_dir).abs().item() # With 20 samples in 64 dims, some noise is expected assert cosine > 0.5, f"LEACE direction not aligned with true direction: {cosine}" def test_isotropic_matches_diff_of_means(self, extractor, isotropic_data): """With isotropic noise, LEACE should roughly match diff-of-means.""" harmful, harmless, true_dir = isotropic_data result = extractor.extract(harmful, harmless) # Diff of means diff = torch.stack(harmful).mean(0) - torch.stack(harmless).mean(0) diff_normalized = diff / diff.norm() cosine = (result.direction @ diff_normalized).abs().item() # With finite samples and regularization, some deviation is expected assert cosine > 0.5 def test_leace_differs_from_diff_means_with_anisotropic_noise(self): """With anisotropic noise, LEACE should find a better direction than diff-of-means.""" torch.manual_seed(77) d = 64 n = 50 # True refusal direction true_dir = torch.zeros(d) true_dir[0] = 1.0 # Add anisotropic noise: high variance in dim 1 (NOT the refusal direction) noise_scale = torch.ones(d) * 0.1 noise_scale[1] = 5.0 # Rogue dimension harmful = [true_dir * 0.5 + torch.randn(d) * noise_scale for _ in range(n)] harmless = [-true_dir * 0.5 + torch.randn(d) * noise_scale for _ in range(n)] extractor = LEACEExtractor() result = extractor.extract(harmful, harmless) cosine_to_true = (result.direction @ true_dir).abs().item() # LEACE should still find the true direction, not be distracted by rogue dim assert cosine_to_true > 0.5, f"LEACE distracted by rogue dimension: {cosine_to_true}" # --------------------------------------------------------------------------- # Comparison with diff-of-means # --------------------------------------------------------------------------- class TestCompareWithDiffOfMeans: def test_comparison_output(self, extractor, separable_data): harmful, harmless = separable_data result = extractor.extract(harmful, harmless) harmful_mean = torch.stack(harmful).mean(0) harmless_mean = torch.stack(harmless).mean(0) comparison = LEACEExtractor.compare_with_diff_of_means( result, harmful_mean, harmless_mean, ) assert "cosine_similarity" in comparison assert "leace_eigenvalue" in comparison assert "leace_erasure_loss" in comparison assert "within_class_condition" in comparison assert "mean_diff_norm" in comparison assert 0 <= comparison["cosine_similarity"] <= 1.0 # --------------------------------------------------------------------------- # Multi-layer extraction # --------------------------------------------------------------------------- class TestMultiLayer: def test_extract_all_layers(self, extractor): torch.manual_seed(42) d = 32 n = 15 harmful_acts = {} harmless_acts = {} for layer in [0, 1, 2, 5]: harmful_acts[layer] = [torch.randn(d) + 0.5 for _ in range(n)] harmless_acts[layer] = [torch.randn(d) - 0.5 for _ in range(n)] results = extractor.extract_all_layers(harmful_acts, harmless_acts) assert set(results.keys()) == {0, 1, 2, 5} for idx, result in results.items(): assert result.layer_idx == idx assert result.direction.shape == (d,) # --------------------------------------------------------------------------- # Edge cases # --------------------------------------------------------------------------- class TestEdgeCases: def test_single_sample(self, extractor): """Should handle single sample per class gracefully.""" d = 32 harmful = [torch.randn(d)] harmless = [torch.randn(d)] result = extractor.extract(harmful, harmless) assert result.direction.shape == (d,) assert torch.isfinite(result.direction).all() def test_identical_activations(self, extractor): """Should handle case where harmful == harmless.""" d = 32 x = torch.randn(d) harmful = [x.clone() for _ in range(5)] harmless = [x.clone() for _ in range(5)] result = extractor.extract(harmful, harmless) assert result.direction.shape == (d,) # Direction norm should be ~0 or direction is a fallback assert torch.isfinite(result.direction).all() def test_3d_input_squeezed(self, extractor): """Should handle (n, 1, d) shaped inputs.""" d = 32 harmful = [torch.randn(1, d) for _ in range(10)] harmless = [torch.randn(1, d) for _ in range(10)] result = extractor.extract(harmful, harmless) assert result.direction.shape == (d,) def test_shrinkage(self): """Shrinkage should produce valid results.""" torch.manual_seed(42) d = 64 n = 10 # n < d → need shrinkage harmful = [torch.randn(d) + 0.3 for _ in range(n)] harmless = [torch.randn(d) - 0.3 for _ in range(n)] extractor = LEACEExtractor(shrinkage=0.5) result = extractor.extract(harmful, harmless) assert result.direction.shape == (d,) assert torch.isfinite(result.direction).all()