Spaces:
Running
Running
| """Tests for the SOTA abliteration pipeline.""" | |
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from unittest.mock import MagicMock | |
| import pytest | |
| import torch | |
| from transformers import GPT2Config, GPT2LMHeadModel | |
| from obliteratus.abliterate import ( | |
| HARMFUL_PROMPTS, | |
| HARMLESS_PROMPTS, | |
| METHODS, | |
| STAGES, | |
| AbliterationPipeline, | |
| PipelineStage, | |
| StageResult, | |
| ) | |
| from obliteratus.models.loader import ModelHandle | |
| # --------------------------------------------------------------------------- | |
| # Fixtures | |
| # --------------------------------------------------------------------------- | |
| def _make_tiny_handle(): | |
| """Create a minimal ModelHandle with a tiny GPT-2 for testing.""" | |
| config = GPT2Config( | |
| vocab_size=1000, | |
| n_positions=128, | |
| n_embd=64, | |
| n_layer=4, | |
| n_head=2, | |
| n_inner=256, | |
| ) | |
| model = GPT2LMHeadModel(config) | |
| model.eval() | |
| tokenizer = MagicMock() | |
| tokenizer.pad_token = "<pad>" | |
| tokenizer.eos_token = "<eos>" | |
| tokenizer.return_value = { | |
| "input_ids": torch.randint(0, 1000, (1, 10)), | |
| "attention_mask": torch.ones(1, 10, dtype=torch.long), | |
| } | |
| tokenizer.decode.return_value = "The capital of France is Paris, a beautiful city" | |
| handle = ModelHandle( | |
| model=model, | |
| tokenizer=tokenizer, | |
| config=config, | |
| model_name="gpt2-test", | |
| task="causal_lm", | |
| ) | |
| handle.snapshot() | |
| return handle | |
| def _make_varied_tokenizer(handle): | |
| """Set up a tokenizer mock that returns different tokens per call.""" | |
| call_count = [0] | |
| def mock_tokenizer(prompt, **kwargs): | |
| call_count[0] += 1 | |
| torch.manual_seed(call_count[0]) | |
| return { | |
| "input_ids": torch.randint(0, 1000, (1, 5)), | |
| "attention_mask": torch.ones(1, 5, dtype=torch.long), | |
| } | |
| handle.tokenizer.side_effect = mock_tokenizer | |
| def handle(): | |
| return _make_tiny_handle() | |
| # --------------------------------------------------------------------------- | |
| # Data & stage definitions | |
| # --------------------------------------------------------------------------- | |
| class TestPrompts: | |
| def test_harmful_prompts_expanded(self): | |
| assert len(HARMFUL_PROMPTS) >= 99 | |
| def test_harmless_prompts_expanded(self): | |
| assert len(HARMLESS_PROMPTS) >= 99 | |
| def test_prompt_lists_same_length(self): | |
| assert len(HARMFUL_PROMPTS) == len(HARMLESS_PROMPTS) | |
| def test_prompt_count_512(self): | |
| """512 prompts across 7 severity tiers.""" | |
| assert len(HARMFUL_PROMPTS) == 512 | |
| assert len(HARMLESS_PROMPTS) == 512 | |
| def test_prompt_volume_slicing(self): | |
| """Slicing at standard volumes gives correct counts.""" | |
| for n in (33, 66, 99, 256, 512): | |
| assert len(HARMFUL_PROMPTS[:n]) == n | |
| assert len(HARMLESS_PROMPTS[:n]) == n | |
| class TestStages: | |
| def test_six_stages(self): | |
| assert len(STAGES) == 6 | |
| def test_stage_keys(self): | |
| keys = [s.key for s in STAGES] | |
| assert keys == ["summon", "probe", "distill", "excise", "verify", "rebirth"] | |
| def test_stage_dataclass(self): | |
| stage = PipelineStage(key="test", name="TEST", description="A test stage") | |
| assert stage.key == "test" | |
| assert stage.name == "TEST" | |
| def test_stage_result_defaults(self): | |
| result = StageResult(stage="test", status="running") | |
| assert result.message == "" | |
| assert result.duration == 0.0 | |
| assert result.details == {} | |
| # --------------------------------------------------------------------------- | |
| # Method presets | |
| # --------------------------------------------------------------------------- | |
| class TestMethods: | |
| def test_methods_exist(self): | |
| assert set(METHODS.keys()) == {"basic", "advanced", "aggressive", "informed", "surgical", "inverted", "nuclear", "optimized", "failspy", "gabliteration", "heretic", "rdo", "spectral_cascade"} | |
| def test_basic_single_direction(self): | |
| cfg = METHODS["basic"] | |
| assert cfg["n_directions"] == 1 | |
| assert cfg["norm_preserve"] is False | |
| assert cfg["regularization"] == 0.0 | |
| assert cfg["refinement_passes"] == 1 | |
| def test_advanced_multi_direction(self): | |
| cfg = METHODS["advanced"] | |
| assert cfg["n_directions"] > 1 | |
| assert cfg["norm_preserve"] is True | |
| assert cfg["regularization"] > 0 | |
| assert cfg["refinement_passes"] >= 2 | |
| def test_aggressive_full_gabliteration(self): | |
| cfg = METHODS["aggressive"] | |
| assert cfg["n_directions"] >= 8 | |
| assert cfg["norm_preserve"] is True | |
| assert cfg["refinement_passes"] >= 3 | |
| # --------------------------------------------------------------------------- | |
| # Pipeline init | |
| # --------------------------------------------------------------------------- | |
| class TestPipelineInit: | |
| def test_default_prompts(self): | |
| pipeline = AbliterationPipeline(model_name="test-model") | |
| assert pipeline.harmful_prompts == HARMFUL_PROMPTS | |
| assert pipeline.harmless_prompts == HARMLESS_PROMPTS | |
| def test_custom_prompts(self): | |
| harmful = ["bad prompt"] | |
| harmless = ["good prompt"] | |
| pipeline = AbliterationPipeline( | |
| model_name="test-model", | |
| harmful_prompts=harmful, | |
| harmless_prompts=harmless, | |
| ) | |
| assert pipeline.harmful_prompts == harmful | |
| assert pipeline.harmless_prompts == harmless | |
| def test_defaults(self): | |
| pipeline = AbliterationPipeline(model_name="test-model") | |
| assert pipeline.device == "auto" | |
| assert pipeline.dtype == "float16" | |
| assert pipeline.output_dir == Path("abliterated") | |
| assert pipeline.trust_remote_code is False | |
| assert pipeline.handle is None | |
| def test_default_method_is_advanced(self): | |
| pipeline = AbliterationPipeline(model_name="test-model") | |
| assert pipeline.method == "advanced" | |
| assert pipeline.n_directions == METHODS["advanced"]["n_directions"] | |
| assert pipeline.norm_preserve == METHODS["advanced"]["norm_preserve"] | |
| assert pipeline.regularization == METHODS["advanced"]["regularization"] | |
| def test_method_basic(self): | |
| pipeline = AbliterationPipeline(model_name="test-model", method="basic") | |
| assert pipeline.n_directions == 1 | |
| assert pipeline.norm_preserve is False | |
| assert pipeline.regularization == 0.0 | |
| def test_method_aggressive(self): | |
| pipeline = AbliterationPipeline(model_name="test-model", method="aggressive") | |
| assert pipeline.n_directions == 8 | |
| assert pipeline.norm_preserve is True | |
| assert pipeline.refinement_passes == 3 | |
| def test_explicit_overrides_method(self): | |
| pipeline = AbliterationPipeline( | |
| model_name="test-model", | |
| method="basic", | |
| n_directions=6, | |
| norm_preserve=True, | |
| regularization=0.5, | |
| refinement_passes=4, | |
| ) | |
| assert pipeline.n_directions == 6 | |
| assert pipeline.norm_preserve is True | |
| assert pipeline.regularization == 0.5 | |
| assert pipeline.refinement_passes == 4 | |
| def test_callbacks(self): | |
| stage_results = [] | |
| log_msgs = [] | |
| pipeline = AbliterationPipeline( | |
| model_name="test-model", | |
| on_stage=lambda r: stage_results.append(r), | |
| on_log=lambda m: log_msgs.append(m), | |
| ) | |
| pipeline.log("hello") | |
| assert log_msgs == ["hello"] | |
| pipeline._emit("test", "running", "msg") | |
| assert len(stage_results) == 1 | |
| assert stage_results[0].stage == "test" | |
| # --------------------------------------------------------------------------- | |
| # _project_out_advanced (norm-preserving + regularization) | |
| # --------------------------------------------------------------------------- | |
| class TestProjectOutAdvanced: | |
| def test_norm_preserving(self): | |
| """Norm-preserving mode should keep Frobenius norm constant.""" | |
| class Wrapper(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.o_proj = torch.nn.Linear(4, 8, bias=False) | |
| module = Wrapper() | |
| torch.manual_seed(42) | |
| module.o_proj.weight.data = torch.randn(8, 4) | |
| original_norm = module.o_proj.weight.data.norm().item() | |
| direction = torch.randn(4, 1) | |
| direction = direction / direction.norm() | |
| AbliterationPipeline._project_out_advanced( | |
| module, direction, ["o_proj"], norm_preserve=True, regularization=0.0 | |
| ) | |
| new_norm = module.o_proj.weight.data.norm().item() | |
| # With amplification cap (1.10x max), exact norm preservation isn't | |
| # guaranteed on tiny matrices (hidden_dim=4) where a single direction | |
| # removes a large fraction of energy. Verify the norm is closer to | |
| # original than the un-preserved norm would be (i.e. cap is working). | |
| without_preserve_norm_sq = original_norm ** 2 - (module.o_proj.weight.data @ direction).pow(2).sum().item() | |
| # The new norm should be >= the un-preserved norm (cap restores some) | |
| assert new_norm >= original_norm * 0.85, \ | |
| f"Norm should be approximately preserved (within cap): {original_norm:.4f} vs {new_norm:.4f}" | |
| def test_regularization_partial_removal(self): | |
| """Regularization should preserve some of the refusal component.""" | |
| class Wrapper(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.o_proj = torch.nn.Linear(4, 8, bias=False) | |
| module_full = Wrapper() | |
| module_reg = Wrapper() | |
| torch.manual_seed(42) | |
| W_orig = torch.randn(8, 4) | |
| module_full.o_proj.weight.data = W_orig.clone() | |
| module_reg.o_proj.weight.data = W_orig.clone() | |
| direction = torch.randn(4, 1) | |
| direction = direction / direction.norm() | |
| # Full removal | |
| AbliterationPipeline._project_out_advanced( | |
| module_full, direction, ["o_proj"], norm_preserve=False, regularization=0.0 | |
| ) | |
| # Regularized (30% preserved) | |
| AbliterationPipeline._project_out_advanced( | |
| module_reg, direction, ["o_proj"], norm_preserve=False, regularization=0.3 | |
| ) | |
| W_full = module_full.o_proj.weight.data | |
| W_reg = module_reg.o_proj.weight.data | |
| # Full removal should have zero projection on direction | |
| proj_full = (W_full @ direction).norm().item() | |
| assert proj_full < 1e-4 | |
| # Regularized should have non-zero projection (30% preserved) | |
| proj_reg = (W_reg @ direction).norm().item() | |
| proj_orig = (W_orig @ direction).norm().item() | |
| expected_ratio = 0.3 | |
| actual_ratio = proj_reg / proj_orig if proj_orig > 0 else 0 | |
| assert abs(actual_ratio - expected_ratio) < 0.05, \ | |
| f"Expected ~{expected_ratio:.0%} preserved, got {actual_ratio:.0%}" | |
| def test_norm_preserving_transposed(self): | |
| """Norm-preserving should also work for transposed weights.""" | |
| class Wrapper(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.c_proj = torch.nn.Linear(8, 4, bias=False) | |
| module = Wrapper() | |
| torch.manual_seed(42) | |
| module.c_proj.weight.data = torch.randn(4, 8) | |
| original_norm = module.c_proj.weight.data.norm().item() | |
| direction = torch.randn(4, 1) | |
| direction = direction / direction.norm() | |
| AbliterationPipeline._project_out_advanced( | |
| module, direction, ["c_proj"], norm_preserve=True, regularization=0.0 | |
| ) | |
| new_norm = module.c_proj.weight.data.norm().item() | |
| # With amplification cap (1.10x max), exact norm preservation isn't | |
| # guaranteed on tiny matrices where a single direction removes a large | |
| # fraction of energy. | |
| assert new_norm >= original_norm * 0.80, \ | |
| f"Norm should be approximately preserved (within cap): {original_norm:.4f} vs {new_norm:.4f}" | |
| # --------------------------------------------------------------------------- | |
| # Full attention projection (q/k/v + o_proj) | |
| # --------------------------------------------------------------------------- | |
| class TestAttentionFullProjection: | |
| """Test that ALL attention weight matrices are projected (not just o_proj).""" | |
| def test_qkv_all_projected(self): | |
| """q_proj, k_proj, v_proj should all be projected alongside o_proj.""" | |
| hidden = 16 | |
| class FakeAttn(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.q_proj = torch.nn.Linear(hidden, hidden, bias=False) | |
| self.k_proj = torch.nn.Linear(hidden, hidden, bias=False) | |
| self.v_proj = torch.nn.Linear(hidden, hidden, bias=False) | |
| self.o_proj = torch.nn.Linear(hidden, hidden, bias=False) | |
| attn = FakeAttn() | |
| torch.manual_seed(42) | |
| for p in attn.parameters(): | |
| p.data = torch.randn_like(p.data) | |
| originals = { | |
| name: getattr(attn, name).weight.data.clone() | |
| for name in ["q_proj", "k_proj", "v_proj", "o_proj"] | |
| } | |
| d = torch.randn(hidden, 1) | |
| d = d / d.norm() | |
| from obliteratus.abliterate import _ATTN_OUT_NAMES, _ATTN_IN_NAMES | |
| count = AbliterationPipeline._project_out_advanced( | |
| attn, d, _ATTN_OUT_NAMES + _ATTN_IN_NAMES, | |
| ) | |
| assert count == 4, f"Should project 4 weights (q/k/v/o), got {count}" | |
| for name in ["q_proj", "k_proj", "v_proj", "o_proj"]: | |
| assert not torch.allclose( | |
| getattr(attn, name).weight.data, originals[name] | |
| ), f"{name} should be modified" | |
| def test_project_all_does_not_early_return(self): | |
| """_project_out_advanced should project ALL matching weights, not just first.""" | |
| hidden = 16 | |
| class FakeModule(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.up_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| self.gate_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| mod = FakeModule() | |
| torch.manual_seed(42) | |
| orig_up = mod.up_proj.weight.data.clone() | |
| orig_gate = mod.gate_proj.weight.data.clone() | |
| d = torch.randn(hidden, 1) | |
| d = d / d.norm() | |
| from obliteratus.abliterate import _FFN_IN_NAMES | |
| count = AbliterationPipeline._project_out_advanced(mod, d, _FFN_IN_NAMES) | |
| assert count == 2, f"Should project both up_proj and gate_proj, got {count}" | |
| assert not torch.allclose(mod.up_proj.weight.data, orig_up), "up_proj should be modified" | |
| assert not torch.allclose(mod.gate_proj.weight.data, orig_gate), "gate_proj should be modified" | |
| def test_lm_head_projection(self): | |
| """lm_head should be projectable via _project_out_advanced.""" | |
| hidden = 16 | |
| vocab = 100 | |
| class FakeModel(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.lm_head = torch.nn.Linear(hidden, vocab, bias=False) | |
| model = FakeModel() | |
| torch.manual_seed(42) | |
| orig = model.lm_head.weight.data.clone() | |
| d = torch.randn(hidden, 1) | |
| d = d / d.norm() | |
| count = AbliterationPipeline._project_out_advanced( | |
| model, d, ["lm_head"], regularization=0.0, | |
| ) | |
| assert count == 1, "Should project lm_head" | |
| assert not torch.allclose(model.lm_head.weight.data, orig), "lm_head should be modified" | |
| # Verify refusal direction is removed from lm_head | |
| proj = (model.lm_head.weight.data @ d).norm().item() | |
| assert proj < 1e-4, f"Refusal direction should be removed from lm_head, proj={proj}" | |
| class TestKneeDetectionThreshold: | |
| """Test that knee detection uses 5% threshold to include more layers.""" | |
| def test_five_percent_threshold_includes_more(self): | |
| """Layers between 5% and 10% of max should now be included.""" | |
| # Layer norms: max=10.0, then several between 5%-10% | |
| sorted_layers = [(0, 10.0), (1, 8.0), (2, 6.0), (3, 0.7), (4, 0.6)] | |
| selected = AbliterationPipeline._select_layers_knee(sorted_layers) | |
| # 0.7 and 0.6 are 7% and 6% of max — should now be included (> 5% threshold) | |
| assert 3 in selected or 4 in selected, ( | |
| f"Layers with 6-7% of max signal should be included, got {selected}" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # MoE projection (router, shared expert, input/output, fused) | |
| # --------------------------------------------------------------------------- | |
| class TestProjectMoEExperts: | |
| """Test the full MoE projection pipeline: router, shared expert, experts.""" | |
| def _make_direction(self, hidden_dim=16): | |
| d = torch.randn(hidden_dim, 1) | |
| return d / d.norm() | |
| def test_router_gate_projected(self): | |
| """Router/gate weight should have refusal direction removed.""" | |
| hidden = 16 | |
| n_experts = 4 | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.gate = torch.nn.Linear(hidden, n_experts, bias=True) | |
| self.experts = torch.nn.ModuleList([ | |
| self._make_expert() for _ in range(n_experts) | |
| ]) | |
| def _make_expert(): | |
| m = torch.nn.Module() | |
| m.down_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| m.up_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| return m | |
| moe = FakeMoE() | |
| d = self._make_direction(hidden) | |
| W_gate_orig = moe.gate.weight.data.clone() | |
| count = AbliterationPipeline._project_moe_experts(moe, d) | |
| assert count > 0 | |
| # Gate weight should have been modified | |
| assert not torch.allclose(moe.gate.weight.data, W_gate_orig), \ | |
| "Router/gate weights should be projected" | |
| # The gate weight's projection onto the direction should be ~0 | |
| proj = (moe.gate.weight.data @ d).norm().item() | |
| assert proj < 1e-4, f"Gate should have no component along refusal dir, got {proj}" | |
| def test_shared_expert_projected(self): | |
| """Shared expert (always-on) should have both input and output projected.""" | |
| hidden = 16 | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.gate = torch.nn.Linear(hidden, 2, bias=False) | |
| self.shared_expert = torch.nn.Module() | |
| self.shared_expert.down_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| self.shared_expert.up_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| self.experts = torch.nn.ModuleList([ | |
| self._make_expert() for _ in range(2) | |
| ]) | |
| def _make_expert(): | |
| m = torch.nn.Module() | |
| m.down_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| m.up_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| return m | |
| moe = FakeMoE() | |
| d = self._make_direction(hidden) | |
| shared_down_orig = moe.shared_expert.down_proj.weight.data.clone() | |
| shared_up_orig = moe.shared_expert.up_proj.weight.data.clone() | |
| count = AbliterationPipeline._project_moe_experts(moe, d) | |
| assert count > 0 | |
| # Both shared expert output AND input projections should be modified | |
| assert not torch.allclose(moe.shared_expert.down_proj.weight.data, shared_down_orig), \ | |
| "Shared expert output (down_proj) should be projected" | |
| assert not torch.allclose(moe.shared_expert.up_proj.weight.data, shared_up_orig), \ | |
| "Shared expert input (up_proj) should be projected" | |
| def test_expert_input_projections_projected(self): | |
| """Expert input projections (up_proj, gate_proj) should also be modified.""" | |
| hidden = 16 | |
| class FakeExpert(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| self.up_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| self.gate_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(2)]) | |
| moe = FakeMoE() | |
| d = self._make_direction(hidden) | |
| up_orig = moe.experts[0].up_proj.weight.data.clone() | |
| count = AbliterationPipeline._project_moe_experts(moe, d) | |
| # Each expert contributes 2 projections (output + input) | |
| # 2 experts * 2 = 4 minimum | |
| assert count >= 4, f"Expected >= 4 projections (out+in per expert), got {count}" | |
| assert not torch.allclose(moe.experts[0].up_proj.weight.data, up_orig), \ | |
| "Expert input (up_proj) should be projected" | |
| def test_fused_3d_output_and_input(self): | |
| """Fused 3D parameter patterns (GPT-OSS style) should project both directions.""" | |
| hidden = 16 | |
| intermediate = 32 | |
| n_experts = 4 | |
| class FusedExperts(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden)) | |
| self.up_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden)) | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.experts = FusedExperts() | |
| moe = FakeMoE() | |
| d = self._make_direction(hidden) | |
| down_orig = moe.experts.down_proj.data.clone() | |
| up_orig = moe.experts.up_proj.data.clone() | |
| count = AbliterationPipeline._project_moe_experts(moe, d) | |
| # 4 experts output + 4 experts input = 8 | |
| assert count == 8, f"Expected 8 fused projections, got {count}" | |
| assert not torch.allclose(moe.experts.down_proj.data, down_orig), \ | |
| "Fused output (down_proj) should be projected" | |
| assert not torch.allclose(moe.experts.up_proj.data, up_orig), \ | |
| "Fused input (up_proj) should be projected" | |
| def test_fused_3d_norm_preserve(self): | |
| """Fused 3D projections should preserve norms when requested.""" | |
| hidden = 16 | |
| intermediate = 32 | |
| n_experts = 4 | |
| class FusedExperts(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden)) | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.experts = FusedExperts() | |
| moe = FakeMoE() | |
| d = self._make_direction(hidden) | |
| # Record per-expert norms before | |
| orig_norms = [moe.experts.down_proj.data[i].norm().item() for i in range(n_experts)] | |
| AbliterationPipeline._project_moe_experts(moe, d, norm_preserve=True) | |
| # Check per-expert norms preserved | |
| for i in range(n_experts): | |
| new_norm = moe.experts.down_proj.data[i].norm().item() | |
| assert abs(orig_norms[i] - new_norm) < 1e-3, \ | |
| f"Expert {i} norm not preserved: {orig_norms[i]:.4f} vs {new_norm:.4f}" | |
| def test_no_experts_returns_zero(self): | |
| """Module without experts attribute should return 0.""" | |
| class NoMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.mlp = torch.nn.Linear(16, 32) | |
| moe = NoMoE() | |
| d = self._make_direction(16) | |
| assert AbliterationPipeline._project_moe_experts(moe, d) == 0 | |
| def test_router_bias_projected(self): | |
| """Router bias should be projected when project_biases=True.""" | |
| hidden = 16 | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.gate = torch.nn.Linear(hidden, 4, bias=True) | |
| self.experts = torch.nn.ModuleList([ | |
| self._make_expert() for _ in range(4) | |
| ]) | |
| def _make_expert(): | |
| m = torch.nn.Module() | |
| m.down_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| return m | |
| moe = FakeMoE() | |
| d = self._make_direction(hidden) | |
| bias_orig = moe.gate.bias.data.clone() | |
| count = AbliterationPipeline._project_moe_experts(moe, d, project_biases=True) | |
| # Gate has 4 outputs (num_experts), direction has 16 dims | |
| # bias shape (4,) != direction shape (16,), so bias won't match. | |
| # This is correct: router bias is (num_experts,), not (hidden_dim,), | |
| # so _project_bias won't modify it (shape mismatch is expected). | |
| assert torch.allclose(moe.gate.bias.data, bias_orig), ( | |
| "Router bias should be unchanged when shape mismatches direction" | |
| ) | |
| assert isinstance(count, int) | |
| assert count > 0 # expert weights should still be projected | |
| def test_router_auto_detection_fallback(self): | |
| """Unknown router name should be auto-detected and projected.""" | |
| import warnings as w | |
| hidden = 16 | |
| n_experts = 4 | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # Unusual router name not in _ROUTER_NAMES | |
| self.moe_gate_proj = torch.nn.Linear(hidden, n_experts, bias=False) | |
| self.experts = torch.nn.ModuleList([ | |
| self._make_expert() for _ in range(n_experts) | |
| ]) | |
| def _make_expert(): | |
| m = torch.nn.Module() | |
| m.down_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| return m | |
| moe = FakeMoE() | |
| d = self._make_direction(hidden) | |
| gate_orig = moe.moe_gate_proj.weight.data.clone() | |
| with w.catch_warnings(record=True) as caught: | |
| w.simplefilter("always") | |
| AbliterationPipeline._project_moe_experts(moe, d) | |
| # Should auto-detect and project the unusual router name | |
| assert not torch.allclose(moe.moe_gate_proj.weight.data, gate_orig), \ | |
| "Auto-detected router should be projected" | |
| # Should emit a warning about the auto-detection | |
| auto_detect_warnings = [ | |
| x for x in caught | |
| if "auto-detected" in str(x.message) | |
| ] | |
| assert len(auto_detect_warnings) > 0, "Should warn about auto-detected router" | |
| def test_full_moe_all_components(self): | |
| """End-to-end: all MoE components should be modified together.""" | |
| hidden = 16 | |
| class FakeExpert(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| self.up_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.gate = torch.nn.Linear(hidden, 4, bias=False) | |
| self.shared_expert = torch.nn.Module() | |
| self.shared_expert.down_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| self.shared_expert.up_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(4)]) | |
| moe = FakeMoE() | |
| d = self._make_direction(hidden) | |
| count = AbliterationPipeline._project_moe_experts(moe, d) | |
| # Expected: 1 (gate) + 2 (shared out+in) + 4*2 (expert out+in) = 11 | |
| assert count == 11, f"Expected 11 total projections, got {count}" | |
| # --------------------------------------------------------------------------- | |
| # SOTA technique #1: Safety-neuron masking (GateBreaker-style z-score) | |
| # --------------------------------------------------------------------------- | |
| class TestSafetyNeuronMasking: | |
| def test_outlier_neurons_zeroed(self): | |
| """Neurons with outsized refusal projection should be zeroed.""" | |
| hidden = 16 | |
| class Wrapper(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Linear(hidden, 64, bias=False) | |
| module = Wrapper() | |
| torch.manual_seed(42) | |
| # Inject a few rows with very high projection along direction | |
| d = torch.randn(hidden, 1) | |
| d = d / d.norm() | |
| # Make rows 0,1,2 have huge projection (outliers) | |
| for i in range(3): | |
| module.down_proj.weight.data[i] = d.squeeze() * 10.0 | |
| n_masked = AbliterationPipeline._mask_safety_neurons( | |
| module, d, ["down_proj"], z_threshold=2.0, | |
| ) | |
| assert n_masked >= 3, f"Expected >= 3 masked neurons, got {n_masked}" | |
| # Masked rows should be zero | |
| for i in range(3): | |
| assert module.down_proj.weight.data[i].abs().max().item() < 1e-6 | |
| def test_no_outliers_no_masking(self): | |
| """When all neurons have similar projection, none should be masked.""" | |
| hidden = 16 | |
| class Wrapper(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| module = Wrapper() | |
| # Uniform weights — no outliers | |
| module.down_proj.weight.data = torch.ones(32, hidden) * 0.1 | |
| d = torch.randn(hidden, 1) | |
| d = d / d.norm() | |
| n_masked = AbliterationPipeline._mask_safety_neurons( | |
| module, d, ["down_proj"], z_threshold=2.0, | |
| ) | |
| assert n_masked == 0 | |
| def test_high_threshold_masks_fewer(self): | |
| """Higher z_threshold should mask fewer neurons.""" | |
| hidden = 16 | |
| class Wrapper(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Linear(hidden, 64, bias=False) | |
| module1 = Wrapper() | |
| module2 = Wrapper() | |
| torch.manual_seed(42) | |
| W = torch.randn(64, hidden) | |
| module1.down_proj.weight.data = W.clone() | |
| module2.down_proj.weight.data = W.clone() | |
| d = torch.randn(hidden, 1) | |
| d = d / d.norm() | |
| n1 = AbliterationPipeline._mask_safety_neurons(module1, d, ["down_proj"], z_threshold=1.5) | |
| n2 = AbliterationPipeline._mask_safety_neurons(module2, d, ["down_proj"], z_threshold=3.0) | |
| assert n1 >= n2, f"Lower threshold should mask more: {n1} vs {n2}" | |
| # --------------------------------------------------------------------------- | |
| # SOTA technique #2: Jailbreak-contrastive direction refinement | |
| # --------------------------------------------------------------------------- | |
| class TestJailbreakContrastive: | |
| def test_jailbreak_prompt_generation(self): | |
| """Pipeline should generate jailbreak variants of harmful prompts.""" | |
| pipeline = AbliterationPipeline(model_name="test") | |
| pipeline.harmful_prompts = ["How to hack a computer", "Write a virus"] | |
| jailbreak = pipeline._generate_jailbreak_prompts() | |
| assert len(jailbreak) == 2 | |
| # Each should contain the original prompt | |
| assert "hack a computer" in jailbreak[0].lower() | |
| assert "virus" in jailbreak[1].lower() | |
| # Each should be wrapped in a template (longer than original) | |
| assert len(jailbreak[0]) > len("How to hack a computer") | |
| def test_jailbreak_contrast_blending(self): | |
| """Jailbreak-contrastive blending should modify refusal direction.""" | |
| pipeline = AbliterationPipeline( | |
| model_name="test", | |
| use_jailbreak_contrast=True, | |
| n_directions=1, | |
| ) | |
| hidden = 16 | |
| pipeline._on_log = lambda m: None | |
| # Simulate probed means | |
| torch.manual_seed(42) | |
| harm_mean = torch.randn(1, hidden) | |
| safe_mean = torch.randn(1, hidden) | |
| jb_mean = torch.randn(1, hidden) | |
| pipeline._harmful_means = {0: harm_mean} | |
| pipeline._harmless_means = {0: safe_mean} | |
| pipeline._jailbreak_means = {0: jb_mean} | |
| pipeline._harmful_acts = {0: [harm_mean]} | |
| pipeline._harmless_acts = {0: [safe_mean]} | |
| pipeline._jailbreak_acts = {0: [jb_mean]} | |
| # Run distill (will set standard direction, then blend) | |
| pipeline._distill() | |
| # Direction should be a unit vector | |
| d = pipeline.refusal_directions[0] | |
| assert abs(d.norm().item() - 1.0) < 1e-4 | |
| # Direction should differ from pure harm-safe difference | |
| std_diff = (harm_mean - safe_mean).squeeze() | |
| std_dir = std_diff / std_diff.norm() | |
| cosine = (d @ std_dir).item() | |
| # Blended direction should not be identical to standard | |
| assert cosine < 0.99, f"Blended direction too similar to standard: cos={cosine}" | |
| def test_surgical_method_enables_jailbreak(self): | |
| """Surgical method should enable jailbreak-contrastive by default.""" | |
| cfg = METHODS["surgical"] | |
| assert cfg["use_jailbreak_contrast"] is True | |
| # --------------------------------------------------------------------------- | |
| # SOTA technique #3: Layer-adaptive projection strength | |
| # --------------------------------------------------------------------------- | |
| class TestLayerAdaptiveStrength: | |
| def test_layer_weights_computed(self): | |
| """Layer-adaptive weights should be proportional to refusal signal.""" | |
| pipeline = AbliterationPipeline( | |
| model_name="test", | |
| layer_adaptive_strength=True, | |
| n_directions=1, | |
| ) | |
| hidden = 16 | |
| pipeline._on_log = lambda m: None | |
| # Simulate: layer 0 has strong signal, layer 1 weak | |
| torch.manual_seed(42) | |
| strong_diff = torch.randn(1, hidden) * 10.0 | |
| weak_diff = torch.randn(1, hidden) * 1.0 | |
| zero_mean = torch.zeros(1, hidden) | |
| pipeline._harmful_means = {0: strong_diff, 1: weak_diff} | |
| pipeline._harmless_means = {0: zero_mean, 1: zero_mean} | |
| pipeline._harmful_acts = {0: [strong_diff], 1: [weak_diff]} | |
| pipeline._harmless_acts = {0: [zero_mean], 1: [zero_mean]} | |
| pipeline._distill() | |
| # Layer weights should exist for strong layers | |
| assert len(pipeline._layer_excise_weights) > 0 | |
| # Strongest layer should have weight ~1.0 | |
| max_weight = max(pipeline._layer_excise_weights.values()) | |
| assert max_weight > 0.9, f"Max weight should be ~1.0, got {max_weight}" | |
| def test_surgical_method_enables_adaptive(self): | |
| """Surgical method should enable layer-adaptive by default.""" | |
| cfg = METHODS["surgical"] | |
| assert cfg["layer_adaptive_strength"] is True | |
| # --------------------------------------------------------------------------- | |
| # SOTA technique #5: Attention head surgery | |
| # --------------------------------------------------------------------------- | |
| class TestAttentionHeadSurgery: | |
| def test_head_selective_projection(self): | |
| """Selective head projection should only modify targeted head rows.""" | |
| hidden = 16 | |
| n_heads = 4 | |
| head_dim = hidden // n_heads | |
| class FakeAttn(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.o_proj = torch.nn.Linear(hidden, hidden, bias=False) | |
| attn = FakeAttn() | |
| torch.manual_seed(42) | |
| W_orig = attn.o_proj.weight.data.clone() | |
| d = torch.randn(hidden, 1) | |
| d = d / d.norm() | |
| # Head scores: head 0 is top safety head, head 3 is lowest | |
| head_scores = [(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)] | |
| n_modified = AbliterationPipeline._project_head_selective( | |
| attn, d, head_scores, n_heads=n_heads, head_fraction=0.25, | |
| ) | |
| assert n_modified >= 1, "Should modify at least 1 head" | |
| W_new = attn.o_proj.weight.data | |
| # Head 0 columns (targeted) should be modified | |
| assert not torch.allclose( | |
| W_new[:, 0:head_dim], W_orig[:, 0:head_dim] | |
| ), "Targeted head 0 should be modified" | |
| # Head 3 columns (NOT targeted) should be untouched | |
| assert torch.allclose( | |
| W_new[:, 3*head_dim:4*head_dim], | |
| W_orig[:, 3*head_dim:4*head_dim], | |
| ), "Non-targeted head 3 should be untouched" | |
| def test_head_surgery_norm_preserve(self): | |
| """Head surgery with norm_preserve should maintain per-head norms.""" | |
| hidden = 16 | |
| n_heads = 4 | |
| head_dim = hidden // n_heads | |
| class FakeAttn(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.o_proj = torch.nn.Linear(hidden, hidden, bias=False) | |
| attn = FakeAttn() | |
| torch.manual_seed(42) | |
| d = torch.randn(hidden, 1) | |
| d = d / d.norm() | |
| orig_norms = [ | |
| attn.o_proj.weight.data[:, h*head_dim:(h+1)*head_dim].norm().item() | |
| for h in range(n_heads) | |
| ] | |
| head_scores = [(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)] | |
| AbliterationPipeline._project_head_selective( | |
| attn, d, head_scores, n_heads=n_heads, | |
| head_fraction=0.5, norm_preserve=True, | |
| ) | |
| # Targeted heads should have preserved norms | |
| for h in range(2): # top 50% = 2 heads | |
| new_norm = attn.o_proj.weight.data[:, h*head_dim:(h+1)*head_dim].norm().item() | |
| assert abs(orig_norms[h] - new_norm) < 1e-3, \ | |
| f"Head {h} norm not preserved: {orig_norms[h]:.4f} vs {new_norm:.4f}" | |
| def test_head_surgery_non_square_gqa(self): | |
| """Head surgery should work for GQA models with non-square o_proj (attn_dim != hidden_dim).""" | |
| hidden_dim = 12 # model hidden dimension | |
| attn_dim = 32 # attention dimension (n_heads * head_dim_attn) | |
| n_heads = 4 | |
| head_dim_attn = attn_dim // n_heads # 8 | |
| class FakeAttnGQA(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # o_proj maps attn_dim -> hidden_dim | |
| # nn.Linear weight shape: (hidden_dim, attn_dim) = (12, 32) | |
| self.o_proj = torch.nn.Linear(attn_dim, hidden_dim, bias=False) | |
| attn = FakeAttnGQA() | |
| torch.manual_seed(42) | |
| attn.o_proj.weight.data = torch.randn(hidden_dim, attn_dim) | |
| W_orig = attn.o_proj.weight.data.clone() | |
| d = torch.randn(hidden_dim, 1) | |
| d = d / d.norm() | |
| head_scores = [(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)] | |
| n_modified = AbliterationPipeline._project_head_selective( | |
| attn, d, head_scores, n_heads=n_heads, head_fraction=0.25, | |
| ) | |
| assert n_modified >= 1, "Should modify at least 1 head" | |
| W_new = attn.o_proj.weight.data | |
| # Head 0 columns (targeted) should be modified | |
| assert not torch.allclose( | |
| W_new[:, 0:head_dim_attn], W_orig[:, 0:head_dim_attn] | |
| ), "Targeted head 0 should be modified" | |
| # Head 3 columns (NOT targeted) should be untouched | |
| assert torch.allclose( | |
| W_new[:, 3*head_dim_attn:4*head_dim_attn], | |
| W_orig[:, 3*head_dim_attn:4*head_dim_attn], | |
| ), "Non-targeted head 3 should be untouched" | |
| def test_head_surgery_gqa_norm_preserve(self): | |
| """Head surgery on GQA non-square o_proj with norm_preserve.""" | |
| hidden_dim = 12 | |
| attn_dim = 32 | |
| n_heads = 4 | |
| head_dim_attn = attn_dim // n_heads | |
| class FakeAttnGQA(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.o_proj = torch.nn.Linear(attn_dim, hidden_dim, bias=False) | |
| attn = FakeAttnGQA() | |
| torch.manual_seed(42) | |
| attn.o_proj.weight.data = torch.randn(hidden_dim, attn_dim) | |
| d = torch.randn(hidden_dim, 1) | |
| d = d / d.norm() | |
| orig_norms = [ | |
| attn.o_proj.weight.data[:, h*head_dim_attn:(h+1)*head_dim_attn].norm().item() | |
| for h in range(n_heads) | |
| ] | |
| head_scores = [(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)] | |
| AbliterationPipeline._project_head_selective( | |
| attn, d, head_scores, n_heads=n_heads, | |
| head_fraction=0.5, norm_preserve=True, | |
| ) | |
| for h in range(2): # top 50% = 2 heads | |
| new_norm = attn.o_proj.weight.data[:, h*head_dim_attn:(h+1)*head_dim_attn].norm().item() | |
| assert abs(orig_norms[h] - new_norm) < 1e-3, \ | |
| f"GQA head {h} norm not preserved: {orig_norms[h]:.4f} vs {new_norm:.4f}" | |
| # --------------------------------------------------------------------------- | |
| # SOTA technique #6: SAE feature-level abliteration | |
| # --------------------------------------------------------------------------- | |
| class TestSAEAbliteration: | |
| def test_sae_train_and_reconstruct(self): | |
| """SAE should train and reconstruct activations.""" | |
| from obliteratus.analysis.sae_abliteration import train_sae | |
| hidden = 32 | |
| # Generate synthetic activations | |
| torch.manual_seed(42) | |
| acts = [torch.randn(hidden) for _ in range(64)] | |
| sae = train_sae(acts, hidden, expansion=2, n_epochs=10, lr=1e-3) | |
| # Forward pass should work | |
| x = torch.randn(1, hidden) | |
| x_hat, z = sae(x) | |
| assert x_hat.shape == x.shape | |
| assert z.shape == (1, 2 * hidden) # expansion=2 | |
| # Z should be sparse (ReLU activation) | |
| assert (z == 0).float().mean() > 0.3, "Features should be sparse" | |
| def test_refusal_feature_identification(self): | |
| """SAE should identify features that differ between harmful/harmless.""" | |
| from obliteratus.analysis.sae_abliteration import ( | |
| train_sae, identify_refusal_features, | |
| ) | |
| hidden = 32 | |
| torch.manual_seed(42) | |
| # Create activations with clear harmful/harmless separation | |
| refusal_dir = torch.randn(hidden) | |
| refusal_dir = refusal_dir / refusal_dir.norm() | |
| harmful_acts = [torch.randn(hidden) + 2.0 * refusal_dir for _ in range(32)] | |
| harmless_acts = [torch.randn(hidden) - 2.0 * refusal_dir for _ in range(32)] | |
| all_acts = harmful_acts + harmless_acts | |
| sae = train_sae(all_acts, hidden, expansion=2, n_epochs=30, lr=3e-4) | |
| result = identify_refusal_features( | |
| sae, harmful_acts, harmless_acts, layer_idx=0, top_k=4, | |
| ) | |
| assert result.n_refusal_features == 4 | |
| assert result.sae_directions.shape == (4, hidden) | |
| assert result.variance_explained > 0.0 | |
| # SAE directions should have some alignment with the actual refusal direction | |
| best_cos = max( | |
| abs((result.sae_directions[i] @ refusal_dir).item()) | |
| for i in range(result.sae_directions.shape[0]) | |
| ) | |
| assert best_cos > 0.1, f"SAE should find direction aligned with refusal: best_cos={best_cos}" | |
| def test_sae_directions_unit_norm(self): | |
| """SAE-derived directions should be unit normalized.""" | |
| from obliteratus.analysis.sae_abliteration import ( | |
| train_sae, identify_refusal_features, | |
| ) | |
| hidden = 16 | |
| torch.manual_seed(42) | |
| harmful = [torch.randn(hidden) + torch.ones(hidden) for _ in range(16)] | |
| harmless = [torch.randn(hidden) - torch.ones(hidden) for _ in range(16)] | |
| sae = train_sae(harmful + harmless, hidden, expansion=2, n_epochs=10) | |
| result = identify_refusal_features(sae, harmful, harmless, 0, top_k=3) | |
| for i in range(result.sae_directions.shape[0]): | |
| norm = result.sae_directions[i].norm().item() | |
| assert abs(norm - 1.0) < 1e-3, f"Direction {i} norm={norm}, expected 1.0" | |
| # --------------------------------------------------------------------------- | |
| # Surgical method preset | |
| # --------------------------------------------------------------------------- | |
| class TestSurgicalMethod: | |
| def test_surgical_enables_all_sota(self): | |
| """Surgical method should enable all 6 SOTA techniques.""" | |
| cfg = METHODS["surgical"] | |
| assert cfg["use_jailbreak_contrast"] is True | |
| assert cfg["layer_adaptive_strength"] is True | |
| assert cfg["safety_neuron_masking"] is True | |
| assert cfg["per_expert_directions"] is True | |
| assert cfg["attention_head_surgery"] is True | |
| assert cfg["use_sae_features"] is True | |
| def test_basic_disables_all_sota(self): | |
| """Basic method should not enable SOTA techniques (no keys or False).""" | |
| cfg = METHODS["basic"] | |
| assert cfg.get("use_jailbreak_contrast", False) is False | |
| assert cfg.get("layer_adaptive_strength", False) is False | |
| assert cfg.get("safety_neuron_masking", False) is False | |
| def test_pipeline_init_surgical(self): | |
| """Pipeline initialized with surgical method should have all flags set.""" | |
| pipeline = AbliterationPipeline(model_name="test", method="surgical") | |
| assert pipeline.use_jailbreak_contrast is True | |
| assert pipeline.layer_adaptive_strength is True | |
| assert pipeline.safety_neuron_masking is True | |
| assert pipeline.per_expert_directions is True | |
| assert pipeline.attention_head_surgery is True | |
| assert pipeline.use_sae_features is True | |
| def test_pipeline_init_explicit_override(self): | |
| """Explicit params should override method defaults.""" | |
| pipeline = AbliterationPipeline( | |
| model_name="test", method="surgical", | |
| safety_neuron_masking=False, | |
| ) | |
| assert pipeline.safety_neuron_masking is False | |
| assert pipeline.use_jailbreak_contrast is True # rest still from surgical | |
| # --------------------------------------------------------------------------- | |
| # Inverted method (semantic refusal inversion) | |
| # --------------------------------------------------------------------------- | |
| class TestInvertedMethod: | |
| def test_inverted_preset_config(self): | |
| """Inverted method preset should enable inversion flag.""" | |
| cfg = METHODS["inverted"] | |
| assert cfg["invert_refusal"] is True | |
| assert cfg["n_directions"] == 8 | |
| assert cfg["use_jailbreak_contrast"] is True | |
| def test_surgical_does_not_invert(self): | |
| """Surgical method should NOT enable inversion by default.""" | |
| cfg = METHODS["surgical"] | |
| assert cfg.get("invert_refusal", False) is False | |
| def test_pipeline_init_inverted(self): | |
| """Pipeline initialized with inverted method should have flag set.""" | |
| pipeline = AbliterationPipeline(model_name="test", method="inverted") | |
| assert pipeline.invert_refusal is True | |
| assert pipeline.use_jailbreak_contrast is True | |
| assert pipeline.safety_neuron_masking is False # zeroing + reflection is destructive | |
| def test_pipeline_invert_explicit_override(self): | |
| """Explicit invert_refusal param should override method default.""" | |
| pipeline = AbliterationPipeline( | |
| model_name="test", method="surgical", invert_refusal=True, | |
| ) | |
| assert pipeline.invert_refusal is True | |
| pipeline2 = AbliterationPipeline( | |
| model_name="test", method="inverted", invert_refusal=False, | |
| ) | |
| assert pipeline2.invert_refusal is False | |
| def test_reflection_math(self): | |
| """2x projection (reflection) should negate the refusal component.""" | |
| hidden = 16 | |
| class Wrapper(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.o_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| module = Wrapper() | |
| torch.manual_seed(42) | |
| W_orig = module.o_proj.weight.data.clone() | |
| d = torch.randn(hidden, 1) | |
| d = d / d.norm() | |
| # Original projection onto d | |
| orig_proj = (W_orig @ d).squeeze() | |
| # Reflection: regularization=-1.0 → scale=2.0 | |
| AbliterationPipeline._project_out_advanced( | |
| module, d, ["o_proj"], regularization=-1.0, | |
| ) | |
| W_reflected = module.o_proj.weight.data | |
| new_proj = (W_reflected @ d).squeeze() | |
| # After reflection, projection should be NEGATED (sign flipped) | |
| assert torch.allclose(new_proj, -orig_proj, atol=1e-4), ( | |
| f"Reflected projection should be negated: expected ~{-orig_proj[:3]} got {new_proj[:3]}" | |
| ) | |
| def test_reflection_preserves_orthogonal_component(self): | |
| """Reflection should not change the component perpendicular to d.""" | |
| hidden = 8 | |
| class Wrapper(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.o_proj = torch.nn.Linear(hidden, 16, bias=False) | |
| module = Wrapper() | |
| torch.manual_seed(42) | |
| W_orig = module.o_proj.weight.data.clone() | |
| d = torch.randn(hidden, 1) | |
| d = d / d.norm() | |
| # Compute original orthogonal component | |
| orig_d_component = (W_orig @ d) @ d.T # rank-1 matrix: projection onto d | |
| orig_ortho = W_orig - orig_d_component # everything except d-component | |
| AbliterationPipeline._project_out_advanced( | |
| module, d, ["o_proj"], regularization=-1.0, | |
| ) | |
| W_reflected = module.o_proj.weight.data | |
| new_d_component = (W_reflected @ d) @ d.T | |
| new_ortho = W_reflected - new_d_component | |
| # Orthogonal component should be unchanged | |
| assert torch.allclose(orig_ortho, new_ortho, atol=1e-4), ( | |
| "Reflection should preserve orthogonal component" | |
| ) | |
| def test_moe_expert_safety_classification(self): | |
| """_identify_safety_experts should classify experts by router affinity.""" | |
| hidden = 16 | |
| n_experts = 4 | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.gate = torch.nn.Linear(hidden, n_experts, bias=False) | |
| self.experts = torch.nn.ModuleList([ | |
| torch.nn.Linear(hidden, hidden) for _ in range(n_experts) | |
| ]) | |
| class FakeLayer(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.self_attn = torch.nn.Module() | |
| self.self_attn.o_proj = torch.nn.Linear(hidden, hidden, bias=False) | |
| self.mlp = FakeMoE() | |
| from obliteratus.models.loader import ModelHandle | |
| from unittest.mock import MagicMock | |
| from transformers import GPT2Config | |
| config = GPT2Config(n_embd=hidden, n_head=2, n_layer=1, vocab_size=100, n_positions=64) | |
| model = MagicMock() | |
| model.parameters.return_value = iter([torch.zeros(1)]) | |
| handle = ModelHandle( | |
| model=model, tokenizer=MagicMock(), | |
| config=config, model_name="test", task="causal_lm", | |
| ) | |
| pipeline = AbliterationPipeline(model_name="test", method="inverted") | |
| pipeline.handle = handle | |
| pipeline._on_log = lambda m: None | |
| pipeline._on_stage = lambda r: None | |
| # Set up fake layer and direction | |
| layer = FakeLayer() | |
| torch.manual_seed(42) | |
| # Make router weight so expert 0 has highest affinity for d | |
| d = torch.randn(hidden) | |
| d = d / d.norm() | |
| # Set router weights: expert 0 aligned with d, expert 3 anti-aligned | |
| layer.mlp.gate.weight.data[0] = d * 5.0 | |
| layer.mlp.gate.weight.data[1] = d * 1.0 | |
| layer.mlp.gate.weight.data[2] = d * -1.0 | |
| layer.mlp.gate.weight.data[3] = d * -5.0 | |
| # Mock get_layer_modules to return our fake layer | |
| import obliteratus.abliterate as abl_module | |
| orig_get_layers = abl_module.get_layer_modules | |
| orig_get_ffn = abl_module.get_ffn_module | |
| abl_module.get_layer_modules = lambda h: [layer] | |
| abl_module.get_ffn_module = lambda lay, a: lay.mlp | |
| try: | |
| pipeline.refusal_directions = {0: d} | |
| pipeline._strong_layers = [0] | |
| pipeline._identify_safety_experts() | |
| finally: | |
| abl_module.get_layer_modules = orig_get_layers | |
| abl_module.get_ffn_module = orig_get_ffn | |
| assert 0 in pipeline._expert_safety_scores | |
| scores = pipeline._expert_safety_scores[0] | |
| # Expert 0 should be highest safety affinity | |
| assert scores[0][0] == 0, f"Expert 0 should be top safety, got {scores[0]}" | |
| # Expert 3 should be lowest | |
| assert scores[-1][0] == 3, f"Expert 3 should be lowest, got {scores[-1]}" | |
| def test_moe_inverted_excision_selective(self): | |
| """Inverted MoE excision should reflect safety experts and remove from capability.""" | |
| hidden = 16 | |
| n_experts = 4 | |
| class FakeExpert(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Linear(hidden, hidden, bias=False) | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.gate = torch.nn.Linear(hidden, n_experts, bias=False) | |
| self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)]) | |
| moe = FakeMoE() | |
| torch.manual_seed(42) | |
| for p in moe.parameters(): | |
| p.data = torch.randn_like(p.data) | |
| d = torch.randn(hidden, 1) | |
| d = d / d.norm() | |
| # Set up safety scores: experts 0,1 are safety, 2,3 are capability | |
| pipeline = AbliterationPipeline(model_name="test", method="inverted") | |
| pipeline._on_log = lambda m: None | |
| pipeline._on_stage = lambda r: None | |
| pipeline._expert_safety_scores = { | |
| 0: [(0, 5.0), (1, 3.0), (2, -1.0), (3, -3.0)] | |
| } | |
| orig_router = moe.gate.weight.data.clone() | |
| count = pipeline._project_moe_experts_inverted( | |
| moe, d, 0, norm_preserve=False, project_biases=False, | |
| ) | |
| assert count > 0, "Should project some weights" | |
| # Router should be reflected (capped at 1.5x to prevent extreme logits | |
| # that cause CUDA illegal memory access in batched expert forward). | |
| # With router_reg = max(reflect_reg, -0.5) → scale = 1.5: | |
| # new_proj ≈ orig_proj - 1.5 * orig_proj = -0.5 * orig_proj | |
| # Additionally, _stabilize_router_weights clamps outliers, so we | |
| # verify the sign is flipped and magnitude is substantial. | |
| router_proj = (moe.gate.weight.data @ d.squeeze()).squeeze() | |
| orig_router_proj = (orig_router @ d.squeeze()).squeeze() | |
| cosine = torch.nn.functional.cosine_similarity( | |
| router_proj.unsqueeze(0), -orig_router_proj.unsqueeze(0), | |
| ) | |
| assert cosine > 0.5, ( | |
| f"Router projection should be at least partially reflected, cosine={cosine.item():.3f}" | |
| ) | |
| # Safety expert 0: should be reflected (projection negated) | |
| e0_proj = (moe.experts[0].down_proj.weight.data @ d).norm() | |
| # After reflection the projection doesn't go to zero — it negates | |
| assert e0_proj > 1e-4, "Safety expert should have non-zero projection (reflected, not removed)" | |
| # Capability expert 3: should have projection removed (near zero) | |
| e3_proj = (moe.experts[3].down_proj.weight.data @ d).norm().item() | |
| assert e3_proj < 1e-3, f"Capability expert should have projection removed, got {e3_proj}" | |
| # --------------------------------------------------------------------------- | |
| # Nuclear method | |
| # --------------------------------------------------------------------------- | |
| class TestNuclearMethod: | |
| def test_nuclear_preset_config(self): | |
| """Nuclear method should match inverted baseline + permanent weight techniques.""" | |
| cfg = METHODS["nuclear"] | |
| assert cfg["invert_refusal"] is True | |
| assert cfg["n_directions"] == 4 # fewer than inverted to avoid over-ablation | |
| assert cfg["refinement_passes"] == 2 # same as inverted | |
| assert cfg["reflection_strength"] == 1.25 # tempered for CoT coherence | |
| assert cfg["project_embeddings"] is True | |
| assert cfg["embed_regularization"] == 0.50 # conservative cascade limit | |
| assert cfg["activation_steering"] is True # residual cleanup hooks | |
| assert cfg["steering_strength"] == 0.15 # light residual correction | |
| assert cfg["expert_transplant"] is True | |
| assert cfg["transplant_blend"] == 0.10 # gentle nudge, not overwrite | |
| assert cfg["use_jailbreak_contrast"] is True | |
| assert cfg["attention_head_surgery"] is True | |
| assert cfg["layer_adaptive_strength"] is True # per-layer scaling | |
| def test_nuclear_pipeline_init(self): | |
| """Pipeline initialized with nuclear method should have all flags set.""" | |
| pipeline = AbliterationPipeline(model_name="test", method="nuclear") | |
| assert pipeline.invert_refusal is True | |
| assert pipeline.reflection_strength == 1.25 | |
| assert pipeline.embed_regularization == 0.50 | |
| assert pipeline.transplant_blend == 0.10 | |
| assert pipeline.project_embeddings is True | |
| assert pipeline.activation_steering is True # residual cleanup | |
| assert pipeline.expert_transplant is True | |
| assert pipeline.n_directions == 4 | |
| assert pipeline.refinement_passes == 2 | |
| assert pipeline.layer_adaptive_strength is True | |
| def test_reflection_strength_configurable(self): | |
| """reflection_strength should be explicitly overridable.""" | |
| pipeline = AbliterationPipeline( | |
| model_name="test", method="inverted", reflection_strength=3.0, | |
| ) | |
| assert pipeline.reflection_strength == 3.0 | |
| def test_inverted_default_strength_is_2(self): | |
| """Inverted method should default to reflection_strength=2.0.""" | |
| pipeline = AbliterationPipeline(model_name="test", method="inverted") | |
| assert pipeline.reflection_strength == 2.0 | |
| def test_boosted_reflection_math(self): | |
| """2.5x reflection should produce stronger negation than 2x.""" | |
| hidden = 16 | |
| class Wrapper(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.o_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| d = torch.randn(hidden, 1) | |
| d = d / d.norm() | |
| # 2x reflection | |
| module_2x = Wrapper() | |
| torch.manual_seed(42) | |
| module_2x.o_proj.weight.data = torch.randn(32, hidden) | |
| orig = module_2x.o_proj.weight.data.clone() | |
| AbliterationPipeline._project_out_advanced( | |
| module_2x, d, ["o_proj"], regularization=-1.0, # scale=2.0 | |
| ) | |
| proj_2x = (module_2x.o_proj.weight.data @ d).squeeze() | |
| # 2.5x reflection | |
| module_25x = Wrapper() | |
| module_25x.o_proj.weight.data = orig.clone() | |
| AbliterationPipeline._project_out_advanced( | |
| module_25x, d, ["o_proj"], regularization=-1.5, # scale=2.5 | |
| ) | |
| proj_25x = (module_25x.o_proj.weight.data @ d).squeeze() | |
| # 2.5x should be 25% stronger negation than 2x | |
| assert proj_25x.norm() > proj_2x.norm(), ( | |
| "2.5x reflection should produce stronger (more negative) projection than 2x" | |
| ) | |
| def test_activation_steering_hook(self): | |
| """Steering hooks should subtract refusal direction from hidden states.""" | |
| hidden = 8 | |
| class FakeLayer(torch.nn.Module): | |
| def forward(self, x): | |
| return x | |
| layer = FakeLayer() | |
| layers = torch.nn.ModuleList([layer]) | |
| # Explicitly enable steering (nuclear preset has it off by default) | |
| pipeline = AbliterationPipeline( | |
| model_name="test", method="inverted", activation_steering=True, | |
| steering_strength=0.5, | |
| ) | |
| pipeline._on_log = lambda m: None | |
| pipeline._on_stage = lambda r: None | |
| d = torch.randn(hidden) | |
| d = d / d.norm() | |
| pipeline.refusal_directions = {0: d} | |
| pipeline._strong_layers = [0] | |
| n_hooks = pipeline._install_activation_steering(layers) | |
| assert n_hooks == 1 | |
| assert len(pipeline._steering_hooks) == 1 | |
| # Create a hidden state with strong refusal component | |
| batch = torch.randn(1, 4, hidden) | |
| refusal_component = 5.0 * d.unsqueeze(0).unsqueeze(0).expand_as(batch) | |
| input_hidden = batch + refusal_component | |
| # Run through the layer (hook should fire) | |
| output = layer(input_hidden) | |
| # The refusal component should be reduced | |
| proj_before = torch.einsum("bsh,h->bs", input_hidden, d).abs().mean() | |
| proj_after = torch.einsum("bsh,h->bs", output, d).abs().mean() | |
| assert proj_after < proj_before, ( | |
| f"Steering should reduce refusal projection: before={proj_before:.3f}, after={proj_after:.3f}" | |
| ) | |
| # Cleanup | |
| for hook in pipeline._steering_hooks: | |
| hook.remove() | |
| def test_expert_transplant(self): | |
| """Expert transplant should overwrite safety expert weights with capability average.""" | |
| hidden = 16 | |
| n_experts = 4 | |
| class FakeExpert(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Linear(hidden, hidden, bias=False) | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.gate = torch.nn.Linear(hidden, n_experts, bias=False) | |
| self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)]) | |
| class FakeLayer(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.self_attn = torch.nn.Module() | |
| self.self_attn.o_proj = torch.nn.Linear(hidden, hidden, bias=False) | |
| self.mlp = FakeMoE() | |
| layer = FakeLayer() | |
| layers = torch.nn.ModuleList([layer]) | |
| torch.manual_seed(42) | |
| for p in layer.parameters(): | |
| p.data = torch.randn_like(p.data) | |
| # Save original safety expert weight | |
| orig_safety0 = layer.mlp.experts[0].down_proj.weight.data.clone() | |
| # Save capability expert weights for computing expected mean | |
| # With top-third classification (n_experts // 3 = 1), only expert 0 | |
| # is safety; experts 1, 2, 3 are all capability. | |
| cap1 = layer.mlp.experts[1].down_proj.weight.data.clone() | |
| cap2 = layer.mlp.experts[2].down_proj.weight.data.clone() | |
| cap3 = layer.mlp.experts[3].down_proj.weight.data.clone() | |
| expected_mean = (cap1 + cap2 + cap3) / 3.0 | |
| import obliteratus.abliterate as abl_module | |
| from obliteratus.models.loader import ModelHandle | |
| from transformers import GPT2Config | |
| config = GPT2Config(n_embd=hidden, n_head=2, n_layer=1, vocab_size=100, n_positions=64) | |
| model = MagicMock() | |
| model.parameters.return_value = iter([torch.zeros(1)]) | |
| handle = ModelHandle(model=model, tokenizer=MagicMock(), config=config, model_name="test", task="causal_lm") | |
| pipeline = AbliterationPipeline(model_name="test", method="nuclear") | |
| pipeline.handle = handle | |
| pipeline._on_log = lambda m: None | |
| pipeline._on_stage = lambda r: None | |
| pipeline._strong_layers = [0] | |
| # Experts 0,1 are safety (high affinity), 2,3 are capability | |
| pipeline._expert_safety_scores = { | |
| 0: [(0, 5.0), (1, 3.0), (2, -1.0), (3, -3.0)] | |
| } | |
| orig_get_ffn = abl_module.get_ffn_module | |
| abl_module.get_ffn_module = lambda lay, a: lay.mlp | |
| try: | |
| count = pipeline._transplant_expert_weights(layers) | |
| finally: | |
| abl_module.get_ffn_module = orig_get_ffn | |
| assert count >= 1, f"Should blend at least 1 weight (top-third safety expert), got {count}" | |
| # Safety expert 0 should be a 10% blend toward capability mean | |
| # (nuclear default transplant_blend=0.10) | |
| # new = 0.90 * original + 0.10 * capability_mean | |
| blend = pipeline.transplant_blend # 0.10 | |
| expected_blend = (1.0 - blend) * orig_safety0 + blend * expected_mean | |
| transplanted = layer.mlp.experts[0].down_proj.weight.data | |
| assert torch.allclose(transplanted, expected_blend, atol=1e-4), ( | |
| f"Safety expert weight should be {blend:.0%} blended toward capability mean" | |
| ) | |
| # Capability expert 2 should be unchanged | |
| assert torch.allclose(layer.mlp.experts[2].down_proj.weight.data, cap2, atol=1e-6), ( | |
| "Capability expert should be unchanged" | |
| ) | |
| def test_gather_state_dict_raises_on_missing_offload(self): | |
| """Should raise RuntimeError (not silently corrupt) when offload dir is missing.""" | |
| from obliteratus.models.loader import ModelHandle | |
| from transformers import GPT2Config | |
| config = GPT2Config(n_embd=8, n_head=2, n_layer=1, vocab_size=100, n_positions=64) | |
| # Create a fake model whose state_dict returns a meta tensor | |
| fake_model = MagicMock() | |
| meta_tensor = torch.empty(4, 8, device="meta") | |
| fake_model.state_dict.return_value = {"layer.weight": meta_tensor} | |
| handle = ModelHandle( | |
| model=fake_model, tokenizer=MagicMock(), config=config, | |
| model_name="test", task="causal_lm", | |
| ) | |
| handle._offload_dir = "/nonexistent/path" | |
| pipeline = AbliterationPipeline(model_name="test", method="nuclear") | |
| pipeline.handle = handle | |
| pipeline._on_log = lambda m: None | |
| pipeline._on_stage = lambda r: None | |
| with pytest.raises(RuntimeError, match="bricked checkpoint"): | |
| pipeline._gather_state_dict() | |
| # --------------------------------------------------------------------------- | |
| # Knee detection | |
| # --------------------------------------------------------------------------- | |
| class TestKneeDetection: | |
| def test_empty_input(self): | |
| result = AbliterationPipeline._select_layers_knee([]) | |
| assert result == [] | |
| def test_two_layers(self): | |
| result = AbliterationPipeline._select_layers_knee([(0, 5.0), (1, 3.0)]) | |
| assert set(result) == {0, 1} | |
| def test_clear_knee(self): | |
| """Layers with a sharp dropoff should be separated by knee detection.""" | |
| sorted_layers = [ | |
| (14, 10.0), (15, 9.5), (13, 9.0), # strong cluster | |
| (16, 2.0), (12, 1.5), (17, 1.0), (11, 0.5), (18, 0.2), (10, 0.1), | |
| ] | |
| result = AbliterationPipeline._select_layers_knee(sorted_layers) | |
| # Should select the strong cluster (layers 14, 15, 13) and exclude weak ones | |
| assert 14 in result | |
| assert 15 in result | |
| assert 13 in result | |
| assert len(result) <= 5 # shouldn't select all 9 | |
| def test_minimum_threshold_filters_noise(self): | |
| """Layers below 10% of max should be filtered out.""" | |
| sorted_layers = [(0, 10.0), (1, 0.5)] # 0.5 is 5% of 10 | |
| result = AbliterationPipeline._select_layers_knee(sorted_layers) | |
| # Layer 1 is below 10% threshold | |
| assert 0 in result | |
| def test_all_equal_norms(self): | |
| """When all norms are equal, should select all (or most).""" | |
| sorted_layers = [(i, 5.0) for i in range(5)] | |
| result = AbliterationPipeline._select_layers_knee(sorted_layers) | |
| assert len(result) >= 1 | |
| # --------------------------------------------------------------------------- | |
| # Activation collection | |
| # --------------------------------------------------------------------------- | |
| class TestActivationCollection: | |
| def test_collect_activations(self, handle): | |
| """Test that activation collection returns correct structure.""" | |
| from obliteratus.strategies.utils import get_layer_modules | |
| pipeline = AbliterationPipeline(model_name="test") | |
| pipeline.handle = handle | |
| pipeline._on_log = lambda m: None | |
| layers = get_layer_modules(handle) | |
| prompts = ["Hello world", "Test prompt"] | |
| handle.tokenizer.return_value = { | |
| "input_ids": torch.randint(0, 1000, (1, 5)), | |
| "attention_mask": torch.ones(1, 5, dtype=torch.long), | |
| } | |
| activations = pipeline._collect_activations(layers, prompts, "test") | |
| assert len(activations) == len(layers) | |
| for idx in range(len(layers)): | |
| assert len(activations[idx]) == len(prompts) | |
| for act in activations[idx]: | |
| assert act.device == torch.device("cpu") | |
| assert act.shape[-1] == handle.hidden_size | |
| # --------------------------------------------------------------------------- | |
| # Distill: single direction (basic method) | |
| # --------------------------------------------------------------------------- | |
| class TestDistillBasic: | |
| def test_single_direction(self, handle): | |
| """Basic method: single refusal direction via difference-in-means.""" | |
| from obliteratus.strategies.utils import get_layer_modules | |
| pipeline = AbliterationPipeline( | |
| model_name="test", | |
| method="basic", | |
| harmful_prompts=["bad prompt"], | |
| harmless_prompts=["good prompt"], | |
| ) | |
| pipeline.handle = handle | |
| pipeline._on_log = lambda m: None | |
| pipeline._on_stage = lambda r: None | |
| _make_varied_tokenizer(handle) | |
| pipeline._probe() | |
| pipeline._distill() | |
| n_layers = len(get_layer_modules(handle)) | |
| assert len(pipeline.refusal_directions) == n_layers | |
| for idx, direction in pipeline.refusal_directions.items(): | |
| assert abs(direction.norm().item() - 1.0) < 1e-4 | |
| # Single direction: subspace should be (1, hidden_dim) | |
| assert pipeline.refusal_subspaces[idx].shape[0] == 1 | |
| # --------------------------------------------------------------------------- | |
| # Distill: multi-direction SVD (advanced/aggressive method) | |
| # --------------------------------------------------------------------------- | |
| class TestDistillSVD: | |
| def test_multi_direction_svd(self, handle): | |
| """Advanced method: SVD extracts multiple refusal directions. | |
| Note: on small models (hidden_size < 2048 or < 2B params), n_directions | |
| is automatically capped to 2 to prevent over-ablation. The test model | |
| (hidden_size=64, 4 layers) triggers this safeguard. | |
| """ | |
| from obliteratus.strategies.utils import get_layer_modules | |
| pipeline = AbliterationPipeline( | |
| model_name="test", | |
| method="advanced", | |
| harmful_prompts=["bad1", "bad2", "bad3", "bad4", "bad5"], | |
| harmless_prompts=["good1", "good2", "good3", "good4", "good5"], | |
| ) | |
| pipeline.handle = handle | |
| pipeline._on_log = lambda m: None | |
| pipeline._on_stage = lambda r: None | |
| _make_varied_tokenizer(handle) | |
| pipeline._probe() | |
| pipeline._distill() | |
| n_layers = len(get_layer_modules(handle)) | |
| assert len(pipeline.refusal_subspaces) == n_layers | |
| # Small-model cap: n_directions capped to 2 for tiny test model | |
| expected_dirs = min(2, pipeline.n_directions, 5, handle.hidden_size) | |
| for idx, subspace in pipeline.refusal_subspaces.items(): | |
| assert subspace.shape[0] == expected_dirs | |
| assert subspace.shape[1] == handle.hidden_size | |
| # Primary direction should still be a unit vector | |
| for idx, direction in pipeline.refusal_directions.items(): | |
| assert abs(direction.norm().item() - 1.0) < 1e-4 | |
| # --------------------------------------------------------------------------- | |
| # Full pipeline: excise with different methods | |
| # --------------------------------------------------------------------------- | |
| class TestExcise: | |
| def test_excise_basic(self, handle): | |
| """Basic method should modify weights.""" | |
| from obliteratus.strategies.utils import get_layer_modules | |
| pipeline = AbliterationPipeline( | |
| model_name="test", | |
| method="basic", | |
| harmful_prompts=["bad prompt"], | |
| harmless_prompts=["good prompt"], | |
| ) | |
| pipeline.handle = handle | |
| pipeline._on_log = lambda m: None | |
| pipeline._on_stage = lambda r: None | |
| _make_varied_tokenizer(handle) | |
| layers = get_layer_modules(handle) | |
| original_weights = {} | |
| for idx in range(len(layers)): | |
| for name, param in layers[idx].named_parameters(): | |
| original_weights[(idx, name)] = param.data.clone() | |
| pipeline._probe() | |
| pipeline._distill() | |
| pipeline._excise() | |
| any_changed = False | |
| for idx in range(len(layers)): | |
| for name, param in layers[idx].named_parameters(): | |
| if not torch.allclose(original_weights[(idx, name)], param.data, atol=1e-6): | |
| any_changed = True | |
| break | |
| assert any_changed, "Excise should modify at least some weights" | |
| def test_excise_advanced_norm_preserving(self, handle): | |
| """Advanced method with norm preservation should maintain weight norms.""" | |
| from obliteratus.strategies.utils import get_layer_modules | |
| pipeline = AbliterationPipeline( | |
| model_name="test", | |
| method="advanced", | |
| harmful_prompts=["bad prompt"], | |
| harmless_prompts=["good prompt"], | |
| ) | |
| pipeline.handle = handle | |
| pipeline._on_log = lambda m: None | |
| pipeline._on_stage = lambda r: None | |
| _make_varied_tokenizer(handle) | |
| get_layer_modules(handle) | |
| pipeline._probe() | |
| pipeline._distill() | |
| pipeline._excise() | |
| # Weights should have been modified (advanced uses _project_out_advanced) | |
| assert len(pipeline._strong_layers) > 0 | |
| # --------------------------------------------------------------------------- | |
| # Rebirth (save) | |
| # --------------------------------------------------------------------------- | |
| class TestRebirth: | |
| def test_rebirth_saves_metadata(self, handle, tmp_path): | |
| """Rebirth should save model and comprehensive metadata JSON.""" | |
| pipeline = AbliterationPipeline( | |
| model_name="test-model", | |
| output_dir=str(tmp_path / "output"), | |
| method="advanced", | |
| ) | |
| pipeline.handle = handle | |
| pipeline._on_log = lambda m: None | |
| pipeline._on_stage = lambda r: None | |
| pipeline._strong_layers = [0] | |
| pipeline._quality_metrics = {"perplexity": 8.5, "coherence": 1.0} | |
| handle.model.save_pretrained = MagicMock() | |
| handle.tokenizer.save_pretrained = MagicMock() | |
| result_path = pipeline._rebirth() | |
| assert result_path == tmp_path / "output" | |
| assert (result_path / "abliteration_metadata.json").exists() | |
| metadata = json.loads((result_path / "abliteration_metadata.json").read_text()) | |
| assert metadata["source_model"] == "test-model" | |
| assert metadata["technique"] == "refusal_direction_ablation" | |
| assert metadata["method"] == "advanced" | |
| assert metadata["strong_layers"] == [0] | |
| assert "method_config" in metadata | |
| assert metadata["method_config"]["n_directions"] == METHODS["advanced"]["n_directions"] | |
| assert metadata["method_config"]["norm_preserve"] is True | |
| assert "references" in metadata | |
| assert len(metadata["references"]) >= 3 | |
| assert "quality_metrics" in metadata | |
| assert metadata["quality_metrics"]["perplexity"] == 8.5 | |
| # --------------------------------------------------------------------------- | |
| # CLI integration | |
| # --------------------------------------------------------------------------- | |
| class TestCLI: | |
| def test_abliterate_parser_with_method(self): | |
| """Test that the abliterate subcommand parses method correctly.""" | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| subparsers = parser.add_subparsers(dest="command") | |
| abl_parser = subparsers.add_parser("abliterate") | |
| abl_parser.add_argument("model", type=str) | |
| abl_parser.add_argument("--output-dir", type=str, default=None) | |
| abl_parser.add_argument("--device", type=str, default="auto") | |
| abl_parser.add_argument("--dtype", type=str, default="float16") | |
| abl_parser.add_argument("--method", type=str, default="advanced", | |
| choices=["basic", "advanced", "aggressive"]) | |
| abl_parser.add_argument("--n-directions", type=int, default=None) | |
| abl_parser.add_argument("--regularization", type=float, default=None) | |
| abl_parser.add_argument("--refinement-passes", type=int, default=None) | |
| args = parser.parse_args(["abliterate", "gpt2", "--method", "aggressive", "--n-directions", "6"]) | |
| assert args.command == "abliterate" | |
| assert args.model == "gpt2" | |
| assert args.method == "aggressive" | |
| assert args.n_directions == 6 | |
| assert args.dtype == "float16" | |
| def test_default_method(self): | |
| """Default method should be advanced.""" | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| subparsers = parser.add_subparsers(dest="command") | |
| abl_parser = subparsers.add_parser("abliterate") | |
| abl_parser.add_argument("model", type=str) | |
| abl_parser.add_argument("--method", type=str, default="advanced") | |
| args = parser.parse_args(["abliterate", "gpt2"]) | |
| assert args.method == "advanced" | |
| # --------------------------------------------------------------------------- | |
| # Expert-Granular Abliteration (EGA) | |
| # --------------------------------------------------------------------------- | |
| class TestFindRouterModule: | |
| """Test _find_router_module static method.""" | |
| def test_finds_gate(self): | |
| """Should find a router named 'gate'.""" | |
| hidden = 16 | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.gate = torch.nn.Linear(hidden, 4, bias=False) | |
| self.experts = torch.nn.ModuleList() | |
| moe = FakeMoE() | |
| router = AbliterationPipeline._find_router_module(moe) | |
| assert router is moe.gate | |
| def test_finds_router(self): | |
| """Should find a router named 'router'.""" | |
| hidden = 16 | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.router = torch.nn.Linear(hidden, 4, bias=False) | |
| self.experts = torch.nn.ModuleList() | |
| moe = FakeMoE() | |
| router = AbliterationPipeline._find_router_module(moe) | |
| assert router is moe.router | |
| def test_auto_detects_unknown_router(self): | |
| """Should auto-detect a router with unusual name via heuristic.""" | |
| hidden = 16 | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.moe_gate_proj = torch.nn.Linear(hidden, 4, bias=False) | |
| self.experts = torch.nn.ModuleList() | |
| moe = FakeMoE() | |
| router = AbliterationPipeline._find_router_module(moe) | |
| assert router is moe.moe_gate_proj | |
| def test_returns_none_no_router(self): | |
| """Should return None when no router is found.""" | |
| class NoRouter(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.linear = torch.nn.Linear(16, 16) | |
| mod = NoRouter() | |
| assert AbliterationPipeline._find_router_module(mod) is None | |
| class TestRouterProfilingHooks: | |
| """Test _install_router_profiling_hooks.""" | |
| def _make_moe_pipeline_and_layers(self, hidden=16, n_experts=4): | |
| """Create a pipeline with a fake MoE model for router profiling tests.""" | |
| from obliteratus.models.loader import ModelHandle | |
| from transformers import GPT2Config | |
| class FakeExpert(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Linear(hidden, hidden, bias=False) | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.gate = torch.nn.Linear(hidden, n_experts, bias=False) | |
| self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)]) | |
| def forward(self, x): | |
| return x | |
| class FakeLayer(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.self_attn = torch.nn.Module() | |
| self.self_attn.o_proj = torch.nn.Linear(hidden, hidden, bias=False) | |
| self.mlp = FakeMoE() | |
| def forward(self, x): | |
| return (x,) | |
| config = GPT2Config(n_embd=hidden, n_head=2, n_layer=1, vocab_size=100, n_positions=64) | |
| model = MagicMock() | |
| model.parameters.return_value = iter([torch.zeros(1)]) | |
| handle = ModelHandle(model=model, tokenizer=MagicMock(), config=config, model_name="test", task="causal_lm") | |
| pipeline = AbliterationPipeline(model_name="test", method="surgical") | |
| pipeline.handle = handle | |
| pipeline._on_log = lambda m: None | |
| pipeline._on_stage = lambda r: None | |
| layer = FakeLayer() | |
| layers = torch.nn.ModuleList([layer]) | |
| # Monkey-patch get_ffn_module | |
| import obliteratus.abliterate as abl_module | |
| orig_get_ffn = abl_module.get_ffn_module | |
| abl_module.get_ffn_module = lambda lay, a: lay.mlp | |
| return pipeline, layers, layer, abl_module, orig_get_ffn | |
| def test_hooks_installed(self): | |
| """Should install hooks on MoE router modules.""" | |
| pipeline, layers, layer, abl_module, orig_get_ffn = self._make_moe_pipeline_and_layers() | |
| try: | |
| hooks = pipeline._install_router_profiling_hooks(layers) | |
| assert len(hooks) == 1 | |
| assert 0 in pipeline._routing_harmful | |
| assert 0 in pipeline._routing_harmless | |
| finally: | |
| for h in hooks: | |
| h.remove() | |
| abl_module.get_ffn_module = orig_get_ffn | |
| def test_hooks_record_logits(self): | |
| """Hooks should record router logits during forward passes.""" | |
| pipeline, layers, layer, abl_module, orig_get_ffn = self._make_moe_pipeline_and_layers() | |
| try: | |
| hooks = pipeline._install_router_profiling_hooks(layers) | |
| # Simulate harmful forward pass | |
| pipeline._routing_is_harmful = True | |
| x = torch.randn(1, 5, 16) | |
| layer.mlp.gate(x) # triggers hook | |
| assert len(pipeline._routing_harmful[0]) == 1 | |
| assert pipeline._routing_harmful[0][0].shape[0] == 4 # n_experts | |
| # Simulate harmless forward pass | |
| pipeline._routing_is_harmful = False | |
| layer.mlp.gate(x) | |
| assert len(pipeline._routing_harmless[0]) == 1 | |
| finally: | |
| for h in hooks: | |
| h.remove() | |
| abl_module.get_ffn_module = orig_get_ffn | |
| def test_no_handle_returns_empty(self): | |
| """Should return empty list when handle is None.""" | |
| pipeline = AbliterationPipeline(model_name="test", method="surgical") | |
| pipeline.handle = None | |
| hooks = pipeline._install_router_profiling_hooks(torch.nn.ModuleList()) | |
| assert hooks == [] | |
| class TestComputeExpertGranularDirections: | |
| """Test _compute_expert_granular_directions.""" | |
| def test_computes_per_expert_directions(self): | |
| """Should compute per-expert refusal directions from routing data.""" | |
| hidden = 16 | |
| n_experts = 4 | |
| pipeline = AbliterationPipeline(model_name="test", method="surgical") | |
| pipeline._on_log = lambda m: None | |
| pipeline._on_stage = lambda r: None | |
| pipeline._strong_layers = [0] | |
| torch.manual_seed(42) | |
| # Simulate router logits: expert 0 favored for harmful, expert 3 for harmless | |
| h_logits = [] | |
| s_logits = [] | |
| for _ in range(10): | |
| hl = torch.randn(n_experts) | |
| hl[0] += 2.0 # bias expert 0 for harmful | |
| h_logits.append(hl) | |
| sl = torch.randn(n_experts) | |
| sl[3] += 2.0 # bias expert 3 for harmless | |
| s_logits.append(sl) | |
| pipeline._routing_harmful = {0: h_logits} | |
| pipeline._routing_harmless = {0: s_logits} | |
| # Simulate per-prompt activations with harmful/harmless separation | |
| refusal_dir = torch.randn(hidden) | |
| refusal_dir = refusal_dir / refusal_dir.norm() | |
| h_acts = [torch.randn(hidden) + 1.5 * refusal_dir for _ in range(10)] | |
| s_acts = [torch.randn(hidden) - 1.5 * refusal_dir for _ in range(10)] | |
| pipeline._harmful_acts = {0: h_acts} | |
| pipeline._harmless_acts = {0: s_acts} | |
| pipeline._compute_expert_granular_directions() | |
| # Should have computed expert directions for layer 0 | |
| assert 0 in pipeline._expert_directions | |
| assert len(pipeline._expert_directions[0]) > 0 | |
| # Should have dynamic safety scores | |
| assert 0 in pipeline._expert_safety_scores | |
| scores = pipeline._expert_safety_scores[0] | |
| assert len(scores) == n_experts | |
| # Expert 0 should have higher safety score (more activated for harmful) | |
| expert_0_score = next(s for eid, s in scores if eid == 0) | |
| expert_3_score = next(s for eid, s in scores if eid == 3) | |
| assert expert_0_score > expert_3_score, ( | |
| f"Expert 0 should have higher safety score: {expert_0_score} vs {expert_3_score}" | |
| ) | |
| def test_directions_are_unit_vectors(self): | |
| """Per-expert directions should be unit normalized.""" | |
| hidden = 16 | |
| n_experts = 4 | |
| pipeline = AbliterationPipeline(model_name="test", method="surgical") | |
| pipeline._on_log = lambda m: None | |
| pipeline._strong_layers = [0] | |
| torch.manual_seed(42) | |
| h_logits = [torch.randn(n_experts) for _ in range(10)] | |
| s_logits = [torch.randn(n_experts) for _ in range(10)] | |
| pipeline._routing_harmful = {0: h_logits} | |
| pipeline._routing_harmless = {0: s_logits} | |
| pipeline._harmful_acts = {0: [torch.randn(hidden) + torch.ones(hidden) for _ in range(10)]} | |
| pipeline._harmless_acts = {0: [torch.randn(hidden) - torch.ones(hidden) for _ in range(10)]} | |
| pipeline._compute_expert_granular_directions() | |
| if 0 in pipeline._expert_directions: | |
| for ei, d in pipeline._expert_directions[0].items(): | |
| assert abs(d.norm().item() - 1.0) < 1e-4, ( | |
| f"Expert {ei} direction norm={d.norm().item()}, expected 1.0" | |
| ) | |
| def test_skips_when_no_routing_data(self): | |
| """Should skip gracefully when no routing data is available.""" | |
| pipeline = AbliterationPipeline(model_name="test", method="surgical") | |
| pipeline._on_log = lambda m: None | |
| pipeline._routing_harmful = {} | |
| pipeline._routing_harmless = {} | |
| pipeline._compute_expert_granular_directions() | |
| assert len(pipeline._expert_directions) == 0 | |
| def test_skips_expert_with_low_routing_weight(self): | |
| """Experts with insufficient routing weight should not get directions.""" | |
| hidden = 16 | |
| pipeline = AbliterationPipeline(model_name="test", method="surgical") | |
| pipeline._on_log = lambda m: None | |
| pipeline._strong_layers = [0] | |
| # Create routing logits where expert 3 is never selected (very low) | |
| h_logits = [] | |
| s_logits = [] | |
| for _ in range(3): | |
| hl = torch.tensor([5.0, 5.0, 5.0, -100.0]) # expert 3 never routed | |
| h_logits.append(hl) | |
| sl = torch.tensor([5.0, 5.0, 5.0, -100.0]) | |
| s_logits.append(sl) | |
| pipeline._routing_harmful = {0: h_logits} | |
| pipeline._routing_harmless = {0: s_logits} | |
| torch.manual_seed(42) | |
| pipeline._harmful_acts = {0: [torch.randn(hidden) for _ in range(3)]} | |
| pipeline._harmless_acts = {0: [torch.randn(hidden) for _ in range(3)]} | |
| pipeline._compute_expert_granular_directions() | |
| # Expert 3 should NOT have a direction (routing weight too low) | |
| if 0 in pipeline._expert_directions: | |
| assert 3 not in pipeline._expert_directions[0], ( | |
| "Expert with near-zero routing weight should not get a direction" | |
| ) | |
| class TestProjectMoEExpertsGranular: | |
| """Test _project_moe_experts_granular (ModuleList path).""" | |
| def _make_direction(self, hidden_dim=16): | |
| d = torch.randn(hidden_dim, 1) | |
| return d / d.norm() | |
| def test_per_expert_directions_applied(self): | |
| """Each expert should use its own direction when available.""" | |
| hidden = 16 | |
| n_experts = 4 | |
| class FakeExpert(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| self.up_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.gate = torch.nn.Linear(hidden, n_experts, bias=False) | |
| self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)]) | |
| moe = FakeMoE() | |
| torch.manual_seed(42) | |
| for p in moe.parameters(): | |
| p.data = torch.randn_like(p.data) | |
| shared_dir = self._make_direction(hidden) | |
| # Create distinct per-expert directions | |
| expert_dirs = {} | |
| for ei in range(n_experts): | |
| d = torch.randn(hidden) | |
| d = d / d.norm() | |
| expert_dirs[ei] = d | |
| pipeline = AbliterationPipeline(model_name="test", method="surgical") | |
| pipeline._on_log = lambda m: None | |
| pipeline._expert_directions = {0: expert_dirs} | |
| # Save originals | |
| orig_weights = { | |
| ei: moe.experts[ei].down_proj.weight.data.clone() | |
| for ei in range(n_experts) | |
| } | |
| count = pipeline._project_moe_experts_granular( | |
| moe, shared_dir, layer_idx=0, | |
| ) | |
| assert count > 0, "Should project some weights" | |
| # All experts should be modified | |
| for ei in range(n_experts): | |
| assert not torch.allclose( | |
| moe.experts[ei].down_proj.weight.data, orig_weights[ei] | |
| ), f"Expert {ei} should be modified" | |
| def test_falls_back_to_shared_direction(self): | |
| """Experts without per-expert direction should use shared direction.""" | |
| hidden = 16 | |
| n_experts = 4 | |
| class FakeExpert(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| self.up_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.gate = torch.nn.Linear(hidden, n_experts, bias=False) | |
| self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)]) | |
| moe = FakeMoE() | |
| torch.manual_seed(42) | |
| for p in moe.parameters(): | |
| p.data = torch.randn_like(p.data) | |
| shared_dir = self._make_direction(hidden) | |
| # Only expert 0 has a per-expert direction | |
| expert_dirs = {0: torch.randn(hidden).div_(torch.randn(hidden).norm())} | |
| expert_dirs[0] = expert_dirs[0] / expert_dirs[0].norm() | |
| pipeline = AbliterationPipeline(model_name="test", method="surgical") | |
| pipeline._on_log = lambda m: None | |
| pipeline._expert_directions = {0: expert_dirs} | |
| orig_e1 = moe.experts[1].down_proj.weight.data.clone() | |
| pipeline._project_moe_experts_granular( | |
| moe, shared_dir, layer_idx=0, | |
| ) | |
| # Experts 1,2,3 should be modified (using shared direction) | |
| assert not torch.allclose(moe.experts[1].down_proj.weight.data, orig_e1), \ | |
| "Expert 1 should use shared direction fallback" | |
| def test_router_uses_shared_direction(self): | |
| """Router should always use the shared direction, not per-expert.""" | |
| hidden = 16 | |
| n_experts = 4 | |
| class FakeExpert(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.gate = torch.nn.Linear(hidden, n_experts, bias=False) | |
| self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)]) | |
| moe = FakeMoE() | |
| shared_dir = self._make_direction(hidden) | |
| pipeline = AbliterationPipeline(model_name="test", method="surgical") | |
| pipeline._on_log = lambda m: None | |
| pipeline._expert_directions = {0: {0: torch.randn(hidden)}} | |
| orig_gate = moe.gate.weight.data.clone() | |
| pipeline._project_moe_experts_granular(moe, shared_dir, layer_idx=0) | |
| # Gate should be projected | |
| assert not torch.allclose(moe.gate.weight.data, orig_gate), \ | |
| "Router should be projected with shared direction" | |
| # Gate's projection onto shared direction should be near zero | |
| proj = (moe.gate.weight.data @ shared_dir).norm().item() | |
| assert proj < 1e-4, f"Router should have shared dir removed, proj={proj}" | |
| def test_shared_expert_uses_shared_direction(self): | |
| """Shared expert should always use the shared direction.""" | |
| hidden = 16 | |
| class FakeExpert(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| self.up_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.gate = torch.nn.Linear(hidden, 2, bias=False) | |
| self.shared_expert = torch.nn.Module() | |
| self.shared_expert.down_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| self.shared_expert.up_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(2)]) | |
| moe = FakeMoE() | |
| shared_dir = self._make_direction(hidden) | |
| pipeline = AbliterationPipeline(model_name="test", method="surgical") | |
| pipeline._on_log = lambda m: None | |
| pipeline._expert_directions = {0: {0: torch.randn(hidden)}} | |
| orig_shared = moe.shared_expert.down_proj.weight.data.clone() | |
| pipeline._project_moe_experts_granular(moe, shared_dir, layer_idx=0) | |
| assert not torch.allclose(moe.shared_expert.down_proj.weight.data, orig_shared), \ | |
| "Shared expert should be projected" | |
| class TestProjectFused3DGranular: | |
| """Test _project_fused_3d_granular for fused 3D expert tensors.""" | |
| def test_per_expert_directions_on_fused(self): | |
| """Each expert slice should use its own direction.""" | |
| hidden = 16 | |
| intermediate = 32 | |
| n_experts = 4 | |
| class FusedExperts(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden)) | |
| container = FusedExperts() | |
| torch.manual_seed(42) | |
| shared_dir = torch.randn(hidden, 1) | |
| shared_dir = shared_dir / shared_dir.norm() | |
| # Per-expert directions | |
| expert_dirs = {} | |
| for ei in range(n_experts): | |
| d = torch.randn(hidden) | |
| d = d / d.norm() | |
| expert_dirs[ei] = d | |
| orig_data = container.down_proj.data.clone() | |
| count = AbliterationPipeline._project_fused_3d_granular( | |
| container, shared_dir, expert_dirs, ["down_proj"], | |
| norm_preserve=False, scale=1.0, | |
| ) | |
| assert count == n_experts, f"Should project {n_experts} experts, got {count}" | |
| # Each expert should be modified | |
| for ei in range(n_experts): | |
| assert not torch.allclose( | |
| container.down_proj.data[ei], orig_data[ei] | |
| ), f"Expert {ei} should be modified" | |
| def test_fallback_to_shared_on_fused(self): | |
| """Experts without per-expert direction should use shared direction.""" | |
| hidden = 16 | |
| intermediate = 32 | |
| n_experts = 4 | |
| class FusedExperts(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden)) | |
| container = FusedExperts() | |
| torch.manual_seed(42) | |
| shared_dir = torch.randn(hidden, 1) | |
| shared_dir = shared_dir / shared_dir.norm() | |
| # Only expert 0 has a direction | |
| expert_dirs = {0: torch.randn(hidden).div_(1.0)} | |
| expert_dirs[0] = expert_dirs[0] / expert_dirs[0].norm() | |
| orig_data = container.down_proj.data.clone() | |
| count = AbliterationPipeline._project_fused_3d_granular( | |
| container, shared_dir, expert_dirs, ["down_proj"], | |
| norm_preserve=False, scale=1.0, | |
| ) | |
| assert count == n_experts | |
| # All experts should be modified (experts 1-3 use shared dir) | |
| for ei in range(n_experts): | |
| assert not torch.allclose( | |
| container.down_proj.data[ei], orig_data[ei] | |
| ), f"Expert {ei} should be modified" | |
| def test_norm_preserve_on_fused(self): | |
| """Fused 3D with norm_preserve should maintain per-expert norms.""" | |
| hidden = 16 | |
| intermediate = 32 | |
| n_experts = 4 | |
| class FusedExperts(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Parameter(torch.randn(n_experts, intermediate, hidden)) | |
| container = FusedExperts() | |
| torch.manual_seed(42) | |
| shared_dir = torch.randn(hidden, 1) | |
| shared_dir = shared_dir / shared_dir.norm() | |
| expert_dirs = {} | |
| for ei in range(n_experts): | |
| d = torch.randn(hidden) | |
| expert_dirs[ei] = d / d.norm() | |
| orig_norms = [container.down_proj.data[i].norm().item() for i in range(n_experts)] | |
| AbliterationPipeline._project_fused_3d_granular( | |
| container, shared_dir, expert_dirs, ["down_proj"], | |
| norm_preserve=True, scale=1.0, | |
| ) | |
| for i in range(n_experts): | |
| new_norm = container.down_proj.data[i].norm().item() | |
| assert abs(orig_norms[i] - new_norm) < 1e-3, ( | |
| f"Expert {i} norm not preserved: {orig_norms[i]:.4f} vs {new_norm:.4f}" | |
| ) | |
| def test_skips_non_3d_params(self): | |
| """Should skip parameters that are not 3-dimensional.""" | |
| hidden = 16 | |
| class FlatExperts(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Parameter(torch.randn(32, hidden)) | |
| container = FlatExperts() | |
| shared_dir = torch.randn(hidden, 1) | |
| shared_dir = shared_dir / shared_dir.norm() | |
| count = AbliterationPipeline._project_fused_3d_granular( | |
| container, shared_dir, {}, ["down_proj"], | |
| norm_preserve=False, scale=1.0, | |
| ) | |
| assert count == 0 | |
| class TestEGAExciseIntegration: | |
| """Test that EGA integrates properly in the excise stage path.""" | |
| def test_ega_pipeline_flags(self): | |
| """Pipeline with surgical method should enable per_expert_directions.""" | |
| pipeline = AbliterationPipeline(model_name="test", method="surgical") | |
| assert pipeline.per_expert_directions is True | |
| def test_ega_only_on_primary_direction(self): | |
| """EGA should only apply for dir_idx==0, not higher SVD directions.""" | |
| # This is enforced by the `and dir_idx == 0` check in _excise | |
| # We verify the code structure exists | |
| from obliteratus.abliterate import AbliterationPipeline | |
| import inspect | |
| source = inspect.getsource(AbliterationPipeline._excise_inner) | |
| assert "dir_idx == 0" in source, "EGA should only apply for primary direction" | |
| assert "_project_moe_experts_granular" in source, "EGA method should be called in excise" | |
| def test_ega_distill_integration(self): | |
| """EGA should be called during distill when per_expert_directions is enabled.""" | |
| from obliteratus.abliterate import AbliterationPipeline | |
| import inspect | |
| source = inspect.getsource(AbliterationPipeline._distill) | |
| assert "_compute_expert_granular_directions" in source | |
| assert "per_expert_directions" in source | |
| def test_nuclear_method_enables_ega(self): | |
| """Nuclear method should also enable per_expert_directions.""" | |
| cfg = METHODS["nuclear"] | |
| assert cfg["per_expert_directions"] is True | |
| pipeline = AbliterationPipeline(model_name="test", method="nuclear") | |
| assert pipeline.per_expert_directions is True | |
| def test_basic_method_disables_ega(self): | |
| """Basic method should not enable per_expert_directions.""" | |
| cfg = METHODS["basic"] | |
| assert cfg.get("per_expert_directions", False) is False | |
| def test_inverted_method_enables_ega(self): | |
| """Inverted method should enable per_expert_directions.""" | |
| cfg = METHODS["inverted"] | |
| assert cfg["per_expert_directions"] is True | |
| def test_ega_with_routing_data_end_to_end(self): | |
| """End-to-end: EGA computes directions and granular projection modifies weights.""" | |
| hidden = 16 | |
| n_experts = 4 | |
| class FakeExpert(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.down_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| self.up_proj = torch.nn.Linear(hidden, 32, bias=False) | |
| class FakeMoE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.gate = torch.nn.Linear(hidden, n_experts, bias=False) | |
| self.experts = torch.nn.ModuleList([FakeExpert() for _ in range(n_experts)]) | |
| moe = FakeMoE() | |
| torch.manual_seed(42) | |
| for p in moe.parameters(): | |
| p.data = torch.randn_like(p.data) | |
| pipeline = AbliterationPipeline(model_name="test", method="surgical") | |
| pipeline._on_log = lambda m: None | |
| pipeline._on_stage = lambda r: None | |
| pipeline._strong_layers = [0] | |
| # Simulate EGA routing data | |
| h_logits = [torch.randn(n_experts) for _ in range(5)] | |
| s_logits = [torch.randn(n_experts) for _ in range(5)] | |
| pipeline._routing_harmful = {0: h_logits} | |
| pipeline._routing_harmless = {0: s_logits} | |
| # Simulate activations with clear separation | |
| refusal_dir = torch.randn(hidden) | |
| refusal_dir = refusal_dir / refusal_dir.norm() | |
| pipeline._harmful_acts = {0: [torch.randn(hidden) + 2 * refusal_dir for _ in range(5)]} | |
| pipeline._harmless_acts = {0: [torch.randn(hidden) - 2 * refusal_dir for _ in range(5)]} | |
| # Step 1: compute EGA directions | |
| pipeline._compute_expert_granular_directions() | |
| assert 0 in pipeline._expert_directions | |
| assert len(pipeline._expert_directions[0]) > 0 | |
| # Step 2: apply granular projection | |
| shared_dir = torch.randn(hidden, 1) | |
| shared_dir = shared_dir / shared_dir.norm() | |
| orig_expert0 = moe.experts[0].down_proj.weight.data.clone() | |
| count = pipeline._project_moe_experts_granular( | |
| moe, shared_dir, layer_idx=0, | |
| ) | |
| assert count > 0 | |
| assert not torch.allclose(moe.experts[0].down_proj.weight.data, orig_expert0), \ | |
| "Expert weights should be modified by EGA" | |