File size: 1,550 Bytes
88e3f4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import threading

import pytest

from omniff.reliability import VRAM_ESTIMATES, ModelMutex, retry_on_oom


def test_model_mutex_acquire_release():
    mutex = ModelMutex()
    mutex.acquire("test")
    mutex.release("test")


def test_model_mutex_concurrent():
    mutex = ModelMutex()
    results = []

    def worker(n):
        mutex.acquire("shared")
        results.append(n)
        mutex.release("shared")

    threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
    for t in threads:
        t.start()
    for t in threads:
        t.join()

    assert sorted(results) == [0, 1, 2, 3, 4]


def test_retry_on_oom_no_error():
    @retry_on_oom(max_retries=2)
    def ok():
        return 42

    assert ok() == 42


def test_retry_on_oom_non_oom_raises():
    @retry_on_oom(max_retries=2)
    def fail():
        raise ValueError("not oom")

    with pytest.raises(ValueError, match="not oom"):
        fail()


def test_retry_on_oom_retries_oom():
    call_count = 0

    @retry_on_oom(max_retries=2, backoff_base=0.01)
    def flaky():
        nonlocal call_count
        call_count += 1
        if call_count < 3:
            raise RuntimeError("CUDA out of memory")
        return "ok"

    assert flaky() == "ok"
    assert call_count == 3


def test_vram_estimates_exist():
    assert "Qwen/Qwen3-4B" in VRAM_ESTIMATES
    assert "openai/whisper-large-v3" in VRAM_ESTIMATES


def test_mutex_double_release_safe():
    mutex = ModelMutex()
    mutex.acquire("test")
    mutex.release("test")
    mutex.release("test")