File size: 2,965 Bytes
88e3f4a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | import time
import pytest
from omniff.scheduler.model_scheduler import LoadPolicy, ModelScheduler
class FakeModel:
def __init__(self):
self._loaded = False
self.load_count = 0
@property
def is_loaded(self):
return self._loaded
def load(self):
self._loaded = True
self.load_count += 1
def unload(self):
self._loaded = False
def test_register_and_acquire():
sched = ModelScheduler(max_loaded=4)
m = FakeModel()
sched.register("llm", m, LoadPolicy.WARM)
result = sched.acquire("llm")
assert result is m
assert m.is_loaded
def test_acquire_missing():
sched = ModelScheduler()
with pytest.raises(KeyError, match="not registered"):
sched.acquire("nope")
def test_cold_release_unloads():
sched = ModelScheduler()
m = FakeModel()
sched.register("llm", m, LoadPolicy.COLD)
sched.acquire("llm")
assert m.is_loaded
sched.release("llm")
assert not m.is_loaded
def test_warm_stays_loaded_after_release():
sched = ModelScheduler()
m = FakeModel()
sched.register("llm", m, LoadPolicy.WARM)
sched.acquire("llm")
sched.release("llm")
assert m.is_loaded
def test_hot_never_evicted():
sched = ModelScheduler(max_loaded=1)
hot = FakeModel()
warm = FakeModel()
sched.register("hot", hot, LoadPolicy.HOT)
sched.register("warm", warm, LoadPolicy.WARM)
sched.acquire("hot")
sched.acquire("warm")
assert hot.is_loaded
def test_evict_expired():
sched = ModelScheduler(default_ttl=0.01)
m = FakeModel()
sched.register("llm", m, LoadPolicy.WARM, ttl=0.01)
sched.acquire("llm")
time.sleep(0.02)
evicted = sched.evict_expired()
assert "llm" in evicted
assert not m.is_loaded
def test_lru_eviction():
sched = ModelScheduler(max_loaded=2)
m1 = FakeModel()
m2 = FakeModel()
m3 = FakeModel()
sched.register("a", m1, LoadPolicy.WARM)
sched.register("b", m2, LoadPolicy.WARM)
sched.register("c", m3, LoadPolicy.WARM)
sched.acquire("a")
sched.acquire("b")
sched.acquire("c")
assert not m1.is_loaded # LRU evicted
assert m2.is_loaded or m3.is_loaded
def test_loaded_models():
sched = ModelScheduler()
m = FakeModel()
sched.register("llm", m, LoadPolicy.WARM)
assert sched.loaded_models() == []
sched.acquire("llm")
assert sched.loaded_models() == ["llm"]
def test_status():
sched = ModelScheduler()
m = FakeModel()
sched.register("llm", m, LoadPolicy.WARM)
s = sched.status()
assert s["llm"]["policy"] == "warm"
assert s["llm"]["loaded"] is False
def test_unload_all():
sched = ModelScheduler()
m1 = FakeModel()
m2 = FakeModel()
sched.register("a", m1, LoadPolicy.HOT)
sched.register("b", m2, LoadPolicy.WARM)
sched.acquire("a")
sched.acquire("b")
sched.unload_all()
assert not m1.is_loaded
assert not m2.is_loaded
|