Spaces:
Running on Zero
Running on Zero
File size: 4,396 Bytes
2bc8e46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | """Tests for ablation presets."""
from __future__ import annotations
from obliteratus.study_presets import (
STUDY_PRESETS,
get_study_preset,
get_preset,
list_study_presets,
list_presets,
)
from obliteratus.config import StudyConfig
class TestPresets:
def test_all_presets_registered(self):
expected_keys = {"quick", "full", "attention", "layers", "knowledge", "pruning", "embeddings", "jailbreak", "guardrail", "robustness"}
assert expected_keys.issubset(set(STUDY_PRESETS.keys()))
def test_get_preset(self):
preset = get_study_preset("quick")
assert preset.name == "Quick Scan"
assert preset.key == "quick"
assert len(preset.strategies) == 2
def test_get_preset_alias(self):
preset = get_preset("quick")
assert preset.name == "Quick Scan"
def test_get_unknown_preset_raises(self):
import pytest
with pytest.raises(KeyError, match="Unknown preset"):
get_study_preset("nonexistent")
def test_list_presets(self):
presets = list_study_presets()
assert len(presets) >= 7
keys = [p.key for p in presets]
assert "quick" in keys
assert "full" in keys
def test_list_presets_alias(self):
assert list_presets() == list_study_presets()
def test_preset_strategies_are_valid(self):
from obliteratus.strategies import STRATEGY_REGISTRY
for preset in list_study_presets():
for s in preset.strategies:
assert s["name"] in STRATEGY_REGISTRY, (
f"Preset {preset.key!r} references unknown strategy {s['name']!r}"
)
class TestConfigWithPreset:
def test_preset_key_in_config(self):
config_dict = {
"preset": "quick",
"model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"},
"dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text"},
}
config = StudyConfig.from_dict(config_dict)
# Should inherit strategies from the quick preset
assert len(config.strategies) == 2
strategy_names = [s.name for s in config.strategies]
assert "layer_removal" in strategy_names
assert "ffn_ablation" in strategy_names
# Should inherit max_samples
assert config.dataset.max_samples == 25
# Should inherit batch_size and max_length
assert config.batch_size == 4
assert config.max_length == 128
def test_legacy_study_preset_key_still_works(self):
config_dict = {
"study_preset": "quick",
"model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"},
"dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text"},
}
config = StudyConfig.from_dict(config_dict)
assert len(config.strategies) == 2
def test_preset_can_be_overridden(self):
config_dict = {
"preset": "quick",
"model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"},
"dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text", "max_samples": 999},
"batch_size": 16,
"strategies": [{"name": "head_pruning", "params": {}}],
}
config = StudyConfig.from_dict(config_dict)
# Explicit strategies should override preset
assert len(config.strategies) == 1
assert config.strategies[0].name == "head_pruning"
# Explicit batch_size should override
assert config.batch_size == 16
# Explicit max_samples in dataset should be kept
assert config.dataset.max_samples == 999
def test_full_preset(self):
config_dict = {
"preset": "full",
"model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"},
"dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text"},
}
config = StudyConfig.from_dict(config_dict)
assert len(config.strategies) == 4
strategy_names = {s.name for s in config.strategies}
assert strategy_names == {"layer_removal", "head_pruning", "ffn_ablation", "embedding_ablation"}
|