Spaces:
Running on Zero
Running on Zero
| """Tests for ablation presets.""" | |
| from __future__ import annotations | |
| from obliteratus.study_presets import ( | |
| STUDY_PRESETS, | |
| get_study_preset, | |
| get_preset, | |
| list_study_presets, | |
| list_presets, | |
| ) | |
| from obliteratus.config import StudyConfig | |
| class TestPresets: | |
| def test_all_presets_registered(self): | |
| expected_keys = {"quick", "full", "attention", "layers", "knowledge", "pruning", "embeddings", "jailbreak", "guardrail", "robustness"} | |
| assert expected_keys.issubset(set(STUDY_PRESETS.keys())) | |
| def test_get_preset(self): | |
| preset = get_study_preset("quick") | |
| assert preset.name == "Quick Scan" | |
| assert preset.key == "quick" | |
| assert len(preset.strategies) == 2 | |
| def test_get_preset_alias(self): | |
| preset = get_preset("quick") | |
| assert preset.name == "Quick Scan" | |
| def test_get_unknown_preset_raises(self): | |
| import pytest | |
| with pytest.raises(KeyError, match="Unknown preset"): | |
| get_study_preset("nonexistent") | |
| def test_list_presets(self): | |
| presets = list_study_presets() | |
| assert len(presets) >= 7 | |
| keys = [p.key for p in presets] | |
| assert "quick" in keys | |
| assert "full" in keys | |
| def test_list_presets_alias(self): | |
| assert list_presets() == list_study_presets() | |
| def test_preset_strategies_are_valid(self): | |
| from obliteratus.strategies import STRATEGY_REGISTRY | |
| for preset in list_study_presets(): | |
| for s in preset.strategies: | |
| assert s["name"] in STRATEGY_REGISTRY, ( | |
| f"Preset {preset.key!r} references unknown strategy {s['name']!r}" | |
| ) | |
| class TestConfigWithPreset: | |
| def test_preset_key_in_config(self): | |
| config_dict = { | |
| "preset": "quick", | |
| "model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"}, | |
| "dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text"}, | |
| } | |
| config = StudyConfig.from_dict(config_dict) | |
| # Should inherit strategies from the quick preset | |
| assert len(config.strategies) == 2 | |
| strategy_names = [s.name for s in config.strategies] | |
| assert "layer_removal" in strategy_names | |
| assert "ffn_ablation" in strategy_names | |
| # Should inherit max_samples | |
| assert config.dataset.max_samples == 25 | |
| # Should inherit batch_size and max_length | |
| assert config.batch_size == 4 | |
| assert config.max_length == 128 | |
| def test_legacy_study_preset_key_still_works(self): | |
| config_dict = { | |
| "study_preset": "quick", | |
| "model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"}, | |
| "dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text"}, | |
| } | |
| config = StudyConfig.from_dict(config_dict) | |
| assert len(config.strategies) == 2 | |
| def test_preset_can_be_overridden(self): | |
| config_dict = { | |
| "preset": "quick", | |
| "model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"}, | |
| "dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text", "max_samples": 999}, | |
| "batch_size": 16, | |
| "strategies": [{"name": "head_pruning", "params": {}}], | |
| } | |
| config = StudyConfig.from_dict(config_dict) | |
| # Explicit strategies should override preset | |
| assert len(config.strategies) == 1 | |
| assert config.strategies[0].name == "head_pruning" | |
| # Explicit batch_size should override | |
| assert config.batch_size == 16 | |
| # Explicit max_samples in dataset should be kept | |
| assert config.dataset.max_samples == 999 | |
| def test_full_preset(self): | |
| config_dict = { | |
| "preset": "full", | |
| "model": {"name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu"}, | |
| "dataset": {"name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text"}, | |
| } | |
| config = StudyConfig.from_dict(config_dict) | |
| assert len(config.strategies) == 4 | |
| strategy_names = {s.name for s in config.strategies} | |
| assert strategy_names == {"layer_removal", "head_pruning", "ffn_ablation", "embedding_ablation"} | |