Spaces:
Running on Zero
Running on Zero
| """Extended tests for novel abliteration pipeline features. | |
| Tests the new capabilities added to the OBLITERATUS abliteration pipeline: | |
| - Bias projection | |
| - Chat template wrapping | |
| - Method presets with new parameters | |
| - True iterative refinement | |
| - Whitened SVD integration | |
| """ | |
| from __future__ import annotations | |
| from unittest.mock import MagicMock | |
| import torch | |
| from transformers import GPT2Config, GPT2LMHeadModel | |
| from obliteratus.abliterate import ( | |
| METHODS, | |
| AbliterationPipeline, | |
| ) | |
| from obliteratus.models.loader import ModelHandle | |
| def _make_tiny_handle(): | |
| """Create a minimal ModelHandle with a tiny GPT-2 for testing.""" | |
| config = GPT2Config( | |
| vocab_size=1000, | |
| n_positions=128, | |
| n_embd=64, | |
| n_layer=4, | |
| n_head=2, | |
| n_inner=256, | |
| ) | |
| model = GPT2LMHeadModel(config) | |
| model.eval() | |
| tokenizer = MagicMock() | |
| tokenizer.pad_token = "<pad>" | |
| tokenizer.eos_token = "<eos>" | |
| tokenizer.return_value = { | |
| "input_ids": torch.randint(0, 1000, (1, 10)), | |
| "attention_mask": torch.ones(1, 10, dtype=torch.long), | |
| } | |
| tokenizer.decode.return_value = "The capital of France is Paris, a beautiful city" | |
| handle = ModelHandle( | |
| model=model, | |
| tokenizer=tokenizer, | |
| config=config, | |
| model_name="gpt2-test", | |
| task="causal_lm", | |
| ) | |
| handle.snapshot() | |
| return handle | |
| def _make_varied_tokenizer(handle): | |
| """Set up a tokenizer mock that returns different tokens per call.""" | |
| call_count = [0] | |
| def mock_tokenizer(prompt, **kwargs): | |
| call_count[0] += 1 | |
| torch.manual_seed(call_count[0]) | |
| return { | |
| "input_ids": torch.randint(0, 1000, (1, 5)), | |
| "attention_mask": torch.ones(1, 5, dtype=torch.long), | |
| } | |
| handle.tokenizer.side_effect = mock_tokenizer | |
| # --------------------------------------------------------------------------- | |
| # New method preset parameters | |
| # --------------------------------------------------------------------------- | |
| class TestNewMethodPresets: | |
| def test_basic_has_new_params(self): | |
| cfg = METHODS["basic"] | |
| assert "project_biases" in cfg | |
| assert "use_chat_template" in cfg | |
| assert "use_whitened_svd" in cfg | |
| assert "true_iterative_refinement" in cfg | |
| assert cfg["project_biases"] is False | |
| assert cfg["use_chat_template"] is False | |
| def test_advanced_has_new_params(self): | |
| cfg = METHODS["advanced"] | |
| assert cfg["project_biases"] is True | |
| assert cfg["use_chat_template"] is True | |
| assert cfg["use_whitened_svd"] is False | |
| assert cfg["true_iterative_refinement"] is False | |
| def test_aggressive_has_new_params(self): | |
| cfg = METHODS["aggressive"] | |
| assert cfg["project_biases"] is True | |
| assert cfg["use_chat_template"] is True | |
| assert cfg["use_whitened_svd"] is True | |
| assert cfg["true_iterative_refinement"] is True | |
| # --------------------------------------------------------------------------- | |
| # Pipeline initialization with new parameters | |
| # --------------------------------------------------------------------------- | |
| class TestNewPipelineInit: | |
| def test_default_new_params(self): | |
| pipeline = AbliterationPipeline(model_name="test-model") | |
| # advanced method defaults | |
| assert pipeline.project_biases is True | |
| assert pipeline.use_chat_template is True | |
| assert pipeline.use_whitened_svd is False | |
| assert pipeline.true_iterative_refinement is False | |
| def test_basic_method_new_params(self): | |
| pipeline = AbliterationPipeline(model_name="test-model", method="basic") | |
| assert pipeline.project_biases is False | |
| assert pipeline.use_chat_template is False | |
| assert pipeline.use_whitened_svd is False | |
| assert pipeline.true_iterative_refinement is False | |
| def test_aggressive_method_new_params(self): | |
| pipeline = AbliterationPipeline(model_name="test-model", method="aggressive") | |
| assert pipeline.project_biases is True | |
| assert pipeline.use_chat_template is True | |
| assert pipeline.use_whitened_svd is True | |
| assert pipeline.true_iterative_refinement is True | |
| def test_explicit_overrides_new_params(self): | |
| pipeline = AbliterationPipeline( | |
| model_name="test-model", | |
| method="basic", | |
| project_biases=True, | |
| use_chat_template=True, | |
| use_whitened_svd=True, | |
| true_iterative_refinement=True, | |
| ) | |
| assert pipeline.project_biases is True | |
| assert pipeline.use_chat_template is True | |
| assert pipeline.use_whitened_svd is True | |
| assert pipeline.true_iterative_refinement is True | |
| # --------------------------------------------------------------------------- | |
| # Bias projection | |
| # --------------------------------------------------------------------------- | |
| class TestBiasProjection: | |
| def test_project_bias_removes_component(self): | |
| """Bias projection should remove refusal direction component from bias.""" | |
| class Wrapper(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.o_proj = torch.nn.Linear(4, 4, bias=True) | |
| module = Wrapper() | |
| torch.manual_seed(42) | |
| module.o_proj.bias.data = torch.tensor([1.0, 2.0, 3.0, 4.0]) | |
| direction = torch.tensor([1.0, 0.0, 0.0, 0.0]).unsqueeze(-1) # unit vector along dim 0 | |
| count = AbliterationPipeline._project_bias(module, direction, ["o_proj"]) | |
| assert count == 1 | |
| # The component along direction [1,0,0,0] was 1.0, should now be ~0 | |
| new_bias = module.o_proj.bias.data | |
| projection_onto_dir = (new_bias @ direction.squeeze()).item() | |
| assert abs(projection_onto_dir) < 1e-5 | |
| # Other components should be unchanged | |
| assert abs(new_bias[1].item() - 2.0) < 1e-5 | |
| assert abs(new_bias[2].item() - 3.0) < 1e-5 | |
| assert abs(new_bias[3].item() - 4.0) < 1e-5 | |
| def test_project_bias_no_bias(self): | |
| """Should handle modules without bias gracefully.""" | |
| class Wrapper(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.o_proj = torch.nn.Linear(4, 4, bias=False) | |
| module = Wrapper() | |
| direction = torch.randn(4, 1) | |
| count = AbliterationPipeline._project_bias(module, direction, ["o_proj"]) | |
| assert count == 0 | |
| def test_project_bias_no_matching_module(self): | |
| """Should return 0 when no candidate names match.""" | |
| class Wrapper(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.something = torch.nn.Linear(4, 4, bias=True) | |
| module = Wrapper() | |
| direction = torch.randn(4, 1) | |
| count = AbliterationPipeline._project_bias(module, direction, ["o_proj"]) | |
| assert count == 0 | |
| # --------------------------------------------------------------------------- | |
| # Chat template wrapping | |
| # --------------------------------------------------------------------------- | |
| class TestChatTemplate: | |
| def test_no_wrap_when_disabled(self): | |
| """Should not wrap prompts when use_chat_template is False.""" | |
| pipeline = AbliterationPipeline( | |
| model_name="test-model", | |
| method="basic", | |
| use_chat_template=False, | |
| ) | |
| prompts = ["Hello", "World"] | |
| result = pipeline._maybe_apply_chat_template(prompts) | |
| assert result == prompts | |
| def test_no_wrap_without_handle(self): | |
| """Should return raw prompts when handle is not set.""" | |
| pipeline = AbliterationPipeline( | |
| model_name="test-model", | |
| use_chat_template=True, | |
| ) | |
| prompts = ["Hello"] | |
| result = pipeline._maybe_apply_chat_template(prompts) | |
| assert result == prompts | |
| def test_wraps_with_template(self): | |
| """Should wrap prompts when tokenizer has apply_chat_template.""" | |
| pipeline = AbliterationPipeline( | |
| model_name="test-model", | |
| use_chat_template=True, | |
| ) | |
| handle = MagicMock() | |
| tokenizer = MagicMock() | |
| def mock_apply(messages, tokenize=False, add_generation_prompt=True): | |
| return f"<user>{messages[0]['content']}</user><assistant>" | |
| tokenizer.apply_chat_template = mock_apply | |
| handle.tokenizer = tokenizer | |
| pipeline.handle = handle | |
| pipeline._on_log = lambda m: None | |
| result = pipeline._maybe_apply_chat_template(["Hello"]) | |
| assert "<user>Hello</user>" in result[0] | |
| def test_fallback_when_no_template(self): | |
| """Should fall back to raw prompts when template is not configured.""" | |
| pipeline = AbliterationPipeline( | |
| model_name="test-model", | |
| use_chat_template=True, | |
| ) | |
| handle = MagicMock() | |
| tokenizer = MagicMock() | |
| tokenizer.apply_chat_template.side_effect = Exception("No template") | |
| handle.tokenizer = tokenizer | |
| pipeline.handle = handle | |
| pipeline._on_log = lambda m: None | |
| result = pipeline._maybe_apply_chat_template(["Hello"]) | |
| assert result == ["Hello"] | |
| # --------------------------------------------------------------------------- | |
| # Metadata includes new fields | |
| # --------------------------------------------------------------------------- | |
| class TestMetadata: | |
| def test_rebirth_includes_new_config(self): | |
| """Metadata should include all new configuration parameters.""" | |
| import json | |
| handle = _make_tiny_handle() | |
| pipeline = AbliterationPipeline( | |
| model_name="test-model", | |
| method="aggressive", | |
| ) | |
| pipeline.handle = handle | |
| pipeline._on_log = lambda m: None | |
| pipeline._on_stage = lambda r: None | |
| pipeline._strong_layers = [0] | |
| pipeline._quality_metrics = {"perplexity": 8.5, "coherence": 1.0} | |
| handle.model.save_pretrained = MagicMock() | |
| handle.tokenizer.save_pretrained = MagicMock() | |
| import tempfile | |
| from pathlib import Path | |
| with tempfile.TemporaryDirectory() as tmp: | |
| pipeline.output_dir = Path(tmp) / "output" | |
| pipeline._rebirth() | |
| metadata = json.loads( | |
| (pipeline.output_dir / "abliteration_metadata.json").read_text() | |
| ) | |
| cfg = metadata["method_config"] | |
| assert "project_biases" in cfg | |
| assert "use_chat_template" in cfg | |
| assert "use_whitened_svd" in cfg | |
| assert "true_iterative_refinement" in cfg | |
| assert cfg["project_biases"] is True | |
| assert cfg["use_whitened_svd"] is True | |
| # Should have more references now | |
| assert len(metadata["references"]) >= 5 | |
| assert any("OBLITERATUS" in r for r in metadata["references"]) | |