Spaces:
Running on Zero
Running on Zero
| """Tests for LEACE (LEAst-squares Concept Erasure) direction extraction.""" | |
| from __future__ import annotations | |
| import pytest | |
| import torch | |
| from obliteratus.analysis.leace import LEACEExtractor, LEACEResult | |
| # --------------------------------------------------------------------------- | |
| # Fixtures | |
| # --------------------------------------------------------------------------- | |
| def extractor(): | |
| return LEACEExtractor(regularization_eps=1e-4) | |
| def separable_data(): | |
| """Generate clearly separable harmful/harmless activations.""" | |
| torch.manual_seed(42) | |
| d = 64 | |
| n = 20 | |
| # Harmful activations: cluster around [1, 0, 0, ...] | |
| harmful_dir = torch.zeros(d) | |
| harmful_dir[0] = 1.0 | |
| harmful = [harmful_dir + 0.1 * torch.randn(d) for _ in range(n)] | |
| # Harmless activations: cluster around [-1, 0, 0, ...] | |
| harmless = [-harmful_dir + 0.1 * torch.randn(d) for _ in range(n)] | |
| return harmful, harmless | |
| def isotropic_data(): | |
| """Data where classes differ only in mean, with isotropic variance.""" | |
| torch.manual_seed(123) | |
| d = 32 | |
| n = 30 | |
| direction = torch.randn(d) | |
| direction = direction / direction.norm() | |
| harmful = [direction * 2.0 + torch.randn(d) for _ in range(n)] | |
| harmless = [-direction * 2.0 + torch.randn(d) for _ in range(n)] | |
| return harmful, harmless, direction | |
| # --------------------------------------------------------------------------- | |
| # LEACEResult | |
| # --------------------------------------------------------------------------- | |
| class TestLEACEResult: | |
| def test_result_fields(self, extractor, separable_data): | |
| harmful, harmless = separable_data | |
| result = extractor.extract(harmful, harmless, layer_idx=5) | |
| assert isinstance(result, LEACEResult) | |
| assert result.layer_idx == 5 | |
| assert result.direction.shape == (64,) | |
| assert result.generalized_eigenvalue > 0 | |
| assert result.within_class_condition > 0 | |
| assert result.mean_diff_norm > 0 | |
| assert result.erasure_loss >= 0 | |
| def test_direction_is_unit_vector(self, extractor, separable_data): | |
| harmful, harmless = separable_data | |
| result = extractor.extract(harmful, harmless) | |
| norm = result.direction.norm().item() | |
| assert abs(norm - 1.0) < 1e-5 | |
| # --------------------------------------------------------------------------- | |
| # Direction quality | |
| # --------------------------------------------------------------------------- | |
| class TestDirectionQuality: | |
| def test_finds_true_direction(self, extractor, separable_data): | |
| """LEACE should find a direction aligned with the true separation axis.""" | |
| harmful, harmless = separable_data | |
| result = extractor.extract(harmful, harmless) | |
| # True direction is [1, 0, 0, ...] | |
| true_dir = torch.zeros(64) | |
| true_dir[0] = 1.0 | |
| cosine = (result.direction @ true_dir).abs().item() | |
| # With 20 samples in 64 dims, some noise is expected | |
| assert cosine > 0.5, f"LEACE direction not aligned with true direction: {cosine}" | |
| def test_isotropic_matches_diff_of_means(self, extractor, isotropic_data): | |
| """With isotropic noise, LEACE should roughly match diff-of-means.""" | |
| harmful, harmless, true_dir = isotropic_data | |
| result = extractor.extract(harmful, harmless) | |
| # Diff of means | |
| diff = torch.stack(harmful).mean(0) - torch.stack(harmless).mean(0) | |
| diff_normalized = diff / diff.norm() | |
| cosine = (result.direction @ diff_normalized).abs().item() | |
| # With finite samples and regularization, some deviation is expected | |
| assert cosine > 0.5 | |
| def test_leace_differs_from_diff_means_with_anisotropic_noise(self): | |
| """With anisotropic noise, LEACE should find a better direction than diff-of-means.""" | |
| torch.manual_seed(77) | |
| d = 64 | |
| n = 50 | |
| # True refusal direction | |
| true_dir = torch.zeros(d) | |
| true_dir[0] = 1.0 | |
| # Add anisotropic noise: high variance in dim 1 (NOT the refusal direction) | |
| noise_scale = torch.ones(d) * 0.1 | |
| noise_scale[1] = 5.0 # Rogue dimension | |
| harmful = [true_dir * 0.5 + torch.randn(d) * noise_scale for _ in range(n)] | |
| harmless = [-true_dir * 0.5 + torch.randn(d) * noise_scale for _ in range(n)] | |
| extractor = LEACEExtractor() | |
| result = extractor.extract(harmful, harmless) | |
| cosine_to_true = (result.direction @ true_dir).abs().item() | |
| # LEACE should still find the true direction, not be distracted by rogue dim | |
| assert cosine_to_true > 0.5, f"LEACE distracted by rogue dimension: {cosine_to_true}" | |
| # --------------------------------------------------------------------------- | |
| # Comparison with diff-of-means | |
| # --------------------------------------------------------------------------- | |
| class TestCompareWithDiffOfMeans: | |
| def test_comparison_output(self, extractor, separable_data): | |
| harmful, harmless = separable_data | |
| result = extractor.extract(harmful, harmless) | |
| harmful_mean = torch.stack(harmful).mean(0) | |
| harmless_mean = torch.stack(harmless).mean(0) | |
| comparison = LEACEExtractor.compare_with_diff_of_means( | |
| result, harmful_mean, harmless_mean, | |
| ) | |
| assert "cosine_similarity" in comparison | |
| assert "leace_eigenvalue" in comparison | |
| assert "leace_erasure_loss" in comparison | |
| assert "within_class_condition" in comparison | |
| assert "mean_diff_norm" in comparison | |
| assert 0 <= comparison["cosine_similarity"] <= 1.0 | |
| # --------------------------------------------------------------------------- | |
| # Multi-layer extraction | |
| # --------------------------------------------------------------------------- | |
| class TestMultiLayer: | |
| def test_extract_all_layers(self, extractor): | |
| torch.manual_seed(42) | |
| d = 32 | |
| n = 15 | |
| harmful_acts = {} | |
| harmless_acts = {} | |
| for layer in [0, 1, 2, 5]: | |
| harmful_acts[layer] = [torch.randn(d) + 0.5 for _ in range(n)] | |
| harmless_acts[layer] = [torch.randn(d) - 0.5 for _ in range(n)] | |
| results = extractor.extract_all_layers(harmful_acts, harmless_acts) | |
| assert set(results.keys()) == {0, 1, 2, 5} | |
| for idx, result in results.items(): | |
| assert result.layer_idx == idx | |
| assert result.direction.shape == (d,) | |
| # --------------------------------------------------------------------------- | |
| # Edge cases | |
| # --------------------------------------------------------------------------- | |
| class TestEdgeCases: | |
| def test_single_sample(self, extractor): | |
| """Should handle single sample per class gracefully.""" | |
| d = 32 | |
| harmful = [torch.randn(d)] | |
| harmless = [torch.randn(d)] | |
| result = extractor.extract(harmful, harmless) | |
| assert result.direction.shape == (d,) | |
| assert torch.isfinite(result.direction).all() | |
| def test_identical_activations(self, extractor): | |
| """Should handle case where harmful == harmless.""" | |
| d = 32 | |
| x = torch.randn(d) | |
| harmful = [x.clone() for _ in range(5)] | |
| harmless = [x.clone() for _ in range(5)] | |
| result = extractor.extract(harmful, harmless) | |
| assert result.direction.shape == (d,) | |
| # Direction norm should be ~0 or direction is a fallback | |
| assert torch.isfinite(result.direction).all() | |
| def test_3d_input_squeezed(self, extractor): | |
| """Should handle (n, 1, d) shaped inputs.""" | |
| d = 32 | |
| harmful = [torch.randn(1, d) for _ in range(10)] | |
| harmless = [torch.randn(1, d) for _ in range(10)] | |
| result = extractor.extract(harmful, harmless) | |
| assert result.direction.shape == (d,) | |
| def test_shrinkage(self): | |
| """Shrinkage should produce valid results.""" | |
| torch.manual_seed(42) | |
| d = 64 | |
| n = 10 # n < d → need shrinkage | |
| harmful = [torch.randn(d) + 0.3 for _ in range(n)] | |
| harmless = [torch.randn(d) - 0.3 for _ in range(n)] | |
| extractor = LEACEExtractor(shrinkage=0.5) | |
| result = extractor.extract(harmful, harmless) | |
| assert result.direction.shape == (d,) | |
| assert torch.isfinite(result.direction).all() | |