kerdosai / tests /test_config.py
Anonymous Hunter
feat: Add robust configuration management, Docker support, initial testing, and quickstart documentation.
f21249a
"""
Unit tests for configuration management.
"""
import pytest
import tempfile
from pathlib import Path
import yaml
from config import (
KerdosConfig,
LoRAConfig,
QuantizationConfig,
TrainingConfig,
DataConfig,
load_config
)
from exceptions import ConfigurationError
class TestLoRAConfig:
"""Tests for LoRA configuration."""
def test_default_config(self):
"""Test default LoRA configuration."""
config = LoRAConfig()
assert config.enabled is True
assert config.r == 8
assert config.alpha == 32
assert config.dropout == 0.1
def test_custom_config(self):
"""Test custom LoRA configuration."""
config = LoRAConfig(r=16, alpha=64, dropout=0.2)
assert config.r == 16
assert config.alpha == 64
assert config.dropout == 0.2
def test_invalid_r(self):
"""Test invalid LoRA rank."""
with pytest.raises(Exception):
LoRAConfig(r=0)
with pytest.raises(Exception):
LoRAConfig(r=300)
def test_invalid_dropout(self):
"""Test invalid dropout value."""
with pytest.raises(Exception):
LoRAConfig(dropout=-0.1)
with pytest.raises(Exception):
LoRAConfig(dropout=1.5)
class TestQuantizationConfig:
"""Tests for quantization configuration."""
def test_default_config(self):
"""Test default quantization configuration."""
config = QuantizationConfig()
assert config.enabled is False
assert config.bits == 4
assert config.use_double_quant is True
def test_4bit_config(self):
"""Test 4-bit quantization."""
config = QuantizationConfig(enabled=True, bits=4)
assert config.enabled is True
assert config.bits == 4
def test_8bit_config(self):
"""Test 8-bit quantization."""
config = QuantizationConfig(enabled=True, bits=8)
assert config.bits == 8
def test_invalid_bits(self):
"""Test invalid bit configuration."""
with pytest.raises(Exception):
QuantizationConfig(bits=2)
with pytest.raises(Exception):
QuantizationConfig(bits=16)
class TestTrainingConfig:
"""Tests for training configuration."""
def test_default_config(self):
"""Test default training configuration."""
config = TrainingConfig()
assert config.epochs == 3
assert config.batch_size == 4
assert config.learning_rate == 2e-5
def test_custom_config(self):
"""Test custom training configuration."""
config = TrainingConfig(
epochs=10,
batch_size=8,
learning_rate=1e-4
)
assert config.epochs == 10
assert config.batch_size == 8
assert config.learning_rate == 1e-4
def test_precision_conflict(self):
"""Test that fp16 and bf16 cannot both be enabled."""
with pytest.raises(Exception):
TrainingConfig(fp16=True, bf16=True)
class TestDataConfig:
"""Tests for data configuration."""
def test_default_config(self):
"""Test default data configuration."""
config = DataConfig()
assert config.text_column == "text"
assert config.preprocessing_num_workers == 4
def test_with_files(self):
"""Test data configuration with file paths."""
config = DataConfig(
train_file="train.json",
validation_file="val.json"
)
assert config.train_file == "train.json"
assert config.validation_file == "val.json"
class TestKerdosConfig:
"""Tests for main KerdosAI configuration."""
def test_default_config(self):
"""Test default configuration."""
config = KerdosConfig(base_model="gpt2")
assert config.base_model == "gpt2"
assert config.lora.enabled is True
assert config.quantization.enabled is False
def test_yaml_roundtrip(self):
"""Test saving and loading YAML configuration."""
with tempfile.TemporaryDirectory() as tmpdir:
config_path = Path(tmpdir) / "config.yaml"
# Create and save config
config = KerdosConfig(
base_model="gpt2",
output_dir=tmpdir
)
config.to_yaml(config_path)
# Load config
loaded_config = KerdosConfig.from_yaml(config_path)
assert loaded_config.base_model == config.base_model
assert loaded_config.output_dir == config.output_dir
def test_invalid_yaml_file(self):
"""Test loading from non-existent file."""
with pytest.raises(ConfigurationError):
KerdosConfig.from_yaml("nonexistent.yaml")
def test_validation_no_data_source(self):
"""Test validation fails when no data source is specified."""
config = KerdosConfig(base_model="gpt2")
with pytest.raises(ConfigurationError):
config.validate_compatibility()
def test_validation_fp16_on_cpu(self):
"""Test validation fails for fp16 on CPU."""
config = KerdosConfig(
base_model="gpt2",
device="cpu",
data=DataConfig(train_file="train.json")
)
config.training.fp16 = True
with pytest.raises(ConfigurationError):
config.validate_compatibility()
def test_env_var_substitution(self):
"""Test environment variable substitution."""
import os
os.environ["TEST_MODEL"] = "test-model"
config_dict = {
"base_model": "${TEST_MODEL}",
"output_dir": "./output"
}
result = KerdosConfig._substitute_env_vars(config_dict)
assert result["base_model"] == "test-model"
class TestLoadConfig:
"""Tests for config loading helper."""
def test_load_default_config(self):
"""Test loading default configuration."""
config = load_config()
assert config.base_model == "gpt2"
def test_load_from_file(self):
"""Test loading from file."""
with tempfile.TemporaryDirectory() as tmpdir:
config_path = Path(tmpdir) / "config.yaml"
# Create config file
config_data = {
"base_model": "test-model",
"output_dir": tmpdir
}
with open(config_path, 'w') as f:
yaml.dump(config_data, f)
# Load config
config = load_config(config_path)
assert config.base_model == "test-model"