""" Unit tests for configuration management. """ import pytest import tempfile from pathlib import Path import yaml from config import ( KerdosConfig, LoRAConfig, QuantizationConfig, TrainingConfig, DataConfig, load_config ) from exceptions import ConfigurationError class TestLoRAConfig: """Tests for LoRA configuration.""" def test_default_config(self): """Test default LoRA configuration.""" config = LoRAConfig() assert config.enabled is True assert config.r == 8 assert config.alpha == 32 assert config.dropout == 0.1 def test_custom_config(self): """Test custom LoRA configuration.""" config = LoRAConfig(r=16, alpha=64, dropout=0.2) assert config.r == 16 assert config.alpha == 64 assert config.dropout == 0.2 def test_invalid_r(self): """Test invalid LoRA rank.""" with pytest.raises(Exception): LoRAConfig(r=0) with pytest.raises(Exception): LoRAConfig(r=300) def test_invalid_dropout(self): """Test invalid dropout value.""" with pytest.raises(Exception): LoRAConfig(dropout=-0.1) with pytest.raises(Exception): LoRAConfig(dropout=1.5) class TestQuantizationConfig: """Tests for quantization configuration.""" def test_default_config(self): """Test default quantization configuration.""" config = QuantizationConfig() assert config.enabled is False assert config.bits == 4 assert config.use_double_quant is True def test_4bit_config(self): """Test 4-bit quantization.""" config = QuantizationConfig(enabled=True, bits=4) assert config.enabled is True assert config.bits == 4 def test_8bit_config(self): """Test 8-bit quantization.""" config = QuantizationConfig(enabled=True, bits=8) assert config.bits == 8 def test_invalid_bits(self): """Test invalid bit configuration.""" with pytest.raises(Exception): QuantizationConfig(bits=2) with pytest.raises(Exception): QuantizationConfig(bits=16) class TestTrainingConfig: """Tests for training configuration.""" def test_default_config(self): """Test default training configuration.""" config = TrainingConfig() assert config.epochs == 3 assert config.batch_size == 4 assert config.learning_rate == 2e-5 def test_custom_config(self): """Test custom training configuration.""" config = TrainingConfig( epochs=10, batch_size=8, learning_rate=1e-4 ) assert config.epochs == 10 assert config.batch_size == 8 assert config.learning_rate == 1e-4 def test_precision_conflict(self): """Test that fp16 and bf16 cannot both be enabled.""" with pytest.raises(Exception): TrainingConfig(fp16=True, bf16=True) class TestDataConfig: """Tests for data configuration.""" def test_default_config(self): """Test default data configuration.""" config = DataConfig() assert config.text_column == "text" assert config.preprocessing_num_workers == 4 def test_with_files(self): """Test data configuration with file paths.""" config = DataConfig( train_file="train.json", validation_file="val.json" ) assert config.train_file == "train.json" assert config.validation_file == "val.json" class TestKerdosConfig: """Tests for main KerdosAI configuration.""" def test_default_config(self): """Test default configuration.""" config = KerdosConfig(base_model="gpt2") assert config.base_model == "gpt2" assert config.lora.enabled is True assert config.quantization.enabled is False def test_yaml_roundtrip(self): """Test saving and loading YAML configuration.""" with tempfile.TemporaryDirectory() as tmpdir: config_path = Path(tmpdir) / "config.yaml" # Create and save config config = KerdosConfig( base_model="gpt2", output_dir=tmpdir ) config.to_yaml(config_path) # Load config loaded_config = KerdosConfig.from_yaml(config_path) assert loaded_config.base_model == config.base_model assert loaded_config.output_dir == config.output_dir def test_invalid_yaml_file(self): """Test loading from non-existent file.""" with pytest.raises(ConfigurationError): KerdosConfig.from_yaml("nonexistent.yaml") def test_validation_no_data_source(self): """Test validation fails when no data source is specified.""" config = KerdosConfig(base_model="gpt2") with pytest.raises(ConfigurationError): config.validate_compatibility() def test_validation_fp16_on_cpu(self): """Test validation fails for fp16 on CPU.""" config = KerdosConfig( base_model="gpt2", device="cpu", data=DataConfig(train_file="train.json") ) config.training.fp16 = True with pytest.raises(ConfigurationError): config.validate_compatibility() def test_env_var_substitution(self): """Test environment variable substitution.""" import os os.environ["TEST_MODEL"] = "test-model" config_dict = { "base_model": "${TEST_MODEL}", "output_dir": "./output" } result = KerdosConfig._substitute_env_vars(config_dict) assert result["base_model"] == "test-model" class TestLoadConfig: """Tests for config loading helper.""" def test_load_default_config(self): """Test loading default configuration.""" config = load_config() assert config.base_model == "gpt2" def test_load_from_file(self): """Test loading from file.""" with tempfile.TemporaryDirectory() as tmpdir: config_path = Path(tmpdir) / "config.yaml" # Create config file config_data = { "base_model": "test-model", "output_dir": tmpdir } with open(config_path, 'w') as f: yaml.dump(config_data, f) # Load config config = load_config(config_path) assert config.base_model == "test-model"