obliteratus / tests /test_analysis.py
pliny-the-prompter's picture
Upload 127 files
45113e6 verified
"""Tests for the analysis techniques."""
from __future__ import annotations
import torch
from obliteratus.analysis.whitened_svd import WhitenedSVDExtractor, WhitenedSVDResult
from obliteratus.analysis.cross_layer import CrossLayerAlignmentAnalyzer, CrossLayerResult
from obliteratus.analysis.activation_probing import ActivationProbe, ProbeResult
# ---------------------------------------------------------------------------
# WhitenedSVDExtractor
# ---------------------------------------------------------------------------
class TestWhitenedSVD:
def test_basic_extraction(self):
"""Whitened SVD should extract directions from activation differences."""
torch.manual_seed(42)
n_prompts, hidden_dim = 10, 32
# Create activations with a clear refusal direction
refusal_dir = torch.randn(hidden_dim)
refusal_dir = refusal_dir / refusal_dir.norm()
harmless = [torch.randn(hidden_dim) for _ in range(n_prompts)]
harmful = [h + 2.0 * refusal_dir for h in harmless] # shifted along refusal dir
extractor = WhitenedSVDExtractor()
result = extractor.extract(harmful, harmless, n_directions=3)
assert isinstance(result, WhitenedSVDResult)
assert result.directions.shape == (3, hidden_dim)
assert result.singular_values.shape == (3,)
assert result.variance_explained > 0
assert result.condition_number > 0
assert result.effective_rank > 0
def test_directions_are_unit_vectors(self):
"""Extracted directions should be unit length."""
torch.manual_seed(42)
harmless = [torch.randn(16) for _ in range(8)]
harmful = [h + torch.randn(16) * 0.5 for h in harmless]
extractor = WhitenedSVDExtractor()
result = extractor.extract(harmful, harmless, n_directions=2)
for i in range(result.directions.shape[0]):
assert abs(result.directions[i].norm().item() - 1.0) < 1e-4
def test_primary_aligns_with_planted_direction(self):
"""Primary whitened direction should capture the planted refusal signal.
Whitening rotates directions relative to the covariance structure,
so perfect alignment with the raw direction is not expected. We verify
the whitened direction explains substantial variance and has moderate
alignment (whitening intentionally reweights dimensions).
"""
torch.manual_seed(42)
hidden_dim = 64
n_prompts = 30
refusal_dir = torch.randn(hidden_dim)
refusal_dir = refusal_dir / refusal_dir.norm()
# Isotropic harmless activations (whitening has minimal effect)
harmless = [torch.randn(hidden_dim) * 0.1 for _ in range(n_prompts)]
harmful = [h + 5.0 * refusal_dir for h in harmless]
extractor = WhitenedSVDExtractor(regularization_eps=1e-3)
result = extractor.extract(harmful, harmless, n_directions=1)
cos_sim = (result.directions[0] @ refusal_dir).abs().item()
# Moderate alignment expected (whitening reweights dimensions)
assert cos_sim > 0.2, f"Expected alignment > 0.2, got {cos_sim:.3f}"
# More importantly: the direction should explain most variance
assert result.variance_explained > 0.5
def test_extract_all_layers(self):
"""Should extract directions for all provided layers."""
torch.manual_seed(42)
harmful_acts = {}
harmless_acts = {}
for layer in range(4):
harmful_acts[layer] = [torch.randn(16) for _ in range(5)]
harmless_acts[layer] = [torch.randn(16) for _ in range(5)]
extractor = WhitenedSVDExtractor()
results = extractor.extract_all_layers(harmful_acts, harmless_acts, n_directions=2)
assert len(results) == 4
for idx in range(4):
assert idx in results
assert results[idx].directions.shape[0] == 2
def test_compare_with_standard(self):
"""Comparison should return valid cosine similarities."""
torch.manual_seed(42)
harmless = [torch.randn(16) for _ in range(8)]
harmful = [h + torch.randn(16) for h in harmless]
extractor = WhitenedSVDExtractor()
result = extractor.extract(harmful, harmless, n_directions=2)
std_dir = torch.randn(16)
std_dir = std_dir / std_dir.norm()
comparison = WhitenedSVDExtractor.compare_with_standard(result, std_dir)
assert "primary_direction_cosine" in comparison
assert "subspace_principal_cosine" in comparison
assert 0 <= comparison["primary_direction_cosine"] <= 1.0
def test_handles_3d_activations(self):
"""Should handle activations with an extra batch dimension."""
torch.manual_seed(42)
# (1, hidden_dim) shape from hook output
harmless = [torch.randn(1, 16) for _ in range(5)]
harmful = [torch.randn(1, 16) for _ in range(5)]
extractor = WhitenedSVDExtractor()
result = extractor.extract(harmful, harmless, n_directions=2)
assert result.directions.shape == (2, 16)
def test_variance_explained_bounded(self):
"""Variance explained should be between 0 and 1."""
torch.manual_seed(42)
harmless = [torch.randn(16) for _ in range(8)]
harmful = [torch.randn(16) for _ in range(8)]
extractor = WhitenedSVDExtractor()
result = extractor.extract(harmful, harmless, n_directions=3)
assert 0 <= result.variance_explained <= 1.0
# ---------------------------------------------------------------------------
# CrossLayerAlignmentAnalyzer
# ---------------------------------------------------------------------------
class TestCrossLayerAlignment:
def test_identical_directions(self):
"""Identical directions across layers should give persistence = 1."""
direction = torch.randn(32)
direction = direction / direction.norm()
directions = {i: direction.clone() for i in range(5)}
analyzer = CrossLayerAlignmentAnalyzer()
result = analyzer.analyze(directions)
assert isinstance(result, CrossLayerResult)
assert result.direction_persistence_score > 0.99
assert result.mean_adjacent_cosine > 0.99
assert result.total_geodesic_distance < 0.01
def test_orthogonal_directions(self):
"""Orthogonal directions should give low persistence."""
# Create orthogonal directions via QR decomposition
torch.manual_seed(42)
M = torch.randn(5, 32)
Q, _ = torch.linalg.qr(M.T)
directions = {i: Q[:, i] for i in range(5)}
analyzer = CrossLayerAlignmentAnalyzer()
result = analyzer.analyze(directions)
assert result.direction_persistence_score < 0.3
assert result.mean_adjacent_cosine < 0.3
def test_cluster_detection(self):
"""Should detect clusters of similar directions."""
torch.manual_seed(42)
# Create two clusters
d1 = torch.randn(32)
d1 = d1 / d1.norm()
d2 = torch.randn(32)
d2 = d2 / d2.norm()
directions = {
0: d1, 1: d1 + 0.01 * torch.randn(32),
2: d1 + 0.01 * torch.randn(32),
3: d2, 4: d2 + 0.01 * torch.randn(32),
}
# Normalize
directions = {k: v / v.norm() for k, v in directions.items()}
analyzer = CrossLayerAlignmentAnalyzer(cluster_threshold=0.9)
result = analyzer.analyze(directions)
# Should find at least 2 clusters
assert result.cluster_count >= 2
def test_empty_input(self):
"""Should handle empty input gracefully."""
analyzer = CrossLayerAlignmentAnalyzer()
result = analyzer.analyze({})
assert result.layer_indices == []
assert result.cluster_count == 0
def test_single_layer(self):
"""Single layer should work fine."""
analyzer = CrossLayerAlignmentAnalyzer()
result = analyzer.analyze({5: torch.randn(16)})
assert result.layer_indices == [5]
assert result.direction_persistence_score == 1.0
def test_strong_layers_filter(self):
"""Should only analyze specified strong layers."""
directions = {i: torch.randn(16) for i in range(10)}
analyzer = CrossLayerAlignmentAnalyzer()
result = analyzer.analyze(directions, strong_layers=[2, 5, 7])
assert result.layer_indices == [2, 5, 7]
assert result.cosine_matrix.shape == (3, 3)
def test_cosine_matrix_symmetry(self):
"""Cosine matrix should be symmetric."""
torch.manual_seed(42)
directions = {i: torch.randn(16) for i in range(4)}
analyzer = CrossLayerAlignmentAnalyzer()
result = analyzer.analyze(directions)
diff = (result.cosine_matrix - result.cosine_matrix.T).abs().max().item()
assert diff < 1e-5
def test_cosine_matrix_diagonal_ones(self):
"""Diagonal of cosine matrix should be 1.0."""
torch.manual_seed(42)
directions = {i: torch.randn(16) for i in range(4)}
analyzer = CrossLayerAlignmentAnalyzer()
result = analyzer.analyze(directions)
for i in range(4):
assert abs(result.cosine_matrix[i, i].item() - 1.0) < 1e-4
def test_angular_drift_monotonic(self):
"""Angular drift should be monotonically non-decreasing."""
torch.manual_seed(42)
directions = {i: torch.randn(16) for i in range(6)}
analyzer = CrossLayerAlignmentAnalyzer()
result = analyzer.analyze(directions)
for i in range(len(result.angular_drift) - 1):
assert result.angular_drift[i + 1] >= result.angular_drift[i] - 1e-6
def test_format_report(self):
"""Format report should produce a non-empty string."""
torch.manual_seed(42)
directions = {i: torch.randn(16) for i in range(4)}
analyzer = CrossLayerAlignmentAnalyzer()
result = analyzer.analyze(directions)
report = CrossLayerAlignmentAnalyzer.format_report(result)
assert "Cross-Layer" in report
assert "persistence" in report
# ---------------------------------------------------------------------------
# ActivationProbe
# ---------------------------------------------------------------------------
class TestActivationProbe:
def test_clean_elimination(self):
"""After removing direction, projections should be near-zero."""
torch.manual_seed(42)
hidden_dim = 32
refusal_dir = torch.randn(hidden_dim)
refusal_dir = refusal_dir / refusal_dir.norm()
# "Post-abliteration" activations: direction has been removed
harmless = [torch.randn(hidden_dim) for _ in range(10)]
harmful = [torch.randn(hidden_dim) for _ in range(10)]
# Both sets are random, no refusal signal => gap should be small
probe = ActivationProbe()
result = probe.probe_layer(harmful, harmless, refusal_dir)
assert abs(result.projection_gap) < 1.0
assert result.separation_d_prime < 2.0
def test_residual_detection(self):
"""Should detect residual refusal signal when direction wasn't removed."""
torch.manual_seed(42)
hidden_dim = 32
refusal_dir = torch.randn(hidden_dim)
refusal_dir = refusal_dir / refusal_dir.norm()
harmless = [torch.randn(hidden_dim) for _ in range(10)]
# Harmful still has strong refusal direction component
harmful = [h + 5.0 * refusal_dir for h in harmless]
probe = ActivationProbe()
result = probe.probe_layer(harmful, harmless, refusal_dir)
assert abs(result.projection_gap) > 1.0
assert result.separation_d_prime > 2.0
def test_probe_all_layers(self):
"""Should compute aggregate metrics across layers."""
torch.manual_seed(42)
hidden_dim = 16
n_layers = 4
harmful_acts = {}
harmless_acts = {}
refusal_dirs = {}
for layer in range(n_layers):
harmful_acts[layer] = [torch.randn(hidden_dim) for _ in range(5)]
harmless_acts[layer] = [torch.randn(hidden_dim) for _ in range(5)]
d = torch.randn(hidden_dim)
refusal_dirs[layer] = d / d.norm()
probe = ActivationProbe()
result = probe.probe_all_layers(harmful_acts, harmless_acts, refusal_dirs)
assert isinstance(result, ProbeResult)
assert len(result.per_layer) == n_layers
assert 0 <= result.refusal_elimination_score <= 1.0
assert result.mean_projection_gap >= 0
def test_res_score_range(self):
"""RES should always be between 0 and 1."""
torch.manual_seed(42)
for seed in range(5):
torch.manual_seed(seed)
harmful = {0: [torch.randn(8) for _ in range(3)]}
harmless = {0: [torch.randn(8) for _ in range(3)]}
dirs = {0: torch.randn(8)}
dirs[0] = dirs[0] / dirs[0].norm()
probe = ActivationProbe()
result = probe.probe_all_layers(harmful, harmless, dirs)
assert 0 <= result.refusal_elimination_score <= 1.0
def test_format_report(self):
"""Format report should produce readable output."""
torch.manual_seed(42)
harmful = {0: [torch.randn(8) for _ in range(3)]}
harmless = {0: [torch.randn(8) for _ in range(3)]}
dirs = {0: torch.randn(8)}
probe = ActivationProbe()
result = probe.probe_all_layers(harmful, harmless, dirs)
report = ActivationProbe.format_report(result)
assert "Refusal Elimination Score" in report
def test_empty_input(self):
"""Should handle empty input gracefully."""
probe = ActivationProbe()
result = probe.probe_all_layers({}, {}, {})
assert result.refusal_elimination_score == 0.0
assert len(result.per_layer) == 0