Spaces:
Running on Zero
Running on Zero
File size: 2,126 Bytes
db06ffa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | import unittest
from unittest.mock import patch
from zsgdp.config import load_config
from zsgdp.gpu import GPUModelConfig, collect_gpu_runtime_status
class GPURuntimeTests(unittest.TestCase):
def test_model_config_reads_gpu_section(self):
config = load_config(overrides={"gpu": {"backend": "vllm", "provider": "huggingface_spaces", "space_name": "zeroshotGPU", "max_batch_size": 8}})
model_config = GPUModelConfig.from_config(config)
self.assertEqual(model_config.backend, "vllm")
self.assertEqual(model_config.provider, "huggingface_spaces")
self.assertEqual(model_config.space_name, "zeroshotGPU")
self.assertEqual(model_config.max_batch_size, 8)
def test_collect_runtime_detects_space_environment(self):
config = load_config()
with patch.dict("os.environ", {"SPACE_ID": "user/zeroshotGPU", "SPACE_HARDWARE": "l4x1"}, clear=False):
status = collect_gpu_runtime_status(config).to_dict()
self.assertEqual(status["provider"], "huggingface_spaces")
self.assertEqual(status["space_name"], "zeroshotGPU")
self.assertEqual(status["gpu_models_target"], "zeroshotGPU")
self.assertTrue(status["running_on_huggingface_space"])
self.assertEqual(status["space_id"], "user/zeroshotGPU")
self.assertEqual(status["hardware"], "l4x1")
self.assertIn(status["device"], {"cpu", "cuda", "mps"})
self.assertIn("torch_available", status)
self.assertEqual(status["configured_models"]["vlm"]["model_id"], "Qwen/Qwen2.5-VL-3B-Instruct")
self.assertEqual(status["configured_models"]["embedding"]["model_id"], "jinaai/jina-embeddings-v3")
def test_collect_runtime_reports_local_note(self):
config = load_config()
with patch.dict("os.environ", {"SPACE_ID": "", "SPACE_HOST": "", "SPACE_HARDWARE": ""}, clear=False):
status = collect_gpu_runtime_status(config)
self.assertFalse(status.running_on_huggingface_space)
self.assertTrue(any("local run" in note for note in status.notes))
if __name__ == "__main__":
unittest.main()
|