"""Tests for configuration loading.""" from __future__ import annotations import yaml from obliteratus.config import StudyConfig SAMPLE_CONFIG = { "model": { "name": "gpt2", "task": "causal_lm", "dtype": "float32", "device": "cpu", }, "dataset": { "name": "wikitext", "subset": "wikitext-2-raw-v1", "split": "test", "text_column": "text", "max_samples": 50, }, "strategies": [ {"name": "layer_removal", "params": {}}, {"name": "ffn_ablation", "params": {}}, ], "metrics": ["perplexity"], "batch_size": 4, "max_length": 256, "output_dir": "results/test", } class TestStudyConfig: def test_from_dict(self): config = StudyConfig.from_dict(SAMPLE_CONFIG) assert config.model.name == "gpt2" assert config.model.task == "causal_lm" assert config.dataset.name == "wikitext" assert len(config.strategies) == 2 assert config.strategies[0].name == "layer_removal" def test_from_yaml(self, tmp_path): yaml_path = tmp_path / "test_config.yaml" yaml_path.write_text(yaml.dump(SAMPLE_CONFIG)) config = StudyConfig.from_yaml(yaml_path) assert config.model.name == "gpt2" assert config.batch_size == 4 def test_roundtrip(self): config = StudyConfig.from_dict(SAMPLE_CONFIG) d = config.to_dict() config2 = StudyConfig.from_dict(d) assert config2.model.name == config.model.name assert config2.dataset.name == config.dataset.name assert len(config2.strategies) == len(config.strategies)