#!/usr/bin/env python3 """ Unit tests for configuration components. """ import pytest import sys from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from compact_ai_model.configs.config import ( Config, ModelConfig, InterleavedThinkingConfig, TrainingConfig, APIConfig, load_config_from_dict, save_config_to_dict ) class TestModelConfig: """Test ModelConfig functionality.""" def test_model_config_creation(self): """Test creating a ModelConfig.""" config = ModelConfig(dim=512, layers=12, heads=8, vocab_size=32000) assert config.dim == 512 assert config.layers == 12 assert config.heads == 8 assert config.vocab_size == 32000 def test_model_config_defaults(self): """Test ModelConfig default values.""" config = ModelConfig() assert config.model_size == "small" assert config.dim == 512 assert config.layers == 12 class TestInterleavedThinkingConfig: """Test InterleavedThinkingConfig functionality.""" def test_thinking_config_creation(self): """Test creating an InterleavedThinkingConfig.""" config = InterleavedThinkingConfig( max_reasoning_paths=4, reasoning_depth=5, early_stop_threshold=0.9 ) assert config.max_reasoning_paths == 4 assert config.reasoning_depth == 5 assert config.early_stop_threshold == 0.9 def test_thinking_config_defaults(self): """Test InterleavedThinkingConfig default values.""" config = InterleavedThinkingConfig() assert config.max_reasoning_paths == 3 assert config.reasoning_depth == 4 assert 0 <= config.early_stop_threshold <= 1 def test_thinking_config_validation(self): """Test that config values are reasonable.""" config = InterleavedThinkingConfig() assert config.max_reasoning_paths > 0 assert config.reasoning_depth > 0 assert config.token_budget > 0 class TestConfig: """Test main Config class.""" def test_balanced_config(self): """Test getting balanced config.""" config = Config.get_balanced_config() assert isinstance(config.model, ModelConfig) assert isinstance(config.thinking, InterleavedThinkingConfig) assert isinstance(config.training, TrainingConfig) assert isinstance(config.api, APIConfig) def test_tiny_config(self): """Test getting tiny config.""" config = Config.get_tiny_config() assert config.model.model_size == "tiny" assert config.model.dim == 256 assert config.thinking.max_reasoning_paths == 2 def test_large_config(self): """Test getting large config.""" config = Config.get_large_config() assert config.model.model_size == "medium" assert config.model.dim == 768 assert config.thinking.max_reasoning_paths == 4 def test_config_serialization(self): """Test config save and load.""" config = Config.get_balanced_config() config_dict = save_config_to_dict(config) loaded_config = load_config_from_dict(config_dict) assert loaded_config.model.dim == config.model.dim assert loaded_config.thinking.max_reasoning_paths == config.thinking.max_reasoning_paths assert loaded_config.training.learning_rate == config.training.learning_rate class TestSerialization: """Test configuration serialization.""" def test_save_load_dict(self): """Test saving config to dict and loading back.""" config = Config.get_balanced_config() config_dict = save_config_to_dict(config) assert isinstance(config_dict, dict) assert "model" in config_dict assert "thinking" in config_dict assert "training" in config_dict assert "api" in config_dict loaded_config = load_config_from_dict(config_dict) assert isinstance(loaded_config, Config) def test_serialization_consistency(self): """Test that serialization preserves values.""" original_config = Config.get_balanced_config() # Serialize and deserialize config_dict = save_config_to_dict(original_config) loaded_config = load_config_from_dict(config_dict) # Check key values assert loaded_config.model.dim == original_config.model.dim assert loaded_config.model.layers == original_config.model.layers assert loaded_config.thinking.max_reasoning_paths == original_config.thinking.max_reasoning_paths assert loaded_config.training.learning_rate == original_config.training.learning_rate if __name__ == "__main__": pytest.main([__file__])