import threading import pytest from omniff.reliability import VRAM_ESTIMATES, ModelMutex, retry_on_oom def test_model_mutex_acquire_release(): mutex = ModelMutex() mutex.acquire("test") mutex.release("test") def test_model_mutex_concurrent(): mutex = ModelMutex() results = [] def worker(n): mutex.acquire("shared") results.append(n) mutex.release("shared") threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)] for t in threads: t.start() for t in threads: t.join() assert sorted(results) == [0, 1, 2, 3, 4] def test_retry_on_oom_no_error(): @retry_on_oom(max_retries=2) def ok(): return 42 assert ok() == 42 def test_retry_on_oom_non_oom_raises(): @retry_on_oom(max_retries=2) def fail(): raise ValueError("not oom") with pytest.raises(ValueError, match="not oom"): fail() def test_retry_on_oom_retries_oom(): call_count = 0 @retry_on_oom(max_retries=2, backoff_base=0.01) def flaky(): nonlocal call_count call_count += 1 if call_count < 3: raise RuntimeError("CUDA out of memory") return "ok" assert flaky() == "ok" assert call_count == 3 def test_vram_estimates_exist(): assert "Qwen/Qwen3-4B" in VRAM_ESTIMATES assert "openai/whisper-large-v3" in VRAM_ESTIMATES def test_mutex_double_release_safe(): mutex = ModelMutex() mutex.acquire("test") mutex.release("test") mutex.release("test")