File size: 2,997 Bytes
88e3f4a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | from pathlib import Path
import pytest
import yaml
from omniff.runtime.config import ExpertConfig, OmniFFConfig, RouterConfig
def test_load_config_from_yaml(tmp_path):
config_file = tmp_path / "omniff.yaml"
config_file.write_text("""
name: test-runtime
version: "0.1"
router:
type: keyword
path: ""
experts:
text_small:
name: text_small
model_type: causal_lm
path: models/llm_small
loading: hot
""")
config = OmniFFConfig.load(config_file)
assert config.name == "test-runtime"
assert config.version == "0.1"
assert config.router.router_type == "keyword"
assert "text_small" in config.experts
assert config.experts["text_small"].model_type == "causal_lm"
assert config.experts["text_small"].loading == "hot"
def test_config_missing_file():
with pytest.raises(FileNotFoundError):
OmniFFConfig.load(Path("/nonexistent/omniff.yaml"))
def test_config_expert_defaults(tmp_path):
config_file = tmp_path / "omniff.yaml"
config_file.write_text("""
name: minimal
version: "0.1"
router:
type: keyword
path: ""
experts: {}
""")
config = OmniFFConfig.load(config_file)
assert config.experts == {}
assert config.graph_templates_dir is None
def test_config_invalid_yaml(tmp_path):
config_file = tmp_path / "omniff.yaml"
config_file.write_text(": : : not valid yaml [[[")
with pytest.raises((yaml.YAMLError, ValueError)):
OmniFFConfig.load(config_file)
def test_config_missing_required_field(tmp_path):
config_file = tmp_path / "omniff.yaml"
config_file.write_text("""
version: "0.1"
router:
type: keyword
""")
with pytest.raises((KeyError, TypeError, Exception)):
OmniFFConfig.load(config_file)
def test_config_missing_router(tmp_path):
config_file = tmp_path / "omniff.yaml"
config_file.write_text("""
name: test
version: "0.1"
""")
with pytest.raises((KeyError, TypeError, Exception)):
OmniFFConfig.load(config_file)
def test_config_missing_expert_model_type(tmp_path):
config_file = tmp_path / "omniff.yaml"
config_file.write_text("""
name: test
version: "0.1"
router:
type: keyword
path: ""
experts:
bad_expert:
name: bad
path: models/bad
""")
with pytest.raises((KeyError, Exception)):
OmniFFConfig.load(config_file)
def test_expert_config_defaults():
expert = ExpertConfig(name="test", model_type="causal_lm", path="models/test")
assert expert.loading == "warm"
assert expert.quantization is None
assert expert.device is None
def test_router_config_defaults():
router = RouterConfig(router_type="keyword")
assert router.path == ""
def test_config_graph_templates_dir(tmp_path):
config_file = tmp_path / "omniff.yaml"
config_file.write_text("""
name: test
version: "0.1"
router:
type: keyword
path: ""
graph_templates_dir: /tmp/templates
experts: {}
""")
config = OmniFFConfig.load(config_file)
assert config.graph_templates_dir == "/tmp/templates"
|