"""Tests for the five new analysis modules: 1. Tuned Lens (learned-affine logit lens variant) 2. Activation Patching (real interchange intervention) 3. Enhanced SAE Decomposition Pipeline 4. Wasserstein-Optimal Direction Extraction 5. Bayesian-Optimized Kernel Projection """ from __future__ import annotations import pytest import torch import torch.nn as nn from obliteratus.analysis.tuned_lens import ( TunedLensTrainer, TunedLensProbe, RefusalTunedLens, TunedLensResult, MultiLayerTunedLensResult, ) from obliteratus.analysis.activation_patching import ( ActivationPatcher, PatchingSite, ActivationPatchingResult, ) from obliteratus.analysis.sae_abliteration import ( SAEDecompositionPipeline, SAEDecompositionResult, FeatureClusterResult, ) from obliteratus.analysis.wasserstein_optimal import ( WassersteinOptimalExtractor, WassersteinDirectionResult, WassersteinComparisonResult, MultiLayerWassersteinResult, ) from obliteratus.analysis.bayesian_kernel_projection import ( BayesianKernelProjection, BayesianOptimizationResult, ProjectionConfig, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_activations( hidden_dim=32, n_per_class=20, separation=2.0, seed=42, ): """Create harmful/harmless activations with planted refusal signal.""" torch.manual_seed(seed) direction = torch.randn(hidden_dim) direction = direction / direction.norm() harmful = [ torch.randn(hidden_dim) * 0.3 + separation * direction for _ in range(n_per_class) ] harmless = [ torch.randn(hidden_dim) * 0.3 for _ in range(n_per_class) ] return harmful, harmless, direction def _make_multilayer_activations( n_layers=6, hidden_dim=32, n_per_class=20, separation=2.0, seed=42, ): """Create per-layer activations with planted refusal signals.""" torch.manual_seed(seed) harmful_acts = {} harmless_acts = {} directions = {} for li in range(n_layers): d = torch.randn(hidden_dim) d = d / d.norm() directions[li] = d strength = separation if 1 <= li <= n_layers - 2 else 0.3 harmful_acts[li] = [ torch.randn(hidden_dim) * 0.3 + strength * d for _ in range(n_per_class) ] harmless_acts[li] = [ torch.randn(hidden_dim) * 0.3 for _ in range(n_per_class) ] return harmful_acts, harmless_acts, directions class FakeTokenizer: """Fake tokenizer that maps strings to reproducible token IDs.""" def __init__(self, vocab_size=100): self.vocab_size = vocab_size def encode(self, text, add_special_tokens=False): return [hash(text) % self.vocab_size] def decode(self, ids): return f"tok_{ids[0]}" class FakeModel(nn.Module): """Fake model with lm_head and transformer.ln_f for testing.""" def __init__(self, hidden_dim=32, vocab_size=100, n_layers=4): super().__init__() self.hidden_dim = hidden_dim self.vocab_size = vocab_size self.n_layers = n_layers self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False) self.transformer = nn.Module() self.transformer.ln_f = nn.LayerNorm(hidden_dim) self.transformer.h = nn.ModuleList([ nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers) ]) def forward(self, input_ids): # Fake forward pass batch_size, seq_len = input_ids.shape x = torch.randn(batch_size, seq_len, self.hidden_dim) for layer in self.transformer.h: x = layer(x) + x logits = self.lm_head(self.transformer.ln_f(x)) return type('Output', (), {'logits': logits})() # =========================================================================== # Tests: Tuned Lens # =========================================================================== class TestTunedLensTrainer: def test_train_single_probe(self): hidden_dim = 16 n_samples = 30 layer_acts = torch.randn(n_samples, hidden_dim) final_acts = layer_acts + torch.randn(n_samples, hidden_dim) * 0.1 trainer = TunedLensTrainer(hidden_dim, n_epochs=20) probe = trainer.train_probe(layer_acts, final_acts, layer_idx=3) assert isinstance(probe, TunedLensProbe) assert probe.layer_idx == 3 assert probe.weight.shape == (hidden_dim, hidden_dim) assert probe.bias.shape == (hidden_dim,) assert probe.train_loss < 1.0 # should converge somewhat def test_train_all_layers(self): hidden_dim = 16 n_samples = 20 layer_acts = { i: torch.randn(n_samples, hidden_dim) for i in range(4) } final_acts = torch.randn(n_samples, hidden_dim) trainer = TunedLensTrainer(hidden_dim, n_epochs=10) probes = trainer.train_all_layers(layer_acts, final_acts) assert len(probes) == 4 for i in range(4): assert i in probes assert probes[i].weight.shape == (hidden_dim, hidden_dim) def test_probe_near_identity_for_final_layer(self): """Probe for the final layer should be close to identity.""" hidden_dim = 16 n_samples = 50 acts = torch.randn(n_samples, hidden_dim) trainer = TunedLensTrainer(hidden_dim, n_epochs=50) probe = trainer.train_probe(acts, acts, layer_idx=0) # Weight should be close to identity identity = torch.eye(hidden_dim) diff = (probe.weight - identity).norm().item() assert diff < 1.0 class TestRefusalTunedLens: def test_analyze_direction(self): hidden_dim = 32 vocab_size = 100 model = FakeModel(hidden_dim, vocab_size) tokenizer = FakeTokenizer(vocab_size) direction = torch.randn(hidden_dim) probe = TunedLensProbe( layer_idx=2, weight=torch.eye(hidden_dim) + torch.randn(hidden_dim, hidden_dim) * 0.01, bias=torch.zeros(hidden_dim), train_loss=0.01, ) lens = RefusalTunedLens(top_k=10) result = lens.analyze_direction(direction, probe, model, tokenizer) assert isinstance(result, TunedLensResult) assert result.layer_idx == 2 assert len(result.top_promoted) <= 10 assert len(result.top_suppressed) <= 10 assert isinstance(result.correction_magnitude, float) assert result.correction_magnitude >= 0 def test_analyze_all_layers(self): hidden_dim = 32 vocab_size = 100 model = FakeModel(hidden_dim, vocab_size) tokenizer = FakeTokenizer(vocab_size) directions = { i: torch.randn(hidden_dim) for i in range(4) } probes = { i: TunedLensProbe( layer_idx=i, weight=torch.eye(hidden_dim), bias=torch.zeros(hidden_dim), train_loss=0.01, ) for i in range(4) } lens = RefusalTunedLens(top_k=5) result = lens.analyze_all_layers(directions, probes, model, tokenizer) assert isinstance(result, MultiLayerTunedLensResult) assert len(result.per_layer) == 4 assert result.strongest_refusal_layer in range(4) def test_compare_with_logit_lens(self): logit_gaps = {0: 0.1, 1: 0.5, 2: 0.3, 3: 0.8} tuned_result = MultiLayerTunedLensResult( per_layer={ i: TunedLensResult( layer_idx=i, top_promoted=[], top_suppressed=[], refusal_token_mean_boost=0.0, compliance_token_mean_boost=0.0, refusal_compliance_gap=v * 1.1, # similar ranking correction_magnitude=0.1, ) for i, v in logit_gaps.items() }, probes={}, strongest_refusal_layer=3, peak_gap_layer=3, mean_refusal_compliance_gap=0.5, logit_lens_agreement=0.0, ) agreement = RefusalTunedLens.compare_with_logit_lens(tuned_result, logit_gaps) # Same ranking → correlation should be 1.0 assert agreement == pytest.approx(1.0, abs=0.01) def test_format_report(self): result = MultiLayerTunedLensResult( per_layer={}, probes={}, strongest_refusal_layer=0, peak_gap_layer=0, mean_refusal_compliance_gap=0.0, logit_lens_agreement=0.0, ) report = RefusalTunedLens.format_report(result) assert "Tuned Lens" in report assert "No layers analyzed" in report # =========================================================================== # Tests: Activation Patching # =========================================================================== class TestActivationPatcher: def test_patching_site_creation(self): site = PatchingSite(layer_idx=3, component="residual") assert site.layer_idx == 3 assert site.component == "residual" assert site.head_idx is None def test_patching_site_with_head(self): site = PatchingSite(layer_idx=2, component="attn_head", head_idx=5) assert site.head_idx == 5 def test_patch_sweep_with_model(self): """Test full patching sweep on fake model.""" hidden_dim = 32 model = FakeModel(hidden_dim, vocab_size=100, n_layers=4) clean_ids = torch.randint(0, 100, (1, 10)) corrupted_ids = torch.randint(0, 100, (1, 10)) patcher = ActivationPatcher(significance_threshold=0.05) result = patcher.patch_sweep( model, clean_ids, corrupted_ids, mode="noising", ) assert isinstance(result, ActivationPatchingResult) assert result.patching_mode == "noising" assert result.n_layers == 4 assert len(result.effects) > 0 assert isinstance(result.circuit_fraction, float) assert 0.0 <= result.circuit_fraction <= 1.0 def test_patch_sweep_denoising(self): hidden_dim = 32 model = FakeModel(hidden_dim, vocab_size=100, n_layers=4) clean_ids = torch.randint(0, 100, (1, 10)) corrupted_ids = torch.randint(0, 100, (1, 10)) patcher = ActivationPatcher() result = patcher.patch_sweep( model, clean_ids, corrupted_ids, mode="denoising", ) assert result.patching_mode == "denoising" def test_custom_metric(self): hidden_dim = 32 model = FakeModel(hidden_dim, vocab_size=100, n_layers=4) clean_ids = torch.randint(0, 100, (1, 10)) corrupted_ids = torch.randint(0, 100, (1, 10)) def custom_metric(logits): return logits.sum().item() patcher = ActivationPatcher(metric_fn=custom_metric) result = patcher.patch_sweep(model, clean_ids, corrupted_ids) assert isinstance(result, ActivationPatchingResult) assert isinstance(result.clean_baseline, float) def test_format_report(self): result = ActivationPatchingResult( n_layers=4, n_sites=4, patching_mode="noising", effects=[], clean_baseline=1.0, corrupted_baseline=0.0, total_effect=1.0, significant_sites=[], circuit_fraction=0.0, top_causal_layers=[], ) report = ActivationPatcher.format_report(result) assert "Activation Patching" in report assert "noising" in report # =========================================================================== # Tests: Enhanced SAE Decomposition Pipeline # =========================================================================== class TestSAEDecompositionPipeline: def test_basic_pipeline(self): harmful, harmless, _ = _make_activations(hidden_dim=16, n_per_class=30, separation=2.0) pipeline = SAEDecompositionPipeline( expansion=2, n_epochs=10, top_k_features=8, n_clusters=3, ) result = pipeline.run(harmful, harmless, layer_idx=0) assert isinstance(result, SAEDecompositionResult) assert result.layer_idx == 0 assert result.sae is not None assert result.refusal_features.n_refusal_features == 8 assert len(result.feature_sparsity) == 8 assert len(result.feature_monosemanticity) == 8 assert len(result.per_feature_refusal_reduction) == 8 assert len(result.cumulative_refusal_reduction) == 8 assert 0.0 <= result.raw_direction_overlap <= 1.0 def test_feature_clustering(self): harmful, harmless, _ = _make_activations(hidden_dim=16, n_per_class=30) pipeline = SAEDecompositionPipeline( expansion=2, n_epochs=10, top_k_features=8, n_clusters=3, ) result = pipeline.run(harmful, harmless) clusters = result.feature_clusters assert clusters is not None assert isinstance(clusters, FeatureClusterResult) assert clusters.n_clusters == 3 assert len(clusters.cluster_labels) == 8 assert all(0 <= lbl < 3 for lbl in clusters.cluster_labels) assert clusters.cluster_directions.shape[0] == 3 assert -1.0 <= clusters.silhouette_score <= 1.0 def test_cumulative_reduction_monotonic(self): harmful, harmless, _ = _make_activations(hidden_dim=16, n_per_class=30, separation=3.0) pipeline = SAEDecompositionPipeline(expansion=2, n_epochs=10, top_k_features=6) result = pipeline.run(harmful, harmless) # Cumulative reduction should be non-decreasing for i in range(1, len(result.cumulative_refusal_reduction)): assert result.cumulative_refusal_reduction[i] >= result.cumulative_refusal_reduction[i - 1] - 1e-6 def test_format_report(self): harmful, harmless, _ = _make_activations(hidden_dim=16, n_per_class=20) pipeline = SAEDecompositionPipeline(expansion=2, n_epochs=5, top_k_features=4, n_clusters=2) result = pipeline.run(harmful, harmless) report = SAEDecompositionPipeline.format_report(result) assert "SAE Feature Decomposition" in report assert "Variance explained" in report # =========================================================================== # Tests: Wasserstein-Optimal Direction Extraction # =========================================================================== class TestWassersteinOptimalExtractor: def test_basic_extraction(self): harmful, harmless, planted_dir = _make_activations( hidden_dim=32, n_per_class=30, separation=3.0, ) extractor = WassersteinOptimalExtractor() result = extractor.extract(harmful, harmless, layer_idx=0) assert isinstance(result, WassersteinDirectionResult) assert result.layer_idx == 0 assert result.direction.shape == (32,) assert abs(result.direction.norm().item() - 1.0) < 1e-5 assert result.wasserstein_cost >= 0 assert result.mean_shift_component >= 0 assert result.bures_component >= 0 assert result.cost_effectiveness_ratio >= 0 def test_direction_captures_signal(self): """Wasserstein direction should have non-trivial refusal projection.""" harmful, harmless, planted_dir = _make_activations( hidden_dim=32, n_per_class=30, separation=3.0, ) extractor = WassersteinOptimalExtractor() result = extractor.extract(harmful, harmless) # Direction should have some alignment with planted signal cosine = abs((result.direction @ planted_dir).item()) assert cosine > 0.1 # not totally orthogonal def test_extract_all_layers(self): harmful_acts, harmless_acts, _ = _make_multilayer_activations( n_layers=4, hidden_dim=16, n_per_class=20, ) extractor = WassersteinOptimalExtractor() result = extractor.extract_all_layers(harmful_acts, harmless_acts) assert isinstance(result, MultiLayerWassersteinResult) assert len(result.per_layer) == 4 assert result.best_layer in range(4) assert result.mean_cost_ratio >= 0 def test_compare_with_alternatives(self): harmful, harmless, planted_dir = _make_activations( hidden_dim=16, n_per_class=30, separation=3.0, ) extractor = WassersteinOptimalExtractor() w_result = extractor.extract(harmful, harmless) # Use planted direction as "Fisher" and diff-in-means H = torch.stack(harmful).float() B = torch.stack(harmless).float() dim_dir = (H.mean(0) - B.mean(0)) dim_dir = dim_dir / dim_dir.norm() comparison = extractor.compare_with_alternatives( w_result, harmful, harmless, fisher_direction=planted_dir, dim_direction=dim_dir, ) assert isinstance(comparison, WassersteinComparisonResult) assert comparison.wasserstein_cost_ratio >= 0 assert comparison.fisher_cost_ratio is not None assert comparison.dim_cost_ratio is not None assert 0 <= comparison.cosine_wasserstein_fisher <= 1 assert 0 <= comparison.cosine_wasserstein_dim <= 1 def test_wasserstein_lower_cost_than_dim(self): """Wasserstein-optimal should have lower cost ratio than diff-in-means.""" harmful, harmless, _ = _make_activations( hidden_dim=32, n_per_class=50, separation=2.0, ) extractor = WassersteinOptimalExtractor() w_result = extractor.extract(harmful, harmless) H = torch.stack(harmful).float() B = torch.stack(harmless).float() dim_dir = (H.mean(0) - B.mean(0)) dim_dir = dim_dir / dim_dir.norm() comparison = extractor.compare_with_alternatives( w_result, harmful, harmless, dim_direction=dim_dir, ) # Wasserstein should have lower or equal cost ratio by construction assert comparison.wasserstein_cost_ratio <= comparison.dim_cost_ratio + 1e-4 def test_format_report(self): harmful, harmless, _ = _make_activations(hidden_dim=16, n_per_class=20) extractor = WassersteinOptimalExtractor() result = extractor.extract_all_layers( {0: harmful, 1: harmful}, {0: harmless, 1: harmless}, ) report = WassersteinOptimalExtractor.format_report(result) assert "Wasserstein" in report assert "cost ratio" in report.lower() # =========================================================================== # Tests: Bayesian-Optimized Kernel Projection # =========================================================================== class TestBayesianKernelProjection: def test_basic_optimization(self): harmful_acts, harmless_acts, directions = _make_multilayer_activations( n_layers=6, hidden_dim=16, n_per_class=20, ) optimizer = BayesianKernelProjection( n_trials=30, refusal_weight=0.6, distortion_weight=0.4, ) result = optimizer.optimize(harmful_acts, harmless_acts, directions) assert isinstance(result, BayesianOptimizationResult) assert result.n_trials == 30 assert result.best_score >= 0 assert 0 <= result.best_refusal_reduction <= 1.0 assert result.best_harmless_distortion >= 0 assert len(result.all_trials) == 30 def test_best_config_structure(self): harmful_acts, harmless_acts, directions = _make_multilayer_activations( n_layers=4, hidden_dim=16, n_per_class=15, ) optimizer = BayesianKernelProjection(n_trials=20) result = optimizer.optimize(harmful_acts, harmless_acts, directions) config = result.best_config assert isinstance(config, ProjectionConfig) assert config.layer_range[0] <= config.layer_range[1] assert config.n_directions >= 1 assert 0 <= config.regularization <= 0.5 def test_pareto_front(self): harmful_acts, harmless_acts, directions = _make_multilayer_activations( n_layers=6, hidden_dim=16, n_per_class=20, ) optimizer = BayesianKernelProjection(n_trials=50) result = optimizer.optimize(harmful_acts, harmless_acts, directions) # Pareto front should have at least 1 entry assert len(result.pareto_configs) >= 1 # Pareto entries should be non-dominated for i in range(len(result.pareto_configs) - 1): # Each entry should have lower distortion than the next # (since they're sorted by decreasing refusal reduction) assert ( result.pareto_configs[i].harmless_distortion >= result.pareto_configs[i + 1].harmless_distortion - 1e-8 ) def test_layer_importance(self): harmful_acts, harmless_acts, directions = _make_multilayer_activations( n_layers=6, hidden_dim=16, n_per_class=20, ) optimizer = BayesianKernelProjection(n_trials=50) result = optimizer.optimize(harmful_acts, harmless_acts, directions) assert len(result.layer_importance) == 6 for _layer, imp in result.layer_importance.items(): assert 0 <= imp <= 1.0 def test_tpe_improves_over_random(self): """TPE phase should produce better configs than random exploration.""" harmful_acts, harmless_acts, directions = _make_multilayer_activations( n_layers=6, hidden_dim=16, n_per_class=20, ) optimizer = BayesianKernelProjection(n_trials=60, seed=42) result = optimizer.optimize(harmful_acts, harmless_acts, directions) # Compare average score of first 20 (random) vs last 20 (TPE) first_20 = sorted(result.all_trials[:20], key=lambda t: t.combined_score) last_20 = sorted(result.all_trials[-20:], key=lambda t: t.combined_score) best_random = first_20[0].combined_score best_tpe = min(t.combined_score for t in last_20) # TPE should find at least as good (lower = better) # This is probabilistic so we allow some slack assert best_tpe <= best_random + 0.3 def test_empty_input(self): optimizer = BayesianKernelProjection(n_trials=10) result = optimizer.optimize({}, {}, {}) assert result.n_trials == 0 assert result.best_score == 0.0 def test_format_report(self): harmful_acts, harmless_acts, directions = _make_multilayer_activations( n_layers=4, hidden_dim=16, n_per_class=15, ) optimizer = BayesianKernelProjection(n_trials=20) result = optimizer.optimize(harmful_acts, harmless_acts, directions) report = BayesianKernelProjection.format_report(result) assert "Bayesian" in report assert "Pareto" in report assert "Layer importance" in report # =========================================================================== # Tests: Module imports # =========================================================================== class TestModuleImports: def test_all_new_modules_importable(self): from obliteratus.analysis import TunedLensTrainer from obliteratus.analysis import RefusalTunedLens from obliteratus.analysis import ActivationPatcher from obliteratus.analysis import WassersteinOptimalExtractor from obliteratus.analysis import BayesianKernelProjection from obliteratus.analysis import SAEDecompositionPipeline assert TunedLensTrainer is not None assert RefusalTunedLens is not None assert ActivationPatcher is not None assert WassersteinOptimalExtractor is not None assert BayesianKernelProjection is not None assert SAEDecompositionPipeline is not None def test_new_modules_in_all(self): import obliteratus.analysis as analysis assert "TunedLensTrainer" in analysis.__all__ assert "RefusalTunedLens" in analysis.__all__ assert "ActivationPatcher" in analysis.__all__ assert "WassersteinOptimalExtractor" in analysis.__all__ assert "BayesianKernelProjection" in analysis.__all__ assert "SAEDecompositionPipeline" in analysis.__all__