"""Tests for ablation presets.""" from __future__ import annotations from obliteratus.study_presets import ( STUDY_PRESETS, get_study_preset, get_preset, list_study_presets, list_presets, ) from obliteratus.config import StudyConfig class TestPresets: def test_all_presets_registered(self): expected_keys = {"quick", "full", "attention", "layers", "knowledge", "pruning", "embeddings", "jailbreak", "guardrail", "robustness"} assert expected_keys.issubset(set(STUDY_PRESETS.keys())) def test_get_preset(self): preset = get_study_preset("quick") assert preset.name == "Quick Scan" assert preset.key == "quick" assert len(preset.strategies) == 2 def test_get_preset_alias(self): preset = get_preset("quick") assert preset.name == "Quick Scan" def test_get_unknown_preset_raises(self): import pytest with pytest.raises(KeyError, match="Unknown preset"): get_study_preset("nonexistent") def test_list_presets(self): presets = list_study_presets() assert len(presets) >= 7 keys = [p.key for p in presets] assert "quick" in keys assert "full" in keys def test_list_presets_alias(self): assert list_presets() == list_study_presets() def test_preset_strategies_are_valid(self): from obliteratus.strategies import STRATEGY_REGISTRY for preset in list_study_presets(): for s in preset.strategies: assert s["name"] in STRATEGY_REGISTRY, ( f"Preset {preset.key!r} references unknown strategy {s['name']!r}" ) class TestConfigWithPreset: def test_preset_key_in_config(self): config_dict = { "preset": "quick", "model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"}, "dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text"}, } config = StudyConfig.from_dict(config_dict) # Should inherit strategies from the quick preset assert len(config.strategies) == 2 strategy_names = [s.name for s in config.strategies] assert "layer_removal" in strategy_names assert "ffn_ablation" in strategy_names # Should inherit max_samples assert config.dataset.max_samples == 25 # Should inherit batch_size and max_length assert config.batch_size == 4 assert config.max_length == 128 def test_legacy_study_preset_key_still_works(self): config_dict = { "study_preset": "quick", "model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"}, "dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text"}, } config = StudyConfig.from_dict(config_dict) assert len(config.strategies) == 2 def test_preset_can_be_overridden(self): config_dict = { "preset": "quick", "model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"}, "dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text", "max_samples": 999}, "batch_size": 16, "strategies": [{"name": "head_pruning", "params": {}}], } config = StudyConfig.from_dict(config_dict) # Explicit strategies should override preset assert len(config.strategies) == 1 assert config.strategies[0].name == "head_pruning" # Explicit batch_size should override assert config.batch_size == 16 # Explicit max_samples in dataset should be kept assert config.dataset.max_samples == 999 def test_full_preset(self): config_dict = { "preset": "full", "model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"}, "dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text"}, } config = StudyConfig.from_dict(config_dict) assert len(config.strategies) == 4 strategy_names = {s.name for s in config.strategies} assert strategy_names == {"layer_removal", "head_pruning", "ffn_ablation", "embedding_ablation"}