File size: 2,997 Bytes
88e3f4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from pathlib import Path

import pytest
import yaml

from omniff.runtime.config import ExpertConfig, OmniFFConfig, RouterConfig


def test_load_config_from_yaml(tmp_path):
    config_file = tmp_path / "omniff.yaml"
    config_file.write_text("""
name: test-runtime
version: "0.1"

router:
  type: keyword
  path: ""

experts:
  text_small:
    name: text_small
    model_type: causal_lm
    path: models/llm_small
    loading: hot
""")
    config = OmniFFConfig.load(config_file)
    assert config.name == "test-runtime"
    assert config.version == "0.1"
    assert config.router.router_type == "keyword"
    assert "text_small" in config.experts
    assert config.experts["text_small"].model_type == "causal_lm"
    assert config.experts["text_small"].loading == "hot"


def test_config_missing_file():
    with pytest.raises(FileNotFoundError):
        OmniFFConfig.load(Path("/nonexistent/omniff.yaml"))


def test_config_expert_defaults(tmp_path):
    config_file = tmp_path / "omniff.yaml"
    config_file.write_text("""
name: minimal
version: "0.1"
router:
  type: keyword
  path: ""
experts: {}
""")
    config = OmniFFConfig.load(config_file)
    assert config.experts == {}
    assert config.graph_templates_dir is None


def test_config_invalid_yaml(tmp_path):
    config_file = tmp_path / "omniff.yaml"
    config_file.write_text(": : : not valid yaml [[[")
    with pytest.raises((yaml.YAMLError, ValueError)):
        OmniFFConfig.load(config_file)


def test_config_missing_required_field(tmp_path):
    config_file = tmp_path / "omniff.yaml"
    config_file.write_text("""
version: "0.1"
router:
  type: keyword
""")
    with pytest.raises((KeyError, TypeError, Exception)):
        OmniFFConfig.load(config_file)


def test_config_missing_router(tmp_path):
    config_file = tmp_path / "omniff.yaml"
    config_file.write_text("""
name: test
version: "0.1"
""")
    with pytest.raises((KeyError, TypeError, Exception)):
        OmniFFConfig.load(config_file)


def test_config_missing_expert_model_type(tmp_path):
    config_file = tmp_path / "omniff.yaml"
    config_file.write_text("""
name: test
version: "0.1"
router:
  type: keyword
  path: ""
experts:
  bad_expert:
    name: bad
    path: models/bad
""")
    with pytest.raises((KeyError, Exception)):
        OmniFFConfig.load(config_file)


def test_expert_config_defaults():
    expert = ExpertConfig(name="test", model_type="causal_lm", path="models/test")
    assert expert.loading == "warm"
    assert expert.quantization is None
    assert expert.device is None


def test_router_config_defaults():
    router = RouterConfig(router_type="keyword")
    assert router.path == ""


def test_config_graph_templates_dir(tmp_path):
    config_file = tmp_path / "omniff.yaml"
    config_file.write_text("""
name: test
version: "0.1"
router:
  type: keyword
  path: ""
graph_templates_dir: /tmp/templates
experts: {}
""")
    config = OmniFFConfig.load(config_file)
    assert config.graph_templates_dir == "/tmp/templates"