Anonymous Hunter
feat: Add robust configuration management, Docker support, initial testing, and quickstart documentation.
f21249a
| """ | |
| Unit tests for configuration management. | |
| """ | |
| import pytest | |
| import tempfile | |
| from pathlib import Path | |
| import yaml | |
| from config import ( | |
| KerdosConfig, | |
| LoRAConfig, | |
| QuantizationConfig, | |
| TrainingConfig, | |
| DataConfig, | |
| load_config | |
| ) | |
| from exceptions import ConfigurationError | |
| class TestLoRAConfig: | |
| """Tests for LoRA configuration.""" | |
| def test_default_config(self): | |
| """Test default LoRA configuration.""" | |
| config = LoRAConfig() | |
| assert config.enabled is True | |
| assert config.r == 8 | |
| assert config.alpha == 32 | |
| assert config.dropout == 0.1 | |
| def test_custom_config(self): | |
| """Test custom LoRA configuration.""" | |
| config = LoRAConfig(r=16, alpha=64, dropout=0.2) | |
| assert config.r == 16 | |
| assert config.alpha == 64 | |
| assert config.dropout == 0.2 | |
| def test_invalid_r(self): | |
| """Test invalid LoRA rank.""" | |
| with pytest.raises(Exception): | |
| LoRAConfig(r=0) | |
| with pytest.raises(Exception): | |
| LoRAConfig(r=300) | |
| def test_invalid_dropout(self): | |
| """Test invalid dropout value.""" | |
| with pytest.raises(Exception): | |
| LoRAConfig(dropout=-0.1) | |
| with pytest.raises(Exception): | |
| LoRAConfig(dropout=1.5) | |
| class TestQuantizationConfig: | |
| """Tests for quantization configuration.""" | |
| def test_default_config(self): | |
| """Test default quantization configuration.""" | |
| config = QuantizationConfig() | |
| assert config.enabled is False | |
| assert config.bits == 4 | |
| assert config.use_double_quant is True | |
| def test_4bit_config(self): | |
| """Test 4-bit quantization.""" | |
| config = QuantizationConfig(enabled=True, bits=4) | |
| assert config.enabled is True | |
| assert config.bits == 4 | |
| def test_8bit_config(self): | |
| """Test 8-bit quantization.""" | |
| config = QuantizationConfig(enabled=True, bits=8) | |
| assert config.bits == 8 | |
| def test_invalid_bits(self): | |
| """Test invalid bit configuration.""" | |
| with pytest.raises(Exception): | |
| QuantizationConfig(bits=2) | |
| with pytest.raises(Exception): | |
| QuantizationConfig(bits=16) | |
| class TestTrainingConfig: | |
| """Tests for training configuration.""" | |
| def test_default_config(self): | |
| """Test default training configuration.""" | |
| config = TrainingConfig() | |
| assert config.epochs == 3 | |
| assert config.batch_size == 4 | |
| assert config.learning_rate == 2e-5 | |
| def test_custom_config(self): | |
| """Test custom training configuration.""" | |
| config = TrainingConfig( | |
| epochs=10, | |
| batch_size=8, | |
| learning_rate=1e-4 | |
| ) | |
| assert config.epochs == 10 | |
| assert config.batch_size == 8 | |
| assert config.learning_rate == 1e-4 | |
| def test_precision_conflict(self): | |
| """Test that fp16 and bf16 cannot both be enabled.""" | |
| with pytest.raises(Exception): | |
| TrainingConfig(fp16=True, bf16=True) | |
| class TestDataConfig: | |
| """Tests for data configuration.""" | |
| def test_default_config(self): | |
| """Test default data configuration.""" | |
| config = DataConfig() | |
| assert config.text_column == "text" | |
| assert config.preprocessing_num_workers == 4 | |
| def test_with_files(self): | |
| """Test data configuration with file paths.""" | |
| config = DataConfig( | |
| train_file="train.json", | |
| validation_file="val.json" | |
| ) | |
| assert config.train_file == "train.json" | |
| assert config.validation_file == "val.json" | |
| class TestKerdosConfig: | |
| """Tests for main KerdosAI configuration.""" | |
| def test_default_config(self): | |
| """Test default configuration.""" | |
| config = KerdosConfig(base_model="gpt2") | |
| assert config.base_model == "gpt2" | |
| assert config.lora.enabled is True | |
| assert config.quantization.enabled is False | |
| def test_yaml_roundtrip(self): | |
| """Test saving and loading YAML configuration.""" | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| config_path = Path(tmpdir) / "config.yaml" | |
| # Create and save config | |
| config = KerdosConfig( | |
| base_model="gpt2", | |
| output_dir=tmpdir | |
| ) | |
| config.to_yaml(config_path) | |
| # Load config | |
| loaded_config = KerdosConfig.from_yaml(config_path) | |
| assert loaded_config.base_model == config.base_model | |
| assert loaded_config.output_dir == config.output_dir | |
| def test_invalid_yaml_file(self): | |
| """Test loading from non-existent file.""" | |
| with pytest.raises(ConfigurationError): | |
| KerdosConfig.from_yaml("nonexistent.yaml") | |
| def test_validation_no_data_source(self): | |
| """Test validation fails when no data source is specified.""" | |
| config = KerdosConfig(base_model="gpt2") | |
| with pytest.raises(ConfigurationError): | |
| config.validate_compatibility() | |
| def test_validation_fp16_on_cpu(self): | |
| """Test validation fails for fp16 on CPU.""" | |
| config = KerdosConfig( | |
| base_model="gpt2", | |
| device="cpu", | |
| data=DataConfig(train_file="train.json") | |
| ) | |
| config.training.fp16 = True | |
| with pytest.raises(ConfigurationError): | |
| config.validate_compatibility() | |
| def test_env_var_substitution(self): | |
| """Test environment variable substitution.""" | |
| import os | |
| os.environ["TEST_MODEL"] = "test-model" | |
| config_dict = { | |
| "base_model": "${TEST_MODEL}", | |
| "output_dir": "./output" | |
| } | |
| result = KerdosConfig._substitute_env_vars(config_dict) | |
| assert result["base_model"] == "test-model" | |
| class TestLoadConfig: | |
| """Tests for config loading helper.""" | |
| def test_load_default_config(self): | |
| """Test loading default configuration.""" | |
| config = load_config() | |
| assert config.base_model == "gpt2" | |
| def test_load_from_file(self): | |
| """Test loading from file.""" | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| config_path = Path(tmpdir) / "config.yaml" | |
| # Create config file | |
| config_data = { | |
| "base_model": "test-model", | |
| "output_dir": tmpdir | |
| } | |
| with open(config_path, 'w') as f: | |
| yaml.dump(config_data, f) | |
| # Load config | |
| config = load_config(config_path) | |
| assert config.base_model == "test-model" | |