obliteratus / tests /test_causal_and_transfer.py
pliny-the-prompter's picture
Upload 127 files
45113e6 verified
"""Tests for causal tracing, residual stream decomposition,
probing classifiers, and cross-model transfer analysis."""
from __future__ import annotations
import math
import torch
from obliteratus.analysis.causal_tracing import (
CausalRefusalTracer,
CausalTracingResult,
ComponentCausalEffect,
)
from obliteratus.analysis.residual_stream import (
ResidualStreamDecomposer,
ResidualStreamResult,
LayerDecomposition,
)
from obliteratus.analysis.probing_classifiers import (
LinearRefusalProbe,
ProbeResult,
ProbingSuiteResult,
)
from obliteratus.analysis.cross_model_transfer import (
TransferAnalyzer,
CrossModelResult,
CrossCategoryResult,
CrossLayerResult,
UniversalityReport,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_layer_activations(
n_layers=8, hidden_dim=32, refusal_strength=2.0,
):
"""Create synthetic per-layer activations with planted refusal signal."""
torch.manual_seed(42)
directions = {}
activations = {}
base = torch.randn(hidden_dim) * 0.1
for i in range(n_layers):
d = torch.randn(hidden_dim)
d = d / d.norm()
directions[i] = d
# Stronger refusal in middle layers
strength = refusal_strength if 2 <= i <= 5 else 0.3
activations[i] = base + strength * d + torch.randn(hidden_dim) * 0.05
return activations, directions
def _make_separable_activations(
n_per_class=20, hidden_dim=16, separation=3.0, seed=42,
):
"""Create harmful/harmless activations that are linearly separable."""
torch.manual_seed(seed)
direction = torch.randn(hidden_dim)
direction = direction / direction.norm()
harmful = [
torch.randn(hidden_dim) * 0.5 + separation * direction
for _ in range(n_per_class)
]
harmless = [
torch.randn(hidden_dim) * 0.5 - separation * direction
for _ in range(n_per_class)
]
return harmful, harmless, direction
# ===========================================================================
# Tests: Causal Tracing
# ===========================================================================
class TestCausalTracing:
def test_basic_tracing(self):
activations, directions = _make_layer_activations()
tracer = CausalRefusalTracer(noise_level=3.0)
result = tracer.trace_from_activations(activations, directions)
assert isinstance(result, CausalTracingResult)
assert result.n_layers == 8
assert result.clean_refusal_strength > 0
assert len(result.component_effects) == 8
def test_causal_components_identified(self):
activations, directions = _make_layer_activations()
tracer = CausalRefusalTracer(noise_level=3.0, causal_threshold=0.05)
result = tracer.trace_from_activations(activations, directions)
assert result.circuit_size > 0
assert result.circuit_fraction > 0
assert len(result.causal_components) > 0
def test_corruption_reduces_strength(self):
activations, directions = _make_layer_activations(refusal_strength=5.0)
tracer = CausalRefusalTracer(noise_level=10.0)
result = tracer.trace_from_activations(activations, directions)
# With high noise, corrupted should differ from clean
assert result.total_corruption_effect != 0
def test_single_direction_input(self):
activations, directions = _make_layer_activations()
single_dir = directions[3] # Use one direction for all layers
tracer = CausalRefusalTracer()
result = tracer.trace_from_activations(activations, single_dir)
assert result.n_layers == 8
assert len(result.component_effects) == 8
def test_component_effects_structure(self):
activations, directions = _make_layer_activations()
tracer = CausalRefusalTracer()
result = tracer.trace_from_activations(activations, directions)
for e in result.component_effects:
assert isinstance(e, ComponentCausalEffect)
assert e.component_type == "full_layer"
assert e.causal_effect >= 0
def test_correlation_causal_agreement_bounded(self):
activations, directions = _make_layer_activations()
tracer = CausalRefusalTracer()
result = tracer.trace_from_activations(activations, directions)
assert -1.0 <= result.correlation_causal_agreement <= 1.0
def test_silent_contributors(self):
activations, directions = _make_layer_activations()
tracer = CausalRefusalTracer()
result = tracer.trace_from_activations(activations, directions)
sc = tracer.identify_silent_contributors(result, top_k=3)
assert "silent_contributors" in sc
assert "loud_non_contributors" in sc
assert len(sc["silent_contributors"]) <= 3
def test_custom_component_types(self):
activations, directions = _make_layer_activations()
tracer = CausalRefusalTracer()
result = tracer.trace_from_activations(
activations, directions,
component_types=["attention", "mlp"],
)
# 8 layers * 2 types = 16 effects
assert len(result.component_effects) == 16
def test_format_report(self):
activations, directions = _make_layer_activations()
tracer = CausalRefusalTracer()
result = tracer.trace_from_activations(activations, directions)
report = CausalRefusalTracer.format_tracing_report(result)
assert "Causal Tracing" in report
assert "Circuit size" in report
# ===========================================================================
# Tests: Residual Stream Decomposition
# ===========================================================================
class TestResidualStreamDecomposition:
def test_basic_decomposition(self):
activations, directions = _make_layer_activations()
decomposer = ResidualStreamDecomposer()
result = decomposer.decompose(activations, directions)
assert isinstance(result, ResidualStreamResult)
assert result.n_layers == 8
assert len(result.per_layer) == 8
assert result.total_attention_contribution > 0
assert result.total_mlp_contribution > 0
def test_attention_fraction_bounded(self):
activations, directions = _make_layer_activations()
decomposer = ResidualStreamDecomposer()
result = decomposer.decompose(activations, directions)
assert 0 <= result.attention_fraction <= 1.0
def test_with_head_count(self):
activations, directions = _make_layer_activations()
decomposer = ResidualStreamDecomposer(n_heads_per_layer=4)
result = decomposer.decompose(activations, directions)
assert result.n_refusal_heads >= 0
assert len(result.refusal_heads) > 0
def test_layer_decomposition_structure(self):
activations, directions = _make_layer_activations()
decomposer = ResidualStreamDecomposer()
result = decomposer.decompose(activations, directions)
for _layer_idx, d in result.per_layer.items():
assert isinstance(d, LayerDecomposition)
assert 0 <= d.attn_mlp_ratio <= 1.0
assert d.cumulative_refusal >= 0
def test_accumulation_profile(self):
activations, directions = _make_layer_activations()
decomposer = ResidualStreamDecomposer()
result = decomposer.decompose(activations, directions)
assert len(result.accumulation_profile) == 8
# Accumulation should be monotonically non-decreasing
for i in range(1, len(result.accumulation_profile)):
assert result.accumulation_profile[i] >= result.accumulation_profile[i - 1]
def test_with_explicit_attn_mlp(self):
"""Test with provided attention and MLP outputs."""
torch.manual_seed(42)
hidden_dim = 16
n_layers = 4
ref_dir = torch.randn(hidden_dim)
ref_dir = ref_dir / ref_dir.norm()
acts = {}
attn_outs = {}
mlp_outs = {}
for i in range(n_layers):
attn = torch.randn(hidden_dim) * 0.5
mlp = torch.randn(hidden_dim) * 0.5
attn_outs[i] = attn
mlp_outs[i] = mlp
acts[i] = attn + mlp + (torch.randn(hidden_dim) * 0.1 if i == 0 else acts[i-1])
decomposer = ResidualStreamDecomposer()
result = decomposer.decompose(
acts, ref_dir,
attn_outputs=attn_outs, mlp_outputs=mlp_outs,
)
assert len(result.per_layer) == n_layers
def test_single_direction(self):
activations, _ = _make_layer_activations()
single_dir = torch.randn(32)
decomposer = ResidualStreamDecomposer()
result = decomposer.decompose(activations, single_dir)
assert result.n_layers == 8
def test_head_concentration_bounded(self):
activations, directions = _make_layer_activations()
decomposer = ResidualStreamDecomposer(n_heads_per_layer=8)
result = decomposer.decompose(activations, directions)
assert 0 <= result.head_concentration <= 1.0
def test_format_decomposition(self):
activations, directions = _make_layer_activations()
decomposer = ResidualStreamDecomposer(n_heads_per_layer=4)
result = decomposer.decompose(activations, directions)
report = ResidualStreamDecomposer.format_decomposition(result)
assert "Residual Stream" in report
assert "Attention" in report
assert "MLP" in report
# ===========================================================================
# Tests: Probing Classifiers
# ===========================================================================
class TestProbingClassifiers:
def test_separable_data_high_accuracy(self):
"""With well-separated data, probe should achieve high accuracy."""
harmful, harmless, direction = _make_separable_activations(
n_per_class=30, separation=5.0,
)
probe = LinearRefusalProbe(n_epochs=200)
result = probe.probe_layer(harmful, harmless, direction, layer_idx=5)
assert isinstance(result, ProbeResult)
assert result.layer_idx == 5
assert result.accuracy > 0.7 # Should be separable
def test_inseparable_data_low_accuracy(self):
"""With overlapping data, probe should have lower accuracy."""
harmful, harmless, direction = _make_separable_activations(
n_per_class=30, separation=0.01,
)
probe = LinearRefusalProbe(n_epochs=50)
result = probe.probe_layer(harmful, harmless, direction)
# Accuracy should be near chance (0.5)
assert result.accuracy < 0.9
def test_learned_direction_unit(self):
harmful, harmless, direction = _make_separable_activations()
probe = LinearRefusalProbe(n_epochs=100)
result = probe.probe_layer(harmful, harmless, direction)
assert abs(result.learned_direction.norm().item() - 1.0) < 0.01
def test_cosine_with_analytical(self):
"""Learned direction should align with analytical direction."""
harmful, harmless, direction = _make_separable_activations(
n_per_class=50, separation=5.0,
)
probe = LinearRefusalProbe(n_epochs=300)
result = probe.probe_layer(harmful, harmless, direction)
# With clear separation, learned direction should agree
assert result.cosine_with_analytical > 0.3
def test_without_analytical_direction(self):
harmful, harmless, _ = _make_separable_activations()
probe = LinearRefusalProbe(n_epochs=50)
result = probe.probe_layer(harmful, harmless)
assert result.cosine_with_analytical == 0.0
def test_auroc_bounded(self):
harmful, harmless, direction = _make_separable_activations()
probe = LinearRefusalProbe(n_epochs=100)
result = probe.probe_layer(harmful, harmless, direction)
assert 0 <= result.auroc <= 1.0
def test_mutual_information_nonnegative(self):
harmful, harmless, direction = _make_separable_activations()
probe = LinearRefusalProbe(n_epochs=100)
result = probe.probe_layer(harmful, harmless, direction)
assert result.mutual_information >= 0
def test_probe_all_layers(self):
harmful_acts = {}
harmless_acts = {}
anal_dirs = {}
for li in range(6):
harmful, harmless, direction = _make_separable_activations(
n_per_class=15, separation=3.0, seed=li * 10,
)
harmful_acts[li] = harmful
harmless_acts[li] = harmless
anal_dirs[li] = direction
probe = LinearRefusalProbe(n_epochs=100)
result = probe.probe_all_layers(harmful_acts, harmless_acts, anal_dirs)
assert isinstance(result, ProbingSuiteResult)
assert len(result.per_layer) == 6
assert result.best_accuracy > 0
assert result.total_mutual_information >= 0
def test_format_report(self):
harmful_acts = {}
harmless_acts = {}
for li in range(4):
harmful, harmless, _ = _make_separable_activations(
n_per_class=15, seed=li,
)
harmful_acts[li] = harmful
harmless_acts[li] = harmless
probe = LinearRefusalProbe(n_epochs=50)
result = probe.probe_all_layers(harmful_acts, harmless_acts)
report = LinearRefusalProbe.format_probing_report(result)
assert "Linear Probing" in report
assert "accuracy" in report.lower()
def test_cross_entropy_finite(self):
harmful, harmless, direction = _make_separable_activations()
probe = LinearRefusalProbe(n_epochs=100)
result = probe.probe_layer(harmful, harmless, direction)
assert math.isfinite(result.cross_entropy)
# ===========================================================================
# Tests: Cross-Model Transfer Analysis
# ===========================================================================
class TestTransferAnalysis:
def test_cross_model_identical(self):
"""Identical directions should give perfect transfer."""
torch.manual_seed(42)
dirs = {i: torch.randn(32) for i in range(8)}
analyzer = TransferAnalyzer()
result = analyzer.analyze_cross_model(dirs, dirs, "model_a", "model_a")
assert isinstance(result, CrossModelResult)
assert result.mean_transfer_score > 0.99
def test_cross_model_random(self):
"""Random directions should give low transfer."""
torch.manual_seed(42)
dirs_a = {i: torch.randn(32) for i in range(8)}
torch.manual_seed(99)
dirs_b = {i: torch.randn(32) for i in range(8)}
analyzer = TransferAnalyzer()
result = analyzer.analyze_cross_model(dirs_a, dirs_b, "a", "b")
# Random 32-dim vectors have low expected cosine
assert result.mean_transfer_score < 0.7
def test_cross_model_structure(self):
torch.manual_seed(42)
dirs_a = {i: torch.randn(32) for i in range(8)}
dirs_b = {i: torch.randn(32) for i in range(8)}
analyzer = TransferAnalyzer()
result = analyzer.analyze_cross_model(dirs_a, dirs_b)
assert 0 <= result.transfer_above_threshold <= 1.0
assert len(result.per_layer_transfer) == 8
def test_cross_category_similar(self):
"""Similar categories should cluster together."""
torch.manual_seed(42)
shared = torch.randn(32)
shared = shared / shared.norm()
cat_dirs = {}
for cat in ["weapons", "bombs", "explosives"]:
d = shared + 0.2 * torch.randn(32)
cat_dirs[cat] = d / d.norm()
# Add one very different category
cat_dirs["fraud"] = torch.randn(32)
analyzer = TransferAnalyzer()
result = analyzer.analyze_cross_category(cat_dirs)
assert isinstance(result, CrossCategoryResult)
assert result.mean_cross_category_transfer > 0
assert len(result.categories) == 4
def test_cross_category_specificity(self):
torch.manual_seed(42)
cat_dirs = {f"cat_{i}": torch.randn(16) for i in range(5)}
analyzer = TransferAnalyzer()
result = analyzer.analyze_cross_category(cat_dirs)
assert result.most_universal_category != ""
assert result.most_specific_category != ""
assert len(result.category_clusters) > 0
def test_cross_layer(self):
_, directions = _make_layer_activations()
analyzer = TransferAnalyzer()
result = analyzer.analyze_cross_layer(directions)
assert isinstance(result, CrossLayerResult)
assert result.mean_adjacent_transfer >= 0
assert result.transfer_decay_rate >= 0
def test_cross_layer_adjacent_vs_distant(self):
"""Adjacent layers typically have higher transfer than distant ones."""
torch.manual_seed(42)
# Create directions with gradual drift
d = torch.randn(32)
d = d / d.norm()
directions = {}
for i in range(10):
noise = torch.randn(32) * 0.1 * i
di = d + noise
directions[i] = di / di.norm()
analyzer = TransferAnalyzer()
result = analyzer.analyze_cross_layer(directions)
# Adjacent should have higher transfer than distant
assert result.mean_adjacent_transfer >= result.mean_distant_transfer - 0.1
def test_universality_index(self):
torch.manual_seed(42)
dirs = {i: torch.randn(32) for i in range(6)}
analyzer = TransferAnalyzer()
cross_model = analyzer.analyze_cross_model(dirs, dirs)
cross_layer = analyzer.analyze_cross_layer(dirs)
cat_dirs = {f"cat_{i}": torch.randn(32) for i in range(4)}
cross_cat = analyzer.analyze_cross_category(cat_dirs)
report = analyzer.compute_universality_index(
cross_model=cross_model,
cross_category=cross_cat,
cross_layer=cross_layer,
)
assert isinstance(report, UniversalityReport)
assert 0 <= report.universality_index <= 1.0
def test_universality_empty(self):
analyzer = TransferAnalyzer()
report = analyzer.compute_universality_index()
assert report.universality_index == 0.0
def test_format_cross_model(self):
torch.manual_seed(42)
dirs = {i: torch.randn(32) for i in range(4)}
analyzer = TransferAnalyzer()
result = analyzer.analyze_cross_model(dirs, dirs, "llama", "mistral")
report = TransferAnalyzer.format_cross_model(result)
assert "Cross-Model" in report
assert "llama" in report
def test_format_cross_category(self):
torch.manual_seed(42)
cat_dirs = {f"cat_{i}": torch.randn(16) for i in range(3)}
analyzer = TransferAnalyzer()
result = analyzer.analyze_cross_category(cat_dirs)
report = TransferAnalyzer.format_cross_category(result)
assert "Cross-Category" in report
def test_format_universality(self):
analyzer = TransferAnalyzer()
report_obj = analyzer.compute_universality_index()
report = TransferAnalyzer.format_universality(report_obj)
assert "Universality" in report
def test_dimension_mismatch_handled(self):
"""Cross-model with different hidden dims should truncate."""
dirs_a = {0: torch.randn(32), 1: torch.randn(32)}
dirs_b = {0: torch.randn(64), 1: torch.randn(64)}
analyzer = TransferAnalyzer()
result = analyzer.analyze_cross_model(dirs_a, dirs_b)
assert len(result.per_layer_transfer) == 2
# ===========================================================================
# Tests: Integration
# ===========================================================================
class TestNewImports:
def test_all_new_modules_importable(self):
from obliteratus.analysis import (
CausalRefusalTracer,
ResidualStreamDecomposer,
LinearRefusalProbe,
TransferAnalyzer,
)
assert CausalRefusalTracer is not None
assert ResidualStreamDecomposer is not None
assert LinearRefusalProbe is not None
assert TransferAnalyzer is not None