obliteratus / tests /test_abliterate_extended.py
pliny-the-prompter's picture
Upload 127 files
45113e6 verified
"""Extended tests for novel abliteration pipeline features.
Tests the new capabilities added to the OBLITERATUS abliteration pipeline:
- Bias projection
- Chat template wrapping
- Method presets with new parameters
- True iterative refinement
- Whitened SVD integration
"""
from __future__ import annotations
from unittest.mock import MagicMock
import torch
from transformers import GPT2Config, GPT2LMHeadModel
from obliteratus.abliterate import (
METHODS,
AbliterationPipeline,
)
from obliteratus.models.loader import ModelHandle
def _make_tiny_handle():
"""Create a minimal ModelHandle with a tiny GPT-2 for testing."""
config = GPT2Config(
vocab_size=1000,
n_positions=128,
n_embd=64,
n_layer=4,
n_head=2,
n_inner=256,
)
model = GPT2LMHeadModel(config)
model.eval()
tokenizer = MagicMock()
tokenizer.pad_token = "<pad>"
tokenizer.eos_token = "<eos>"
tokenizer.return_value = {
"input_ids": torch.randint(0, 1000, (1, 10)),
"attention_mask": torch.ones(1, 10, dtype=torch.long),
}
tokenizer.decode.return_value = "The capital of France is Paris, a beautiful city"
handle = ModelHandle(
model=model,
tokenizer=tokenizer,
config=config,
model_name="gpt2-test",
task="causal_lm",
)
handle.snapshot()
return handle
def _make_varied_tokenizer(handle):
"""Set up a tokenizer mock that returns different tokens per call."""
call_count = [0]
def mock_tokenizer(prompt, **kwargs):
call_count[0] += 1
torch.manual_seed(call_count[0])
return {
"input_ids": torch.randint(0, 1000, (1, 5)),
"attention_mask": torch.ones(1, 5, dtype=torch.long),
}
handle.tokenizer.side_effect = mock_tokenizer
# ---------------------------------------------------------------------------
# New method preset parameters
# ---------------------------------------------------------------------------
class TestNewMethodPresets:
def test_basic_has_new_params(self):
cfg = METHODS["basic"]
assert "project_biases" in cfg
assert "use_chat_template" in cfg
assert "use_whitened_svd" in cfg
assert "true_iterative_refinement" in cfg
assert cfg["project_biases"] is False
assert cfg["use_chat_template"] is False
def test_advanced_has_new_params(self):
cfg = METHODS["advanced"]
assert cfg["project_biases"] is True
assert cfg["use_chat_template"] is True
assert cfg["use_whitened_svd"] is False
assert cfg["true_iterative_refinement"] is False
def test_aggressive_has_new_params(self):
cfg = METHODS["aggressive"]
assert cfg["project_biases"] is True
assert cfg["use_chat_template"] is True
assert cfg["use_whitened_svd"] is True
assert cfg["true_iterative_refinement"] is True
# ---------------------------------------------------------------------------
# Pipeline initialization with new parameters
# ---------------------------------------------------------------------------
class TestNewPipelineInit:
def test_default_new_params(self):
pipeline = AbliterationPipeline(model_name="test-model")
# advanced method defaults
assert pipeline.project_biases is True
assert pipeline.use_chat_template is True
assert pipeline.use_whitened_svd is False
assert pipeline.true_iterative_refinement is False
def test_basic_method_new_params(self):
pipeline = AbliterationPipeline(model_name="test-model", method="basic")
assert pipeline.project_biases is False
assert pipeline.use_chat_template is False
assert pipeline.use_whitened_svd is False
assert pipeline.true_iterative_refinement is False
def test_aggressive_method_new_params(self):
pipeline = AbliterationPipeline(model_name="test-model", method="aggressive")
assert pipeline.project_biases is True
assert pipeline.use_chat_template is True
assert pipeline.use_whitened_svd is True
assert pipeline.true_iterative_refinement is True
def test_explicit_overrides_new_params(self):
pipeline = AbliterationPipeline(
model_name="test-model",
method="basic",
project_biases=True,
use_chat_template=True,
use_whitened_svd=True,
true_iterative_refinement=True,
)
assert pipeline.project_biases is True
assert pipeline.use_chat_template is True
assert pipeline.use_whitened_svd is True
assert pipeline.true_iterative_refinement is True
# ---------------------------------------------------------------------------
# Bias projection
# ---------------------------------------------------------------------------
class TestBiasProjection:
def test_project_bias_removes_component(self):
"""Bias projection should remove refusal direction component from bias."""
class Wrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.o_proj = torch.nn.Linear(4, 4, bias=True)
module = Wrapper()
torch.manual_seed(42)
module.o_proj.bias.data = torch.tensor([1.0, 2.0, 3.0, 4.0])
direction = torch.tensor([1.0, 0.0, 0.0, 0.0]).unsqueeze(-1) # unit vector along dim 0
count = AbliterationPipeline._project_bias(module, direction, ["o_proj"])
assert count == 1
# The component along direction [1,0,0,0] was 1.0, should now be ~0
new_bias = module.o_proj.bias.data
projection_onto_dir = (new_bias @ direction.squeeze()).item()
assert abs(projection_onto_dir) < 1e-5
# Other components should be unchanged
assert abs(new_bias[1].item() - 2.0) < 1e-5
assert abs(new_bias[2].item() - 3.0) < 1e-5
assert abs(new_bias[3].item() - 4.0) < 1e-5
def test_project_bias_no_bias(self):
"""Should handle modules without bias gracefully."""
class Wrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.o_proj = torch.nn.Linear(4, 4, bias=False)
module = Wrapper()
direction = torch.randn(4, 1)
count = AbliterationPipeline._project_bias(module, direction, ["o_proj"])
assert count == 0
def test_project_bias_no_matching_module(self):
"""Should return 0 when no candidate names match."""
class Wrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.something = torch.nn.Linear(4, 4, bias=True)
module = Wrapper()
direction = torch.randn(4, 1)
count = AbliterationPipeline._project_bias(module, direction, ["o_proj"])
assert count == 0
# ---------------------------------------------------------------------------
# Chat template wrapping
# ---------------------------------------------------------------------------
class TestChatTemplate:
def test_no_wrap_when_disabled(self):
"""Should not wrap prompts when use_chat_template is False."""
pipeline = AbliterationPipeline(
model_name="test-model",
method="basic",
use_chat_template=False,
)
prompts = ["Hello", "World"]
result = pipeline._maybe_apply_chat_template(prompts)
assert result == prompts
def test_no_wrap_without_handle(self):
"""Should return raw prompts when handle is not set."""
pipeline = AbliterationPipeline(
model_name="test-model",
use_chat_template=True,
)
prompts = ["Hello"]
result = pipeline._maybe_apply_chat_template(prompts)
assert result == prompts
def test_wraps_with_template(self):
"""Should wrap prompts when tokenizer has apply_chat_template."""
pipeline = AbliterationPipeline(
model_name="test-model",
use_chat_template=True,
)
handle = MagicMock()
tokenizer = MagicMock()
def mock_apply(messages, tokenize=False, add_generation_prompt=True):
return f"<user>{messages[0]['content']}</user><assistant>"
tokenizer.apply_chat_template = mock_apply
handle.tokenizer = tokenizer
pipeline.handle = handle
pipeline._on_log = lambda m: None
result = pipeline._maybe_apply_chat_template(["Hello"])
assert "<user>Hello</user>" in result[0]
def test_fallback_when_no_template(self):
"""Should fall back to raw prompts when template is not configured."""
pipeline = AbliterationPipeline(
model_name="test-model",
use_chat_template=True,
)
handle = MagicMock()
tokenizer = MagicMock()
tokenizer.apply_chat_template.side_effect = Exception("No template")
handle.tokenizer = tokenizer
pipeline.handle = handle
pipeline._on_log = lambda m: None
result = pipeline._maybe_apply_chat_template(["Hello"])
assert result == ["Hello"]
# ---------------------------------------------------------------------------
# Metadata includes new fields
# ---------------------------------------------------------------------------
class TestMetadata:
def test_rebirth_includes_new_config(self):
"""Metadata should include all new configuration parameters."""
import json
handle = _make_tiny_handle()
pipeline = AbliterationPipeline(
model_name="test-model",
method="aggressive",
)
pipeline.handle = handle
pipeline._on_log = lambda m: None
pipeline._on_stage = lambda r: None
pipeline._strong_layers = [0]
pipeline._quality_metrics = {"perplexity": 8.5, "coherence": 1.0}
handle.model.save_pretrained = MagicMock()
handle.tokenizer.save_pretrained = MagicMock()
import tempfile
from pathlib import Path
with tempfile.TemporaryDirectory() as tmp:
pipeline.output_dir = Path(tmp) / "output"
pipeline._rebirth()
metadata = json.loads(
(pipeline.output_dir / "abliteration_metadata.json").read_text()
)
cfg = metadata["method_config"]
assert "project_biases" in cfg
assert "use_chat_template" in cfg
assert "use_whitened_svd" in cfg
assert "true_iterative_refinement" in cfg
assert cfg["project_biases"] is True
assert cfg["use_whitened_svd"] is True
# Should have more references now
assert len(metadata["references"]) >= 5
assert any("OBLITERATUS" in r for r in metadata["references"])