likhonsheikh's picture
Upload folder using huggingface_hub
b9b1e87 verified
#!/usr/bin/env python3
"""
Unit tests for configuration components.
"""
import pytest
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from compact_ai_model.configs.config import (
Config, ModelConfig, InterleavedThinkingConfig, TrainingConfig, APIConfig,
load_config_from_dict, save_config_to_dict
)
class TestModelConfig:
"""Test ModelConfig functionality."""
def test_model_config_creation(self):
"""Test creating a ModelConfig."""
config = ModelConfig(dim=512, layers=12, heads=8, vocab_size=32000)
assert config.dim == 512
assert config.layers == 12
assert config.heads == 8
assert config.vocab_size == 32000
def test_model_config_defaults(self):
"""Test ModelConfig default values."""
config = ModelConfig()
assert config.model_size == "small"
assert config.dim == 512
assert config.layers == 12
class TestInterleavedThinkingConfig:
"""Test InterleavedThinkingConfig functionality."""
def test_thinking_config_creation(self):
"""Test creating an InterleavedThinkingConfig."""
config = InterleavedThinkingConfig(
max_reasoning_paths=4,
reasoning_depth=5,
early_stop_threshold=0.9
)
assert config.max_reasoning_paths == 4
assert config.reasoning_depth == 5
assert config.early_stop_threshold == 0.9
def test_thinking_config_defaults(self):
"""Test InterleavedThinkingConfig default values."""
config = InterleavedThinkingConfig()
assert config.max_reasoning_paths == 3
assert config.reasoning_depth == 4
assert 0 <= config.early_stop_threshold <= 1
def test_thinking_config_validation(self):
"""Test that config values are reasonable."""
config = InterleavedThinkingConfig()
assert config.max_reasoning_paths > 0
assert config.reasoning_depth > 0
assert config.token_budget > 0
class TestConfig:
"""Test main Config class."""
def test_balanced_config(self):
"""Test getting balanced config."""
config = Config.get_balanced_config()
assert isinstance(config.model, ModelConfig)
assert isinstance(config.thinking, InterleavedThinkingConfig)
assert isinstance(config.training, TrainingConfig)
assert isinstance(config.api, APIConfig)
def test_tiny_config(self):
"""Test getting tiny config."""
config = Config.get_tiny_config()
assert config.model.model_size == "tiny"
assert config.model.dim == 256
assert config.thinking.max_reasoning_paths == 2
def test_large_config(self):
"""Test getting large config."""
config = Config.get_large_config()
assert config.model.model_size == "medium"
assert config.model.dim == 768
assert config.thinking.max_reasoning_paths == 4
def test_config_serialization(self):
"""Test config save and load."""
config = Config.get_balanced_config()
config_dict = save_config_to_dict(config)
loaded_config = load_config_from_dict(config_dict)
assert loaded_config.model.dim == config.model.dim
assert loaded_config.thinking.max_reasoning_paths == config.thinking.max_reasoning_paths
assert loaded_config.training.learning_rate == config.training.learning_rate
class TestSerialization:
"""Test configuration serialization."""
def test_save_load_dict(self):
"""Test saving config to dict and loading back."""
config = Config.get_balanced_config()
config_dict = save_config_to_dict(config)
assert isinstance(config_dict, dict)
assert "model" in config_dict
assert "thinking" in config_dict
assert "training" in config_dict
assert "api" in config_dict
loaded_config = load_config_from_dict(config_dict)
assert isinstance(loaded_config, Config)
def test_serialization_consistency(self):
"""Test that serialization preserves values."""
original_config = Config.get_balanced_config()
# Serialize and deserialize
config_dict = save_config_to_dict(original_config)
loaded_config = load_config_from_dict(config_dict)
# Check key values
assert loaded_config.model.dim == original_config.model.dim
assert loaded_config.model.layers == original_config.model.layers
assert loaded_config.thinking.max_reasoning_paths == original_config.thinking.max_reasoning_paths
assert loaded_config.training.learning_rate == original_config.training.learning_rate
if __name__ == "__main__":
pytest.main([__file__])