File size: 4,396 Bytes
2bc8e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""Tests for ablation presets."""

from __future__ import annotations

from obliteratus.study_presets import (
    STUDY_PRESETS,
    get_study_preset,
    get_preset,
    list_study_presets,
    list_presets,
)
from obliteratus.config import StudyConfig


class TestPresets:
    def test_all_presets_registered(self):
        expected_keys = {"quick", "full", "attention", "layers", "knowledge", "pruning", "embeddings", "jailbreak", "guardrail", "robustness"}
        assert expected_keys.issubset(set(STUDY_PRESETS.keys()))

    def test_get_preset(self):
        preset = get_study_preset("quick")
        assert preset.name == "Quick Scan"
        assert preset.key == "quick"
        assert len(preset.strategies) == 2

    def test_get_preset_alias(self):
        preset = get_preset("quick")
        assert preset.name == "Quick Scan"

    def test_get_unknown_preset_raises(self):
        import pytest
        with pytest.raises(KeyError, match="Unknown preset"):
            get_study_preset("nonexistent")

    def test_list_presets(self):
        presets = list_study_presets()
        assert len(presets) >= 7
        keys = [p.key for p in presets]
        assert "quick" in keys
        assert "full" in keys

    def test_list_presets_alias(self):
        assert list_presets() == list_study_presets()

    def test_preset_strategies_are_valid(self):
        from obliteratus.strategies import STRATEGY_REGISTRY
        for preset in list_study_presets():
            for s in preset.strategies:
                assert s["name"] in STRATEGY_REGISTRY, (
                    f"Preset {preset.key!r} references unknown strategy {s['name']!r}"
                )


class TestConfigWithPreset:
    def test_preset_key_in_config(self):
        config_dict = {
            "preset": "quick",
            "model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"},
            "dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text"},
        }
        config = StudyConfig.from_dict(config_dict)
        # Should inherit strategies from the quick preset
        assert len(config.strategies) == 2
        strategy_names = [s.name for s in config.strategies]
        assert "layer_removal" in strategy_names
        assert "ffn_ablation" in strategy_names
        # Should inherit max_samples
        assert config.dataset.max_samples == 25
        # Should inherit batch_size and max_length
        assert config.batch_size == 4
        assert config.max_length == 128

    def test_legacy_study_preset_key_still_works(self):
        config_dict = {
            "study_preset": "quick",
            "model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"},
            "dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text"},
        }
        config = StudyConfig.from_dict(config_dict)
        assert len(config.strategies) == 2

    def test_preset_can_be_overridden(self):
        config_dict = {
            "preset": "quick",
            "model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"},
            "dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text", "max_samples": 999},
            "batch_size": 16,
            "strategies": [{"name": "head_pruning", "params": {}}],
        }
        config = StudyConfig.from_dict(config_dict)
        # Explicit strategies should override preset
        assert len(config.strategies) == 1
        assert config.strategies[0].name == "head_pruning"
        # Explicit batch_size should override
        assert config.batch_size == 16
        # Explicit max_samples in dataset should be kept
        assert config.dataset.max_samples == 999

    def test_full_preset(self):
        config_dict = {
            "preset": "full",
            "model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"},
            "dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text"},
        }
        config = StudyConfig.from_dict(config_dict)
        assert len(config.strategies) == 4
        strategy_names = {s.name for s in config.strategies}
        assert strategy_names == {"layer_removal", "head_pruning", "ffn_ablation", "embedding_ablation"}