zeroshotGPU / tests /test_zero_gpu.py
Arjunvir Singh
Add ZeroGPU integration
fa2127b
"""Tests for the ZeroGPU integration helper."""
from __future__ import annotations
import unittest
from unittest.mock import MagicMock, patch
from zsgdp.gpu import zero_gpu
class IsZeroGPUAvailableTests(unittest.TestCase):
def test_returns_false_when_spaces_module_missing(self):
# Default state on a local dev box: spaces is not installed.
with patch.object(zero_gpu, "_SPACES_AVAILABLE", False):
self.assertFalse(zero_gpu.is_zero_gpu_available())
def test_returns_true_when_spaces_module_present(self):
with patch.object(zero_gpu, "_SPACES_AVAILABLE", True):
self.assertTrue(zero_gpu.is_zero_gpu_available())
class GpuDecoratorTests(unittest.TestCase):
def test_passthrough_when_off_space(self):
with patch.object(zero_gpu, "_SPACES_AVAILABLE", False):
decorator = zero_gpu.gpu(duration=42)
def my_function(x: int) -> int:
return x * 2
wrapped = decorator(my_function)
# Off-Space: the decorator is a no-op, the function is unchanged.
self.assertIs(wrapped, my_function)
self.assertEqual(wrapped(7), 14)
def test_delegates_to_spaces_gpu_when_available(self):
# Build a fake `spaces` module whose GPU(duration=...) returns a
# specific decorator we can recognise.
fake_spaces = MagicMock()
marker = lambda fn: fn # noqa: E731 — sentinel decorator
fake_spaces.GPU.return_value = marker
with patch.object(zero_gpu, "_SPACES_AVAILABLE", True), patch.object(
zero_gpu, "_spaces", fake_spaces
):
decorator = zero_gpu.gpu(duration=180)
fake_spaces.GPU.assert_called_once_with(duration=180)
self.assertIs(decorator, marker)
def test_decorator_passes_duration_through(self):
fake_spaces = MagicMock()
fake_spaces.GPU.return_value = lambda fn: fn
with patch.object(zero_gpu, "_SPACES_AVAILABLE", True), patch.object(
zero_gpu, "_spaces", fake_spaces
):
zero_gpu.gpu(duration=60)
zero_gpu.gpu(duration=300)
calls = [call.kwargs for call in fake_spaces.GPU.call_args_list]
self.assertEqual(calls, [{"duration": 60}, {"duration": 300}])
class RuntimeStatusSurfacesZeroGPUTests(unittest.TestCase):
def test_zero_gpu_available_field_in_runtime_status(self):
from zsgdp.config import load_config
from zsgdp.gpu.runtime import collect_gpu_runtime_status
config = load_config()
with patch.object(zero_gpu, "_SPACES_AVAILABLE", False):
status = collect_gpu_runtime_status(config).to_dict()
self.assertFalse(status["zero_gpu_available"])
with patch.object(zero_gpu, "_SPACES_AVAILABLE", True):
status = collect_gpu_runtime_status(config).to_dict()
self.assertTrue(status["zero_gpu_available"])
self.assertTrue(
any("ZeroGPU SDK detected" in note for note in status["notes"]),
msg=f"notes={status['notes']}",
)
class DecoratedClientsStillWorkOffSpaceTests(unittest.TestCase):
"""End-to-end check: decorating EmbeddingRetriever / TransformersClient
must not break their off-Space behaviour."""
def test_embedding_retriever_index_and_query_unchanged_off_space(self):
from zsgdp.benchmarks.embedding_retriever import EmbeddingRetriever
from zsgdp.schema import Chunk
# Tiny deterministic embedder.
def embedder(texts):
return [[float(len(t)), 1.0] for t in texts]
retriever = EmbeddingRetriever(embedder=embedder)
chunks = [
Chunk(
chunk_id=f"c{i}",
doc_id="d",
page_start=1,
page_end=1,
section_path=[],
content_type="prose",
text=text,
token_count=len(text.split()),
)
for i, text in enumerate(["short text", "this is a much longer sentence"])
]
with patch.object(zero_gpu, "_SPACES_AVAILABLE", False):
retriever.index(chunks)
ranking = retriever.query("a much longer sentence", top_k=2)
self.assertEqual(set(ranking), {"c0", "c1"})
def test_transformers_client_returns_unavailable_off_space(self):
# Without a model_id the client refuses to invoke; the @zero_gpu_slot
# decorator must not change that contract.
from zsgdp.gpu.transformers_client import TransformersClient
client = TransformersClient(model_id=None)
with patch.object(zero_gpu, "_SPACES_AVAILABLE", False):
result = client.execute_task({"task_id": "t1"})
self.assertEqual(result["status"], "backend_unavailable")
if __name__ == "__main__":
unittest.main()