obliteratus / tests /test_novel_analysis.py
pliny-the-prompter's picture
Upload 127 files
45113e6 verified
"""Tests for analysis techniques: concept cones, alignment imprints,
multi-token position, and sparse direction surgery."""
from __future__ import annotations
import torch
from obliteratus.analysis.concept_geometry import (
ConceptConeAnalyzer,
ConeConeResult,
MultiLayerConeResult,
CategoryDirection,
DEFAULT_HARM_CATEGORIES,
)
from obliteratus.analysis.alignment_imprint import (
AlignmentImprintDetector,
AlignmentImprint,
BaseInstructDelta,
)
from obliteratus.analysis.multi_token_position import (
MultiTokenPositionAnalyzer,
PositionAnalysisResult,
MultiTokenSummary,
)
from obliteratus.analysis.sparse_surgery import (
SparseDirectionSurgeon,
SparseProjectionResult,
SparseSurgeryPlan,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_category_activations(
hidden_dim=32, n_prompts=30, n_categories=5, category_spread=0.3,
):
"""Create synthetic activations with planted per-category refusal directions.
Each category gets its own refusal direction, with some shared component
to simulate a polyhedral cone structure.
"""
torch.manual_seed(42)
# Shared refusal component
shared = torch.randn(hidden_dim)
shared = shared / shared.norm()
# Per-category unique components
cat_dirs = {}
categories = [f"cat_{i}" for i in range(n_categories)]
for cat in categories:
unique = torch.randn(hidden_dim)
unique = unique / unique.norm()
combined = shared + category_spread * unique
cat_dirs[cat] = combined / combined.norm()
# Assign prompts to categories
prompts_per_cat = n_prompts // n_categories
category_map = {}
for i, cat in enumerate(categories):
for j in range(prompts_per_cat):
category_map[i * prompts_per_cat + j] = cat
actual_n = prompts_per_cat * n_categories
# Generate activations
harmful_acts = []
harmless_acts = []
for idx in range(actual_n):
cat = category_map[idx]
base = torch.randn(hidden_dim) * 0.1
harmful_acts.append(base + 2.0 * cat_dirs[cat])
harmless_acts.append(base)
return harmful_acts, harmless_acts, category_map, cat_dirs
def _make_refusal_directions(n_layers=8, hidden_dim=32, concentration="distributed"):
"""Create synthetic refusal directions with specified concentration pattern."""
torch.manual_seed(123)
directions = {}
strengths = {}
for i in range(n_layers):
d = torch.randn(hidden_dim)
directions[i] = d / d.norm()
if concentration == "concentrated":
# Strong in last few layers only (SFT-like)
strengths[i] = 3.0 if i >= n_layers - 2 else 0.1
elif concentration == "distributed":
# Even across layers (RLHF-like)
strengths[i] = 1.0 + 0.2 * torch.randn(1).item()
elif concentration == "orthogonal":
# Each layer direction is more orthogonal (CAI-like)
if i > 0:
# Make each direction more orthogonal to previous
prev = directions[i - 1]
d = d - (d @ prev) * prev
d = d / d.norm().clamp(min=1e-8)
directions[i] = d
strengths[i] = 1.5
else:
strengths[i] = 2.0 if 2 <= i <= 4 else 0.5
return directions, strengths
# ===========================================================================
# Tests: Concept Cone Geometry
# ===========================================================================
class TestConceptConeAnalyzer:
def test_basic_analysis(self):
harmful, harmless, cat_map, _ = _make_category_activations()
analyzer = ConceptConeAnalyzer(category_map=cat_map)
result = analyzer.analyze_layer(harmful, harmless, layer_idx=5)
assert isinstance(result, ConeConeResult)
assert result.layer_idx == 5
assert result.category_count >= 2
assert result.cone_dimensionality > 0
assert result.cone_solid_angle >= 0
assert 0 <= result.mean_pairwise_cosine <= 1.0
def test_polyhedral_detection(self):
"""With spread-out categories, should detect polyhedral geometry."""
harmful, harmless, cat_map, _ = _make_category_activations(
category_spread=2.0, # Large spread -> distinct directions
)
analyzer = ConceptConeAnalyzer(category_map=cat_map)
result = analyzer.analyze_layer(harmful, harmless)
# With high spread, directions should be more distinct
assert result.cone_dimensionality > 1.0
def test_linear_detection(self):
"""With no spread, should detect linear (single direction) geometry."""
harmful, harmless, cat_map, _ = _make_category_activations(
category_spread=0.0, # No spread -> all directions aligned
)
analyzer = ConceptConeAnalyzer(category_map=cat_map)
result = analyzer.analyze_layer(harmful, harmless)
assert result.mean_pairwise_cosine > 0.8
def test_category_directions_populated(self):
harmful, harmless, cat_map, _ = _make_category_activations()
analyzer = ConceptConeAnalyzer(category_map=cat_map)
result = analyzer.analyze_layer(harmful, harmless)
for cd in result.category_directions:
assert isinstance(cd, CategoryDirection)
assert cd.strength > 0
assert cd.n_prompts >= 2
assert 0 <= cd.specificity <= 1.0
def test_pairwise_cosines(self):
harmful, harmless, cat_map, _ = _make_category_activations()
analyzer = ConceptConeAnalyzer(category_map=cat_map)
result = analyzer.analyze_layer(harmful, harmless)
for (a, b), cos in result.pairwise_cosines.items():
assert 0 <= cos <= 1.0
assert a < b # Sorted pair
def test_general_direction_unit(self):
harmful, harmless, cat_map, _ = _make_category_activations()
analyzer = ConceptConeAnalyzer(category_map=cat_map)
result = analyzer.analyze_layer(harmful, harmless)
assert abs(result.general_direction.norm().item() - 1.0) < 0.01
def test_multi_layer_analysis(self):
harmful, harmless, cat_map, _ = _make_category_activations()
harmful_by_layer = {i: harmful for i in range(4)}
harmless_by_layer = {i: harmless for i in range(4)}
analyzer = ConceptConeAnalyzer(category_map=cat_map)
result = analyzer.analyze_all_layers(harmful_by_layer, harmless_by_layer)
assert isinstance(result, MultiLayerConeResult)
assert len(result.per_layer) == 4
assert result.mean_cone_dimensionality > 0
def test_format_report(self):
harmful, harmless, cat_map, _ = _make_category_activations()
analyzer = ConceptConeAnalyzer(category_map=cat_map)
result = analyzer.analyze_layer(harmful, harmless, layer_idx=3)
report = ConceptConeAnalyzer.format_report(result)
assert "Concept Cone" in report
assert "Layer 3" in report
assert "dimensionality" in report
def test_default_category_map(self):
assert len(DEFAULT_HARM_CATEGORIES) == 30
cats = set(DEFAULT_HARM_CATEGORIES.values())
assert "weapons" in cats
assert "cyber" in cats
def test_empty_activations(self):
analyzer = ConceptConeAnalyzer()
result = analyzer.analyze_layer([], [], layer_idx=0)
assert result.category_count == 0
def test_min_category_size(self):
"""Categories with too few prompts should be excluded."""
harmful, harmless, cat_map, _ = _make_category_activations(
n_prompts=10, n_categories=5,
)
analyzer = ConceptConeAnalyzer(category_map=cat_map, min_category_size=3)
result = analyzer.analyze_layer(harmful, harmless)
# Each category has only 2 prompts, so with min_size=3 all are excluded
assert result.category_count == 0
# ===========================================================================
# Tests: Alignment Imprint Detector
# ===========================================================================
class TestAlignmentImprintDetector:
def test_basic_detection(self):
directions, strengths = _make_refusal_directions()
detector = AlignmentImprintDetector()
imprint = detector.detect_imprint(directions, strengths)
assert isinstance(imprint, AlignmentImprint)
assert imprint.predicted_method in ("dpo", "rlhf", "cai", "sft")
assert 0 <= imprint.confidence <= 1.0
def test_probabilities_sum_to_one(self):
directions, strengths = _make_refusal_directions()
detector = AlignmentImprintDetector()
imprint = detector.detect_imprint(directions, strengths)
total = (imprint.dpo_probability + imprint.rlhf_probability +
imprint.cai_probability + imprint.sft_probability)
assert abs(total - 1.0) < 0.01
def test_concentrated_detects_sft_or_dpo(self):
"""Concentrated refusal (tail-biased) should predict SFT or DPO."""
directions, strengths = _make_refusal_directions(concentration="concentrated")
detector = AlignmentImprintDetector()
imprint = detector.detect_imprint(directions, strengths)
# SFT and DPO both have concentrated signatures
assert imprint.predicted_method in ("sft", "dpo")
def test_distributed_detects_not_sft(self):
"""Distributed refusal should not be predicted as SFT."""
directions, strengths = _make_refusal_directions(
n_layers=16, concentration="distributed",
)
detector = AlignmentImprintDetector()
imprint = detector.detect_imprint(directions, strengths)
# With distributed refusal, Gini is low -> SFT is unlikely to be top prediction
assert imprint.predicted_method != "sft"
def test_orthogonal_detects_cai(self):
"""Orthogonal layer directions should lean toward CAI."""
directions, strengths = _make_refusal_directions(
n_layers=12, concentration="orthogonal",
)
detector = AlignmentImprintDetector()
imprint = detector.detect_imprint(directions, strengths)
# CAI should rank highly due to orthogonality
assert imprint.cai_probability > 0.15
def test_feature_extraction(self):
directions, strengths = _make_refusal_directions()
detector = AlignmentImprintDetector()
imprint = detector.detect_imprint(directions, strengths)
assert 0 <= imprint.gini_coefficient <= 1.0
assert imprint.effective_rank > 0
assert 0 <= imprint.cross_layer_smoothness <= 1.0
assert 0 <= imprint.tail_layer_bias <= 1.0
assert 0 <= imprint.mean_pairwise_orthogonality <= 1.0
assert imprint.spectral_decay_rate >= 0
def test_empty_directions(self):
detector = AlignmentImprintDetector()
imprint = detector.detect_imprint({})
assert imprint.predicted_method == "unknown"
assert imprint.confidence == 0.0
def test_compare_base_instruct(self):
torch.manual_seed(42)
hidden_dim = 32
directions, _ = _make_refusal_directions(hidden_dim=hidden_dim)
base_acts = {i: torch.randn(hidden_dim) for i in range(8)}
instruct_acts = {
i: base_acts[i] + 1.5 * directions[i] for i in range(8)
}
detector = AlignmentImprintDetector()
deltas = detector.compare_base_instruct(base_acts, instruct_acts, directions)
assert len(deltas) == 8
for d in deltas:
assert isinstance(d, BaseInstructDelta)
assert d.delta_magnitude > 0
# Since delta IS the refusal direction, cosine should be high
assert abs(d.cosine_with_refusal) > 0.5
def test_format_imprint(self):
directions, strengths = _make_refusal_directions()
detector = AlignmentImprintDetector()
imprint = detector.detect_imprint(directions, strengths)
report = AlignmentImprintDetector.format_imprint(imprint)
assert "Alignment Imprint" in report
assert "DPO" in report
assert "RLHF" in report
assert "Gini" in report
def test_per_layer_strength_populated(self):
directions, strengths = _make_refusal_directions()
detector = AlignmentImprintDetector()
imprint = detector.detect_imprint(directions, strengths)
assert len(imprint.per_layer_strength) == len(directions)
# ===========================================================================
# Tests: Multi-Token Position Analysis
# ===========================================================================
class TestMultiTokenPositionAnalyzer:
def _make_activations_with_trigger(
self, seq_len=20, hidden_dim=32, trigger_pos=5,
):
"""Create activations with a planted trigger at a specific position."""
torch.manual_seed(42)
refusal_dir = torch.randn(hidden_dim)
refusal_dir = refusal_dir / refusal_dir.norm()
# Background activations
acts = torch.randn(seq_len, hidden_dim) * 0.1
# Strong refusal at trigger position
acts[trigger_pos] += 3.0 * refusal_dir
# Weaker refusal at last position
acts[-1] += 1.0 * refusal_dir
# Moderate at a few positions after trigger (decay)
for i in range(trigger_pos + 1, min(trigger_pos + 4, seq_len)):
decay = 0.5 ** (i - trigger_pos)
acts[i] += 3.0 * decay * refusal_dir
return acts, refusal_dir
def test_basic_analysis(self):
acts, ref_dir = self._make_activations_with_trigger()
analyzer = MultiTokenPositionAnalyzer()
result = analyzer.analyze_prompt(acts, ref_dir, layer_idx=3)
assert isinstance(result, PositionAnalysisResult)
assert result.layer_idx == 3
assert result.n_tokens == 20
assert result.peak_strength > 0
def test_trigger_detection(self):
acts, ref_dir = self._make_activations_with_trigger(trigger_pos=5)
analyzer = MultiTokenPositionAnalyzer(trigger_threshold=0.5)
result = analyzer.analyze_prompt(acts, ref_dir)
# The planted trigger should be detected
assert 5 in result.trigger_positions
assert result.peak_position == 5
def test_peak_vs_last(self):
"""Peak should be at trigger, not last token."""
acts, ref_dir = self._make_activations_with_trigger(trigger_pos=5)
analyzer = MultiTokenPositionAnalyzer()
result = analyzer.analyze_prompt(acts, ref_dir)
assert result.peak_strength > result.last_token_strength
assert result.peak_position != result.n_tokens - 1
def test_decay_rate_positive(self):
acts, ref_dir = self._make_activations_with_trigger(trigger_pos=5)
analyzer = MultiTokenPositionAnalyzer()
result = analyzer.analyze_prompt(acts, ref_dir)
# With exponential decay planted, decay rate should be positive
assert result.decay_rate > 0
def test_position_gini_bounded(self):
acts, ref_dir = self._make_activations_with_trigger()
analyzer = MultiTokenPositionAnalyzer()
result = analyzer.analyze_prompt(acts, ref_dir)
assert 0 <= result.position_gini <= 1.0
def test_token_profiles_length(self):
acts, ref_dir = self._make_activations_with_trigger(seq_len=15)
analyzer = MultiTokenPositionAnalyzer()
result = analyzer.analyze_prompt(acts, ref_dir)
assert len(result.token_profiles) == 15
def test_custom_token_texts(self):
acts, ref_dir = self._make_activations_with_trigger(seq_len=10, trigger_pos=3)
tokens = ["How", "to", "make", "a", "bomb", "from", "scratch", "please", "help", "me"]
analyzer = MultiTokenPositionAnalyzer()
result = analyzer.analyze_prompt(acts, ref_dir, token_texts=tokens)
for tp in result.token_profiles:
assert tp.token_text in tokens or tp.token_text.startswith("pos_")
def test_batch_analysis(self):
batch = []
for i in range(5):
acts, ref_dir = self._make_activations_with_trigger(
trigger_pos=3 + i % 3,
)
batch.append(acts)
analyzer = MultiTokenPositionAnalyzer()
summary = analyzer.analyze_batch(batch, ref_dir)
assert isinstance(summary, MultiTokenSummary)
assert len(summary.per_prompt) == 5
assert summary.mean_peak_vs_last_ratio > 0
assert summary.mean_trigger_count > 0
assert 0 <= summary.peak_is_last_fraction <= 1.0
assert 0 <= summary.last_token_dominance <= 1.0
def test_last_token_dominant_case(self):
"""When signal is only at last token, peak should equal last."""
torch.manual_seed(42)
hidden_dim = 32
seq_len = 10
ref_dir = torch.randn(hidden_dim)
ref_dir = ref_dir / ref_dir.norm()
acts = torch.randn(seq_len, hidden_dim) * 0.01
acts[-1] += 5.0 * ref_dir
analyzer = MultiTokenPositionAnalyzer()
result = analyzer.analyze_prompt(acts, ref_dir)
assert result.peak_position == seq_len - 1
def test_format_position_report(self):
acts, ref_dir = self._make_activations_with_trigger()
analyzer = MultiTokenPositionAnalyzer()
result = analyzer.analyze_prompt(acts, ref_dir, prompt_text="How to hack?")
report = MultiTokenPositionAnalyzer.format_position_report(result)
assert "Multi-Token" in report
assert "Peak position" in report
def test_format_summary(self):
batch = []
for _ in range(3):
acts, ref_dir = self._make_activations_with_trigger()
batch.append(acts)
analyzer = MultiTokenPositionAnalyzer()
summary = analyzer.analyze_batch(batch, ref_dir)
report = MultiTokenPositionAnalyzer.format_summary(summary)
assert "Summary" in report
assert "Prompts analyzed" in report
def test_3d_activations_handled(self):
"""Should handle (1, seq_len, hidden_dim) inputs."""
acts, ref_dir = self._make_activations_with_trigger()
acts = acts.unsqueeze(0) # Add batch dim
analyzer = MultiTokenPositionAnalyzer()
result = analyzer.analyze_prompt(acts, ref_dir)
assert result.n_tokens == 20
def test_empty_batch(self):
ref_dir = torch.randn(32)
analyzer = MultiTokenPositionAnalyzer()
summary = analyzer.analyze_batch([], ref_dir)
assert len(summary.per_prompt) == 0
assert summary.peak_is_last_fraction == 1.0
# ===========================================================================
# Tests: Sparse Direction Surgery
# ===========================================================================
class TestSparseDirectionSurgeon:
def _make_weight_with_sparse_refusal(
self, out_dim=64, in_dim=32, n_refusal_rows=5,
):
"""Create a weight matrix where refusal is concentrated in a few rows."""
torch.manual_seed(42)
refusal_dir = torch.randn(in_dim)
refusal_dir = refusal_dir / refusal_dir.norm()
W = torch.randn(out_dim, in_dim) * 0.1
# Plant strong refusal signal in specific rows
refusal_rows = list(range(n_refusal_rows))
for i in refusal_rows:
W[i] += 5.0 * refusal_dir
return W, refusal_dir, refusal_rows
def test_basic_analysis(self):
W, ref_dir, _ = self._make_weight_with_sparse_refusal()
surgeon = SparseDirectionSurgeon(sparsity=0.1)
result = surgeon.analyze_weight_matrix(W, ref_dir, layer_idx=3)
assert isinstance(result, SparseProjectionResult)
assert result.layer_idx == 3
assert result.n_rows_total == 64
assert result.n_rows_modified > 0
assert result.mean_projection > 0
assert result.max_projection > result.mean_projection
def test_refusal_sparsity_index(self):
"""With sparse refusal, RSI should be high."""
W, ref_dir, _ = self._make_weight_with_sparse_refusal(
out_dim=100, n_refusal_rows=5,
)
surgeon = SparseDirectionSurgeon()
result = surgeon.analyze_weight_matrix(W, ref_dir)
assert result.refusal_sparsity_index > 0.3 # Concentrated signal
def test_energy_removed(self):
"""Top rows should capture most of the refusal energy."""
W, ref_dir, _ = self._make_weight_with_sparse_refusal(
out_dim=64, n_refusal_rows=5,
)
surgeon = SparseDirectionSurgeon(sparsity=0.15) # ~10 rows out of 64
result = surgeon.analyze_weight_matrix(W, ref_dir)
# With 5 refusal rows and 10 modified, should capture most energy
assert result.energy_removed > 0.5
def test_frobenius_change_bounded(self):
W, ref_dir, _ = self._make_weight_with_sparse_refusal()
surgeon = SparseDirectionSurgeon(sparsity=0.1)
result = surgeon.analyze_weight_matrix(W, ref_dir)
assert result.frobenius_change > 0
assert result.frobenius_change < 1.0 # Shouldn't change more than 100%
def test_apply_sparse_projection(self):
"""Sparse projection should reduce refusal signal."""
W, ref_dir, _ = self._make_weight_with_sparse_refusal()
surgeon = SparseDirectionSurgeon(sparsity=0.1)
W_modified = surgeon.apply_sparse_projection(W, ref_dir)
# Check that modified rows have reduced projection
original_proj = (W @ ref_dir).abs().sum().item()
modified_proj = (W_modified @ ref_dir).abs().sum().item()
assert modified_proj < original_proj
def test_sparse_preserves_unmodified_rows(self):
"""Rows below the threshold should be unchanged."""
W, ref_dir, refusal_rows = self._make_weight_with_sparse_refusal(
out_dim=64, n_refusal_rows=5,
)
surgeon = SparseDirectionSurgeon(sparsity=0.1) # ~6 rows
W_modified = surgeon.apply_sparse_projection(W, ref_dir)
# Count rows that actually changed
diffs = (W - W_modified).abs().sum(dim=1)
n_changed = (diffs > 1e-6).sum().item()
n_unchanged = (diffs < 1e-6).sum().item()
assert n_changed <= int(0.1 * 64) + 1 # Sparsity bound
assert n_unchanged >= 57 # Most rows unchanged
def test_dense_vs_sparse_comparison(self):
"""Dense projection should modify all rows; sparse should modify fewer."""
W, ref_dir, _ = self._make_weight_with_sparse_refusal()
# Dense projection
r = ref_dir / ref_dir.norm()
W_dense = W - (W @ r).unsqueeze(1) * r.unsqueeze(0)
# Sparse projection
surgeon = SparseDirectionSurgeon(sparsity=0.1)
W_sparse = surgeon.apply_sparse_projection(W, ref_dir)
dense_changes = (W - W_dense).abs().sum(dim=1)
sparse_changes = (W - W_sparse).abs().sum(dim=1)
n_dense_changed = (dense_changes > 1e-6).sum().item()
n_sparse_changed = (sparse_changes > 1e-6).sum().item()
assert n_sparse_changed < n_dense_changed
def test_plan_surgery(self):
weights = {}
directions = {}
for i in range(6):
W, ref_dir, _ = self._make_weight_with_sparse_refusal()
weights[i] = W
directions[i] = ref_dir
surgeon = SparseDirectionSurgeon(sparsity=0.1)
plan = surgeon.plan_surgery(weights, directions)
assert isinstance(plan, SparseSurgeryPlan)
assert len(plan.per_layer) == 6
assert 0 < plan.recommended_sparsity < 1.0
assert plan.mean_refusal_sparsity_index > 0
assert plan.mean_energy_removed > 0
def test_auto_sparsity(self):
W, ref_dir, _ = self._make_weight_with_sparse_refusal()
surgeon = SparseDirectionSurgeon(auto_sparsity=True)
result = surgeon.analyze_weight_matrix(W, ref_dir)
# Auto sparsity should find a reasonable value
assert 0.01 <= result.sparsity <= 0.5
def test_auto_sparsity_apply(self):
W, ref_dir, _ = self._make_weight_with_sparse_refusal()
surgeon = SparseDirectionSurgeon(auto_sparsity=True)
W_modified = surgeon.apply_sparse_projection(W, ref_dir)
# Should reduce projection
assert (W_modified @ ref_dir).abs().sum() < (W @ ref_dir).abs().sum()
def test_format_analysis(self):
W, ref_dir, _ = self._make_weight_with_sparse_refusal()
surgeon = SparseDirectionSurgeon(sparsity=0.1)
result = surgeon.analyze_weight_matrix(W, ref_dir, layer_idx=4)
report = SparseDirectionSurgeon.format_analysis(result)
assert "Sparse Direction Surgery" in report
assert "Layer 4" in report
assert "Refusal Sparsity Index" in report
def test_format_plan(self):
weights = {i: torch.randn(32, 16) for i in range(4)}
directions = {i: torch.randn(16) for i in range(4)}
surgeon = SparseDirectionSurgeon(sparsity=0.1)
plan = surgeon.plan_surgery(weights, directions)
report = SparseDirectionSurgeon.format_plan(plan)
assert "Sparse Direction Surgery Plan" in report
assert "Recommended sparsity" in report
def test_empty_inputs(self):
surgeon = SparseDirectionSurgeon()
plan = surgeon.plan_surgery({}, {})
assert len(plan.per_layer) == 0
def test_output_dtype_preserved(self):
"""Output should match input dtype."""
W, ref_dir, _ = self._make_weight_with_sparse_refusal()
W_half = W.half()
surgeon = SparseDirectionSurgeon(sparsity=0.1)
W_out = surgeon.apply_sparse_projection(W_half, ref_dir)
assert W_out.dtype == torch.float16
# ===========================================================================
# Tests: Integration / Imports
# ===========================================================================
class TestAnalysisImports:
def test_all_new_modules_importable(self):
from obliteratus.analysis import (
ConceptConeAnalyzer,
AlignmentImprintDetector,
MultiTokenPositionAnalyzer,
SparseDirectionSurgeon,
)
assert ConceptConeAnalyzer is not None
assert AlignmentImprintDetector is not None
assert MultiTokenPositionAnalyzer is not None
assert SparseDirectionSurgeon is not None