obliteratus / tests /test_new_analysis_modules.py
pliny-the-prompter's picture
Upload 127 files
45113e6 verified
"""Tests for the five new analysis modules:
1. Tuned Lens (learned-affine logit lens variant)
2. Activation Patching (real interchange intervention)
3. Enhanced SAE Decomposition Pipeline
4. Wasserstein-Optimal Direction Extraction
5. Bayesian-Optimized Kernel Projection
"""
from __future__ import annotations
import pytest
import torch
import torch.nn as nn
from obliteratus.analysis.tuned_lens import (
TunedLensTrainer,
TunedLensProbe,
RefusalTunedLens,
TunedLensResult,
MultiLayerTunedLensResult,
)
from obliteratus.analysis.activation_patching import (
ActivationPatcher,
PatchingSite,
ActivationPatchingResult,
)
from obliteratus.analysis.sae_abliteration import (
SAEDecompositionPipeline,
SAEDecompositionResult,
FeatureClusterResult,
)
from obliteratus.analysis.wasserstein_optimal import (
WassersteinOptimalExtractor,
WassersteinDirectionResult,
WassersteinComparisonResult,
MultiLayerWassersteinResult,
)
from obliteratus.analysis.bayesian_kernel_projection import (
BayesianKernelProjection,
BayesianOptimizationResult,
ProjectionConfig,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_activations(
hidden_dim=32, n_per_class=20, separation=2.0, seed=42,
):
"""Create harmful/harmless activations with planted refusal signal."""
torch.manual_seed(seed)
direction = torch.randn(hidden_dim)
direction = direction / direction.norm()
harmful = [
torch.randn(hidden_dim) * 0.3 + separation * direction
for _ in range(n_per_class)
]
harmless = [
torch.randn(hidden_dim) * 0.3
for _ in range(n_per_class)
]
return harmful, harmless, direction
def _make_multilayer_activations(
n_layers=6, hidden_dim=32, n_per_class=20, separation=2.0, seed=42,
):
"""Create per-layer activations with planted refusal signals."""
torch.manual_seed(seed)
harmful_acts = {}
harmless_acts = {}
directions = {}
for li in range(n_layers):
d = torch.randn(hidden_dim)
d = d / d.norm()
directions[li] = d
strength = separation if 1 <= li <= n_layers - 2 else 0.3
harmful_acts[li] = [
torch.randn(hidden_dim) * 0.3 + strength * d
for _ in range(n_per_class)
]
harmless_acts[li] = [
torch.randn(hidden_dim) * 0.3
for _ in range(n_per_class)
]
return harmful_acts, harmless_acts, directions
class FakeTokenizer:
"""Fake tokenizer that maps strings to reproducible token IDs."""
def __init__(self, vocab_size=100):
self.vocab_size = vocab_size
def encode(self, text, add_special_tokens=False):
return [hash(text) % self.vocab_size]
def decode(self, ids):
return f"tok_{ids[0]}"
class FakeModel(nn.Module):
"""Fake model with lm_head and transformer.ln_f for testing."""
def __init__(self, hidden_dim=32, vocab_size=100, n_layers=4):
super().__init__()
self.hidden_dim = hidden_dim
self.vocab_size = vocab_size
self.n_layers = n_layers
self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
self.transformer = nn.Module()
self.transformer.ln_f = nn.LayerNorm(hidden_dim)
self.transformer.h = nn.ModuleList([
nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers)
])
def forward(self, input_ids):
# Fake forward pass
batch_size, seq_len = input_ids.shape
x = torch.randn(batch_size, seq_len, self.hidden_dim)
for layer in self.transformer.h:
x = layer(x) + x
logits = self.lm_head(self.transformer.ln_f(x))
return type('Output', (), {'logits': logits})()
# ===========================================================================
# Tests: Tuned Lens
# ===========================================================================
class TestTunedLensTrainer:
def test_train_single_probe(self):
hidden_dim = 16
n_samples = 30
layer_acts = torch.randn(n_samples, hidden_dim)
final_acts = layer_acts + torch.randn(n_samples, hidden_dim) * 0.1
trainer = TunedLensTrainer(hidden_dim, n_epochs=20)
probe = trainer.train_probe(layer_acts, final_acts, layer_idx=3)
assert isinstance(probe, TunedLensProbe)
assert probe.layer_idx == 3
assert probe.weight.shape == (hidden_dim, hidden_dim)
assert probe.bias.shape == (hidden_dim,)
assert probe.train_loss < 1.0 # should converge somewhat
def test_train_all_layers(self):
hidden_dim = 16
n_samples = 20
layer_acts = {
i: torch.randn(n_samples, hidden_dim) for i in range(4)
}
final_acts = torch.randn(n_samples, hidden_dim)
trainer = TunedLensTrainer(hidden_dim, n_epochs=10)
probes = trainer.train_all_layers(layer_acts, final_acts)
assert len(probes) == 4
for i in range(4):
assert i in probes
assert probes[i].weight.shape == (hidden_dim, hidden_dim)
def test_probe_near_identity_for_final_layer(self):
"""Probe for the final layer should be close to identity."""
hidden_dim = 16
n_samples = 50
acts = torch.randn(n_samples, hidden_dim)
trainer = TunedLensTrainer(hidden_dim, n_epochs=50)
probe = trainer.train_probe(acts, acts, layer_idx=0)
# Weight should be close to identity
identity = torch.eye(hidden_dim)
diff = (probe.weight - identity).norm().item()
assert diff < 1.0
class TestRefusalTunedLens:
def test_analyze_direction(self):
hidden_dim = 32
vocab_size = 100
model = FakeModel(hidden_dim, vocab_size)
tokenizer = FakeTokenizer(vocab_size)
direction = torch.randn(hidden_dim)
probe = TunedLensProbe(
layer_idx=2,
weight=torch.eye(hidden_dim) + torch.randn(hidden_dim, hidden_dim) * 0.01,
bias=torch.zeros(hidden_dim),
train_loss=0.01,
)
lens = RefusalTunedLens(top_k=10)
result = lens.analyze_direction(direction, probe, model, tokenizer)
assert isinstance(result, TunedLensResult)
assert result.layer_idx == 2
assert len(result.top_promoted) <= 10
assert len(result.top_suppressed) <= 10
assert isinstance(result.correction_magnitude, float)
assert result.correction_magnitude >= 0
def test_analyze_all_layers(self):
hidden_dim = 32
vocab_size = 100
model = FakeModel(hidden_dim, vocab_size)
tokenizer = FakeTokenizer(vocab_size)
directions = {
i: torch.randn(hidden_dim) for i in range(4)
}
probes = {
i: TunedLensProbe(
layer_idx=i,
weight=torch.eye(hidden_dim),
bias=torch.zeros(hidden_dim),
train_loss=0.01,
)
for i in range(4)
}
lens = RefusalTunedLens(top_k=5)
result = lens.analyze_all_layers(directions, probes, model, tokenizer)
assert isinstance(result, MultiLayerTunedLensResult)
assert len(result.per_layer) == 4
assert result.strongest_refusal_layer in range(4)
def test_compare_with_logit_lens(self):
logit_gaps = {0: 0.1, 1: 0.5, 2: 0.3, 3: 0.8}
tuned_result = MultiLayerTunedLensResult(
per_layer={
i: TunedLensResult(
layer_idx=i,
top_promoted=[], top_suppressed=[],
refusal_token_mean_boost=0.0,
compliance_token_mean_boost=0.0,
refusal_compliance_gap=v * 1.1, # similar ranking
correction_magnitude=0.1,
)
for i, v in logit_gaps.items()
},
probes={},
strongest_refusal_layer=3,
peak_gap_layer=3,
mean_refusal_compliance_gap=0.5,
logit_lens_agreement=0.0,
)
agreement = RefusalTunedLens.compare_with_logit_lens(tuned_result, logit_gaps)
# Same ranking → correlation should be 1.0
assert agreement == pytest.approx(1.0, abs=0.01)
def test_format_report(self):
result = MultiLayerTunedLensResult(
per_layer={},
probes={},
strongest_refusal_layer=0,
peak_gap_layer=0,
mean_refusal_compliance_gap=0.0,
logit_lens_agreement=0.0,
)
report = RefusalTunedLens.format_report(result)
assert "Tuned Lens" in report
assert "No layers analyzed" in report
# ===========================================================================
# Tests: Activation Patching
# ===========================================================================
class TestActivationPatcher:
def test_patching_site_creation(self):
site = PatchingSite(layer_idx=3, component="residual")
assert site.layer_idx == 3
assert site.component == "residual"
assert site.head_idx is None
def test_patching_site_with_head(self):
site = PatchingSite(layer_idx=2, component="attn_head", head_idx=5)
assert site.head_idx == 5
def test_patch_sweep_with_model(self):
"""Test full patching sweep on fake model."""
hidden_dim = 32
model = FakeModel(hidden_dim, vocab_size=100, n_layers=4)
clean_ids = torch.randint(0, 100, (1, 10))
corrupted_ids = torch.randint(0, 100, (1, 10))
patcher = ActivationPatcher(significance_threshold=0.05)
result = patcher.patch_sweep(
model, clean_ids, corrupted_ids,
mode="noising",
)
assert isinstance(result, ActivationPatchingResult)
assert result.patching_mode == "noising"
assert result.n_layers == 4
assert len(result.effects) > 0
assert isinstance(result.circuit_fraction, float)
assert 0.0 <= result.circuit_fraction <= 1.0
def test_patch_sweep_denoising(self):
hidden_dim = 32
model = FakeModel(hidden_dim, vocab_size=100, n_layers=4)
clean_ids = torch.randint(0, 100, (1, 10))
corrupted_ids = torch.randint(0, 100, (1, 10))
patcher = ActivationPatcher()
result = patcher.patch_sweep(
model, clean_ids, corrupted_ids,
mode="denoising",
)
assert result.patching_mode == "denoising"
def test_custom_metric(self):
hidden_dim = 32
model = FakeModel(hidden_dim, vocab_size=100, n_layers=4)
clean_ids = torch.randint(0, 100, (1, 10))
corrupted_ids = torch.randint(0, 100, (1, 10))
def custom_metric(logits):
return logits.sum().item()
patcher = ActivationPatcher(metric_fn=custom_metric)
result = patcher.patch_sweep(model, clean_ids, corrupted_ids)
assert isinstance(result, ActivationPatchingResult)
assert isinstance(result.clean_baseline, float)
def test_format_report(self):
result = ActivationPatchingResult(
n_layers=4,
n_sites=4,
patching_mode="noising",
effects=[],
clean_baseline=1.0,
corrupted_baseline=0.0,
total_effect=1.0,
significant_sites=[],
circuit_fraction=0.0,
top_causal_layers=[],
)
report = ActivationPatcher.format_report(result)
assert "Activation Patching" in report
assert "noising" in report
# ===========================================================================
# Tests: Enhanced SAE Decomposition Pipeline
# ===========================================================================
class TestSAEDecompositionPipeline:
def test_basic_pipeline(self):
harmful, harmless, _ = _make_activations(hidden_dim=16, n_per_class=30, separation=2.0)
pipeline = SAEDecompositionPipeline(
expansion=2, n_epochs=10, top_k_features=8, n_clusters=3,
)
result = pipeline.run(harmful, harmless, layer_idx=0)
assert isinstance(result, SAEDecompositionResult)
assert result.layer_idx == 0
assert result.sae is not None
assert result.refusal_features.n_refusal_features == 8
assert len(result.feature_sparsity) == 8
assert len(result.feature_monosemanticity) == 8
assert len(result.per_feature_refusal_reduction) == 8
assert len(result.cumulative_refusal_reduction) == 8
assert 0.0 <= result.raw_direction_overlap <= 1.0
def test_feature_clustering(self):
harmful, harmless, _ = _make_activations(hidden_dim=16, n_per_class=30)
pipeline = SAEDecompositionPipeline(
expansion=2, n_epochs=10, top_k_features=8, n_clusters=3,
)
result = pipeline.run(harmful, harmless)
clusters = result.feature_clusters
assert clusters is not None
assert isinstance(clusters, FeatureClusterResult)
assert clusters.n_clusters == 3
assert len(clusters.cluster_labels) == 8
assert all(0 <= lbl < 3 for lbl in clusters.cluster_labels)
assert clusters.cluster_directions.shape[0] == 3
assert -1.0 <= clusters.silhouette_score <= 1.0
def test_cumulative_reduction_monotonic(self):
harmful, harmless, _ = _make_activations(hidden_dim=16, n_per_class=30, separation=3.0)
pipeline = SAEDecompositionPipeline(expansion=2, n_epochs=10, top_k_features=6)
result = pipeline.run(harmful, harmless)
# Cumulative reduction should be non-decreasing
for i in range(1, len(result.cumulative_refusal_reduction)):
assert result.cumulative_refusal_reduction[i] >= result.cumulative_refusal_reduction[i - 1] - 1e-6
def test_format_report(self):
harmful, harmless, _ = _make_activations(hidden_dim=16, n_per_class=20)
pipeline = SAEDecompositionPipeline(expansion=2, n_epochs=5, top_k_features=4, n_clusters=2)
result = pipeline.run(harmful, harmless)
report = SAEDecompositionPipeline.format_report(result)
assert "SAE Feature Decomposition" in report
assert "Variance explained" in report
# ===========================================================================
# Tests: Wasserstein-Optimal Direction Extraction
# ===========================================================================
class TestWassersteinOptimalExtractor:
def test_basic_extraction(self):
harmful, harmless, planted_dir = _make_activations(
hidden_dim=32, n_per_class=30, separation=3.0,
)
extractor = WassersteinOptimalExtractor()
result = extractor.extract(harmful, harmless, layer_idx=0)
assert isinstance(result, WassersteinDirectionResult)
assert result.layer_idx == 0
assert result.direction.shape == (32,)
assert abs(result.direction.norm().item() - 1.0) < 1e-5
assert result.wasserstein_cost >= 0
assert result.mean_shift_component >= 0
assert result.bures_component >= 0
assert result.cost_effectiveness_ratio >= 0
def test_direction_captures_signal(self):
"""Wasserstein direction should have non-trivial refusal projection."""
harmful, harmless, planted_dir = _make_activations(
hidden_dim=32, n_per_class=30, separation=3.0,
)
extractor = WassersteinOptimalExtractor()
result = extractor.extract(harmful, harmless)
# Direction should have some alignment with planted signal
cosine = abs((result.direction @ planted_dir).item())
assert cosine > 0.1 # not totally orthogonal
def test_extract_all_layers(self):
harmful_acts, harmless_acts, _ = _make_multilayer_activations(
n_layers=4, hidden_dim=16, n_per_class=20,
)
extractor = WassersteinOptimalExtractor()
result = extractor.extract_all_layers(harmful_acts, harmless_acts)
assert isinstance(result, MultiLayerWassersteinResult)
assert len(result.per_layer) == 4
assert result.best_layer in range(4)
assert result.mean_cost_ratio >= 0
def test_compare_with_alternatives(self):
harmful, harmless, planted_dir = _make_activations(
hidden_dim=16, n_per_class=30, separation=3.0,
)
extractor = WassersteinOptimalExtractor()
w_result = extractor.extract(harmful, harmless)
# Use planted direction as "Fisher" and diff-in-means
H = torch.stack(harmful).float()
B = torch.stack(harmless).float()
dim_dir = (H.mean(0) - B.mean(0))
dim_dir = dim_dir / dim_dir.norm()
comparison = extractor.compare_with_alternatives(
w_result, harmful, harmless,
fisher_direction=planted_dir,
dim_direction=dim_dir,
)
assert isinstance(comparison, WassersteinComparisonResult)
assert comparison.wasserstein_cost_ratio >= 0
assert comparison.fisher_cost_ratio is not None
assert comparison.dim_cost_ratio is not None
assert 0 <= comparison.cosine_wasserstein_fisher <= 1
assert 0 <= comparison.cosine_wasserstein_dim <= 1
def test_wasserstein_lower_cost_than_dim(self):
"""Wasserstein-optimal should have lower cost ratio than diff-in-means."""
harmful, harmless, _ = _make_activations(
hidden_dim=32, n_per_class=50, separation=2.0,
)
extractor = WassersteinOptimalExtractor()
w_result = extractor.extract(harmful, harmless)
H = torch.stack(harmful).float()
B = torch.stack(harmless).float()
dim_dir = (H.mean(0) - B.mean(0))
dim_dir = dim_dir / dim_dir.norm()
comparison = extractor.compare_with_alternatives(
w_result, harmful, harmless, dim_direction=dim_dir,
)
# Wasserstein should have lower or equal cost ratio by construction
assert comparison.wasserstein_cost_ratio <= comparison.dim_cost_ratio + 1e-4
def test_format_report(self):
harmful, harmless, _ = _make_activations(hidden_dim=16, n_per_class=20)
extractor = WassersteinOptimalExtractor()
result = extractor.extract_all_layers(
{0: harmful, 1: harmful},
{0: harmless, 1: harmless},
)
report = WassersteinOptimalExtractor.format_report(result)
assert "Wasserstein" in report
assert "cost ratio" in report.lower()
# ===========================================================================
# Tests: Bayesian-Optimized Kernel Projection
# ===========================================================================
class TestBayesianKernelProjection:
def test_basic_optimization(self):
harmful_acts, harmless_acts, directions = _make_multilayer_activations(
n_layers=6, hidden_dim=16, n_per_class=20,
)
optimizer = BayesianKernelProjection(
n_trials=30, refusal_weight=0.6, distortion_weight=0.4,
)
result = optimizer.optimize(harmful_acts, harmless_acts, directions)
assert isinstance(result, BayesianOptimizationResult)
assert result.n_trials == 30
assert result.best_score >= 0
assert 0 <= result.best_refusal_reduction <= 1.0
assert result.best_harmless_distortion >= 0
assert len(result.all_trials) == 30
def test_best_config_structure(self):
harmful_acts, harmless_acts, directions = _make_multilayer_activations(
n_layers=4, hidden_dim=16, n_per_class=15,
)
optimizer = BayesianKernelProjection(n_trials=20)
result = optimizer.optimize(harmful_acts, harmless_acts, directions)
config = result.best_config
assert isinstance(config, ProjectionConfig)
assert config.layer_range[0] <= config.layer_range[1]
assert config.n_directions >= 1
assert 0 <= config.regularization <= 0.5
def test_pareto_front(self):
harmful_acts, harmless_acts, directions = _make_multilayer_activations(
n_layers=6, hidden_dim=16, n_per_class=20,
)
optimizer = BayesianKernelProjection(n_trials=50)
result = optimizer.optimize(harmful_acts, harmless_acts, directions)
# Pareto front should have at least 1 entry
assert len(result.pareto_configs) >= 1
# Pareto entries should be non-dominated
for i in range(len(result.pareto_configs) - 1):
# Each entry should have lower distortion than the next
# (since they're sorted by decreasing refusal reduction)
assert (
result.pareto_configs[i].harmless_distortion
>= result.pareto_configs[i + 1].harmless_distortion - 1e-8
)
def test_layer_importance(self):
harmful_acts, harmless_acts, directions = _make_multilayer_activations(
n_layers=6, hidden_dim=16, n_per_class=20,
)
optimizer = BayesianKernelProjection(n_trials=50)
result = optimizer.optimize(harmful_acts, harmless_acts, directions)
assert len(result.layer_importance) == 6
for _layer, imp in result.layer_importance.items():
assert 0 <= imp <= 1.0
def test_tpe_improves_over_random(self):
"""TPE phase should produce better configs than random exploration."""
harmful_acts, harmless_acts, directions = _make_multilayer_activations(
n_layers=6, hidden_dim=16, n_per_class=20,
)
optimizer = BayesianKernelProjection(n_trials=60, seed=42)
result = optimizer.optimize(harmful_acts, harmless_acts, directions)
# Compare average score of first 20 (random) vs last 20 (TPE)
first_20 = sorted(result.all_trials[:20], key=lambda t: t.combined_score)
last_20 = sorted(result.all_trials[-20:], key=lambda t: t.combined_score)
best_random = first_20[0].combined_score
best_tpe = min(t.combined_score for t in last_20)
# TPE should find at least as good (lower = better)
# This is probabilistic so we allow some slack
assert best_tpe <= best_random + 0.3
def test_empty_input(self):
optimizer = BayesianKernelProjection(n_trials=10)
result = optimizer.optimize({}, {}, {})
assert result.n_trials == 0
assert result.best_score == 0.0
def test_format_report(self):
harmful_acts, harmless_acts, directions = _make_multilayer_activations(
n_layers=4, hidden_dim=16, n_per_class=15,
)
optimizer = BayesianKernelProjection(n_trials=20)
result = optimizer.optimize(harmful_acts, harmless_acts, directions)
report = BayesianKernelProjection.format_report(result)
assert "Bayesian" in report
assert "Pareto" in report
assert "Layer importance" in report
# ===========================================================================
# Tests: Module imports
# ===========================================================================
class TestModuleImports:
def test_all_new_modules_importable(self):
from obliteratus.analysis import TunedLensTrainer
from obliteratus.analysis import RefusalTunedLens
from obliteratus.analysis import ActivationPatcher
from obliteratus.analysis import WassersteinOptimalExtractor
from obliteratus.analysis import BayesianKernelProjection
from obliteratus.analysis import SAEDecompositionPipeline
assert TunedLensTrainer is not None
assert RefusalTunedLens is not None
assert ActivationPatcher is not None
assert WassersteinOptimalExtractor is not None
assert BayesianKernelProjection is not None
assert SAEDecompositionPipeline is not None
def test_new_modules_in_all(self):
import obliteratus.analysis as analysis
assert "TunedLensTrainer" in analysis.__all__
assert "RefusalTunedLens" in analysis.__all__
assert "ActivationPatcher" in analysis.__all__
assert "WassersteinOptimalExtractor" in analysis.__all__
assert "BayesianKernelProjection" in analysis.__all__
assert "SAEDecompositionPipeline" in analysis.__all__