obliteratus / tests /test_config.py
pliny-the-prompter's picture
Upload 127 files
45113e6 verified
"""Tests for configuration loading."""
from __future__ import annotations
import yaml
from obliteratus.config import StudyConfig
SAMPLE_CONFIG = {
"model": {
"name": "gpt2",
"task": "causal_lm",
"dtype": "float32",
"device": "cpu",
},
"dataset": {
"name": "wikitext",
"subset": "wikitext-2-raw-v1",
"split": "test",
"text_column": "text",
"max_samples": 50,
},
"strategies": [
{"name": "layer_removal", "params": {}},
{"name": "ffn_ablation", "params": {}},
],
"metrics": ["perplexity"],
"batch_size": 4,
"max_length": 256,
"output_dir": "results/test",
}
class TestStudyConfig:
def test_from_dict(self):
config = StudyConfig.from_dict(SAMPLE_CONFIG)
assert config.model.name == "gpt2"
assert config.model.task == "causal_lm"
assert config.dataset.name == "wikitext"
assert len(config.strategies) == 2
assert config.strategies[0].name == "layer_removal"
def test_from_yaml(self, tmp_path):
yaml_path = tmp_path / "test_config.yaml"
yaml_path.write_text(yaml.dump(SAMPLE_CONFIG))
config = StudyConfig.from_yaml(yaml_path)
assert config.model.name == "gpt2"
assert config.batch_size == 4
def test_roundtrip(self):
config = StudyConfig.from_dict(SAMPLE_CONFIG)
d = config.to_dict()
config2 = StudyConfig.from_dict(d)
assert config2.model.name == config.model.name
assert config2.dataset.name == config.dataset.name
assert len(config2.strategies) == len(config.strategies)