File size: 2,965 Bytes
88e3f4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import time

import pytest

from omniff.scheduler.model_scheduler import LoadPolicy, ModelScheduler


class FakeModel:
    def __init__(self):
        self._loaded = False
        self.load_count = 0

    @property
    def is_loaded(self):
        return self._loaded

    def load(self):
        self._loaded = True
        self.load_count += 1

    def unload(self):
        self._loaded = False


def test_register_and_acquire():
    sched = ModelScheduler(max_loaded=4)
    m = FakeModel()
    sched.register("llm", m, LoadPolicy.WARM)
    result = sched.acquire("llm")
    assert result is m
    assert m.is_loaded


def test_acquire_missing():
    sched = ModelScheduler()
    with pytest.raises(KeyError, match="not registered"):
        sched.acquire("nope")


def test_cold_release_unloads():
    sched = ModelScheduler()
    m = FakeModel()
    sched.register("llm", m, LoadPolicy.COLD)
    sched.acquire("llm")
    assert m.is_loaded
    sched.release("llm")
    assert not m.is_loaded


def test_warm_stays_loaded_after_release():
    sched = ModelScheduler()
    m = FakeModel()
    sched.register("llm", m, LoadPolicy.WARM)
    sched.acquire("llm")
    sched.release("llm")
    assert m.is_loaded


def test_hot_never_evicted():
    sched = ModelScheduler(max_loaded=1)
    hot = FakeModel()
    warm = FakeModel()
    sched.register("hot", hot, LoadPolicy.HOT)
    sched.register("warm", warm, LoadPolicy.WARM)
    sched.acquire("hot")
    sched.acquire("warm")
    assert hot.is_loaded


def test_evict_expired():
    sched = ModelScheduler(default_ttl=0.01)
    m = FakeModel()
    sched.register("llm", m, LoadPolicy.WARM, ttl=0.01)
    sched.acquire("llm")
    time.sleep(0.02)
    evicted = sched.evict_expired()
    assert "llm" in evicted
    assert not m.is_loaded


def test_lru_eviction():
    sched = ModelScheduler(max_loaded=2)
    m1 = FakeModel()
    m2 = FakeModel()
    m3 = FakeModel()
    sched.register("a", m1, LoadPolicy.WARM)
    sched.register("b", m2, LoadPolicy.WARM)
    sched.register("c", m3, LoadPolicy.WARM)
    sched.acquire("a")
    sched.acquire("b")
    sched.acquire("c")
    assert not m1.is_loaded  # LRU evicted
    assert m2.is_loaded or m3.is_loaded


def test_loaded_models():
    sched = ModelScheduler()
    m = FakeModel()
    sched.register("llm", m, LoadPolicy.WARM)
    assert sched.loaded_models() == []
    sched.acquire("llm")
    assert sched.loaded_models() == ["llm"]


def test_status():
    sched = ModelScheduler()
    m = FakeModel()
    sched.register("llm", m, LoadPolicy.WARM)
    s = sched.status()
    assert s["llm"]["policy"] == "warm"
    assert s["llm"]["loaded"] is False


def test_unload_all():
    sched = ModelScheduler()
    m1 = FakeModel()
    m2 = FakeModel()
    sched.register("a", m1, LoadPolicy.HOT)
    sched.register("b", m2, LoadPolicy.WARM)
    sched.acquire("a")
    sched.acquire("b")
    sched.unload_all()
    assert not m1.is_loaded
    assert not m2.is_loaded