| | |
| | """ |
| | Unit tests for configuration components. |
| | """ |
| |
|
| | import pytest |
| | import sys |
| | from pathlib import Path |
| |
|
| | |
| | project_root = Path(__file__).parent.parent.parent |
| | sys.path.insert(0, str(project_root)) |
| |
|
| | from compact_ai_model.configs.config import ( |
| | Config, ModelConfig, InterleavedThinkingConfig, TrainingConfig, APIConfig, |
| | load_config_from_dict, save_config_to_dict |
| | ) |
| |
|
| |
|
| | class TestModelConfig: |
| | """Test ModelConfig functionality.""" |
| |
|
| | def test_model_config_creation(self): |
| | """Test creating a ModelConfig.""" |
| | config = ModelConfig(dim=512, layers=12, heads=8, vocab_size=32000) |
| | assert config.dim == 512 |
| | assert config.layers == 12 |
| | assert config.heads == 8 |
| | assert config.vocab_size == 32000 |
| |
|
| | def test_model_config_defaults(self): |
| | """Test ModelConfig default values.""" |
| | config = ModelConfig() |
| | assert config.model_size == "small" |
| | assert config.dim == 512 |
| | assert config.layers == 12 |
| |
|
| |
|
| | class TestInterleavedThinkingConfig: |
| | """Test InterleavedThinkingConfig functionality.""" |
| |
|
| | def test_thinking_config_creation(self): |
| | """Test creating an InterleavedThinkingConfig.""" |
| | config = InterleavedThinkingConfig( |
| | max_reasoning_paths=4, |
| | reasoning_depth=5, |
| | early_stop_threshold=0.9 |
| | ) |
| | assert config.max_reasoning_paths == 4 |
| | assert config.reasoning_depth == 5 |
| | assert config.early_stop_threshold == 0.9 |
| |
|
| | def test_thinking_config_defaults(self): |
| | """Test InterleavedThinkingConfig default values.""" |
| | config = InterleavedThinkingConfig() |
| | assert config.max_reasoning_paths == 3 |
| | assert config.reasoning_depth == 4 |
| | assert 0 <= config.early_stop_threshold <= 1 |
| |
|
| | def test_thinking_config_validation(self): |
| | """Test that config values are reasonable.""" |
| | config = InterleavedThinkingConfig() |
| | assert config.max_reasoning_paths > 0 |
| | assert config.reasoning_depth > 0 |
| | assert config.token_budget > 0 |
| |
|
| |
|
| | class TestConfig: |
| | """Test main Config class.""" |
| |
|
| | def test_balanced_config(self): |
| | """Test getting balanced config.""" |
| | config = Config.get_balanced_config() |
| | assert isinstance(config.model, ModelConfig) |
| | assert isinstance(config.thinking, InterleavedThinkingConfig) |
| | assert isinstance(config.training, TrainingConfig) |
| | assert isinstance(config.api, APIConfig) |
| |
|
| | def test_tiny_config(self): |
| | """Test getting tiny config.""" |
| | config = Config.get_tiny_config() |
| | assert config.model.model_size == "tiny" |
| | assert config.model.dim == 256 |
| | assert config.thinking.max_reasoning_paths == 2 |
| |
|
| | def test_large_config(self): |
| | """Test getting large config.""" |
| | config = Config.get_large_config() |
| | assert config.model.model_size == "medium" |
| | assert config.model.dim == 768 |
| | assert config.thinking.max_reasoning_paths == 4 |
| |
|
| | def test_config_serialization(self): |
| | """Test config save and load.""" |
| | config = Config.get_balanced_config() |
| | config_dict = save_config_to_dict(config) |
| | loaded_config = load_config_from_dict(config_dict) |
| |
|
| | assert loaded_config.model.dim == config.model.dim |
| | assert loaded_config.thinking.max_reasoning_paths == config.thinking.max_reasoning_paths |
| | assert loaded_config.training.learning_rate == config.training.learning_rate |
| |
|
| |
|
| | class TestSerialization: |
| | """Test configuration serialization.""" |
| |
|
| | def test_save_load_dict(self): |
| | """Test saving config to dict and loading back.""" |
| | config = Config.get_balanced_config() |
| | config_dict = save_config_to_dict(config) |
| |
|
| | assert isinstance(config_dict, dict) |
| | assert "model" in config_dict |
| | assert "thinking" in config_dict |
| | assert "training" in config_dict |
| | assert "api" in config_dict |
| |
|
| | loaded_config = load_config_from_dict(config_dict) |
| | assert isinstance(loaded_config, Config) |
| |
|
| | def test_serialization_consistency(self): |
| | """Test that serialization preserves values.""" |
| | original_config = Config.get_balanced_config() |
| |
|
| | |
| | config_dict = save_config_to_dict(original_config) |
| | loaded_config = load_config_from_dict(config_dict) |
| |
|
| | |
| | assert loaded_config.model.dim == original_config.model.dim |
| | assert loaded_config.model.layers == original_config.model.layers |
| | assert loaded_config.thinking.max_reasoning_paths == original_config.thinking.max_reasoning_paths |
| | assert loaded_config.training.learning_rate == original_config.training.learning_rate |
| |
|
| |
|
| | if __name__ == "__main__": |
| | pytest.main([__file__]) |