| """Integration tests for RayEngine and Ray HTTP server (requires GPU + Ray). |
| |
| Tests the Ray actor scheduler backend: |
| - Offline inference via Engine(use_ray=True) inside a Ray actor on a placement group |
| - Error paths in RayEngine._launch_scheduler_processes() |
| - HTTP server launched via --use-ray flag |
| |
| Usage: |
| # 1-GPU tests |
| python -m pytest test/manual/test_ray_engine.py::TestRayEngineOfflineTP1 -v -s |
| python -m pytest test/manual/test_ray_engine.py::TestRayEngineErrors -v -s |
| python -m pytest test/manual/test_ray_engine.py::TestRayHTTPServerTP1 -v -s |
| |
| # 2-GPU tests |
| python -m pytest test/manual/test_ray_engine.py::TestRayEngineOfflineTP2 -v -s |
| python -m pytest test/manual/test_ray_engine.py::TestRayEngineOfflinePP2 -v -s |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import time |
| import unittest |
|
|
| import torch |
|
|
| from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST |
|
|
| |
| _MODEL = os.environ.get("SGLANG_TEST_MODEL", DEFAULT_SMALL_MODEL_NAME_FOR_TEST) |
|
|
| try: |
| import ray |
| from ray.runtime_env import RuntimeEnv |
| from ray.util.placement_group import placement_group |
| from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy |
|
|
| |
| |
| _env_vars = {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"} |
| if os.environ.get("HF_TOKEN"): |
| _env_vars["HF_TOKEN"] = os.environ["HF_TOKEN"] |
| _RAY_RUNTIME_ENV = RuntimeEnv(env_vars=_env_vars) |
| _has_ray = True |
| except ImportError: |
| _has_ray = False |
| _RAY_RUNTIME_ENV = None |
|
|
|
|
| _NUM_GPUS = torch.cuda.device_count() |
|
|
| _SAMPLING_PARAMS = {"max_new_tokens": 32, "temperature": 0.0} |
|
|
| _PROMPTS = [ |
| "The capital of France is", |
| "Explain quantum computing in simple terms:", |
| "Write a haiku about programming:", |
| "What is 2 + 2?", |
| ] |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _create_engine_on_pg(tp_size, pp_size=1, model=_MODEL, extra_kwargs=None): |
| """Create an EngineActor on a placement group and wait for it to be ready. |
| |
| Returns (engine_actor, placement_group). |
| """ |
|
|
| @ray.remote |
| class EngineActor: |
| def __init__(self, **kwargs): |
| from sglang.srt.ray.engine import RayEngine |
|
|
| self.engine = RayEngine(**kwargs) |
|
|
| def is_ready(self): |
| return True |
|
|
| def generate(self, prompt, sampling_params): |
| return self.engine.generate(prompt=prompt, sampling_params=sampling_params) |
|
|
| def shutdown(self): |
| if self.engine: |
| self.engine.shutdown() |
| self.engine = None |
|
|
| total_gpus = tp_size * pp_size |
| pg = placement_group( |
| [{"CPU": 1, "GPU": total_gpus}], |
| strategy="STRICT_PACK", |
| ) |
| ray.get(pg.ready()) |
|
|
| kwargs = dict( |
| model_path=model, |
| tp_size=tp_size, |
| pp_size=pp_size, |
| ) |
| if extra_kwargs: |
| kwargs.update(extra_kwargs) |
|
|
| actor = EngineActor.options( |
| num_cpus=1, |
| num_gpus=0, |
| scheduling_strategy=PlacementGroupSchedulingStrategy( |
| placement_group=pg, |
| placement_group_bundle_index=0, |
| ), |
| ).remote(**kwargs) |
|
|
| ray.get(actor.is_ready.remote(), timeout=600) |
| return actor, pg |
|
|
|
|
| def _cleanup(actor, pg): |
| """Shutdown engine actor and remove placement group.""" |
| try: |
| ray.get(actor.shutdown.remote(), timeout=60) |
| except Exception: |
| pass |
| try: |
| ray.util.remove_placement_group(pg) |
| except Exception: |
| pass |
|
|
|
|
| |
| |
| |
|
|
|
|
| @unittest.skipUnless(_has_ray, "ray is not installed") |
| @unittest.skipUnless(_NUM_GPUS >= 1, "requires at least 1 GPU") |
| class TestRayEngineOfflineTP1(unittest.TestCase): |
|
|
| @classmethod |
| def setUpClass(cls): |
| if not ray.is_initialized(): |
| ray.init(log_to_driver=True, runtime_env=_RAY_RUNTIME_ENV) |
| cls.actor, cls.pg = _create_engine_on_pg(tp_size=1) |
|
|
| @classmethod |
| def tearDownClass(cls): |
| _cleanup(cls.actor, cls.pg) |
| ray.shutdown() |
|
|
| def test_offline_generate(self): |
| result = ray.get( |
| self.actor.generate.remote("The capital of France is", _SAMPLING_PARAMS) |
| ) |
| self.assertIn("text", result) |
| self.assertGreater(len(result["text"]), 0) |
| print(f"Generated: {result['text'][:200]}") |
|
|
| def test_batch_generate(self): |
| for prompt in _PROMPTS: |
| result = ray.get(self.actor.generate.remote(prompt, _SAMPLING_PARAMS)) |
| self.assertIn("text", result) |
| self.assertGreater(len(result["text"]), 0, f"Empty output for: {prompt}") |
|
|
| def test_deterministic(self): |
| prompt = "The meaning of life is" |
| r1 = ray.get(self.actor.generate.remote(prompt, _SAMPLING_PARAMS)) |
| r2 = ray.get(self.actor.generate.remote(prompt, _SAMPLING_PARAMS)) |
| self.assertEqual(r1["text"], r2["text"]) |
|
|
|
|
| |
| |
| |
|
|
|
|
| @unittest.skipUnless(_has_ray, "ray is not installed") |
| @unittest.skipUnless(_NUM_GPUS >= 2, "requires at least 2 GPUs") |
| class TestRayEngineOfflineTP2(unittest.TestCase): |
|
|
| @classmethod |
| def setUpClass(cls): |
| if not ray.is_initialized(): |
| ray.init(log_to_driver=True, runtime_env=_RAY_RUNTIME_ENV) |
| cls.actor, cls.pg = _create_engine_on_pg(tp_size=2) |
|
|
| @classmethod |
| def tearDownClass(cls): |
| _cleanup(cls.actor, cls.pg) |
| ray.shutdown() |
|
|
| def test_offline_generate_tp2(self): |
| result = ray.get( |
| self.actor.generate.remote("The capital of France is", _SAMPLING_PARAMS) |
| ) |
| self.assertIn("text", result) |
| self.assertGreater(len(result["text"]), 0) |
| print(f"Generated (TP=2): {result['text'][:200]}") |
|
|
| def test_batch_generate_tp2(self): |
| for prompt in _PROMPTS: |
| result = ray.get(self.actor.generate.remote(prompt, _SAMPLING_PARAMS)) |
| self.assertIn("text", result) |
| self.assertGreater(len(result["text"]), 0, f"Empty output for: {prompt}") |
|
|
|
|
| |
| |
| |
|
|
|
|
| @unittest.skipUnless(_has_ray, "ray is not installed") |
| @unittest.skipUnless(_NUM_GPUS >= 2, "requires at least 2 GPUs") |
| class TestRayEngineOfflinePP2(unittest.TestCase): |
|
|
| @classmethod |
| def setUpClass(cls): |
| if not ray.is_initialized(): |
| ray.init(log_to_driver=True, runtime_env=_RAY_RUNTIME_ENV) |
| cls.actor, cls.pg = _create_engine_on_pg(tp_size=1, pp_size=2) |
|
|
| @classmethod |
| def tearDownClass(cls): |
| _cleanup(cls.actor, cls.pg) |
| ray.shutdown() |
|
|
| def test_offline_generate_pp2(self): |
| result = ray.get( |
| self.actor.generate.remote("The capital of France is", _SAMPLING_PARAMS) |
| ) |
| self.assertIn("text", result) |
| self.assertGreater(len(result["text"]), 0) |
| print(f"Generated (PP=2): {result['text'][:200]}") |
|
|
| def test_batch_generate_pp2(self): |
| for prompt in _PROMPTS: |
| result = ray.get(self.actor.generate.remote(prompt, _SAMPLING_PARAMS)) |
| self.assertIn("text", result) |
| self.assertGreater(len(result["text"]), 0, f"Empty output for: {prompt}") |
|
|
|
|
| |
| |
| |
|
|
|
|
| @unittest.skipUnless(_has_ray, "ray is not installed") |
| @unittest.skipUnless(_NUM_GPUS >= 1, "requires at least 1 GPU") |
| class TestRayEngineErrors(unittest.TestCase): |
|
|
| @classmethod |
| def setUpClass(cls): |
| if not ray.is_initialized(): |
| ray.init(log_to_driver=True, runtime_env=_RAY_RUNTIME_ENV) |
|
|
| @classmethod |
| def tearDownClass(cls): |
| ray.shutdown() |
|
|
| def test_dp_greater_than_1_raises(self): |
| """RayEngine with dp_size > 1 should raise NotImplementedError.""" |
|
|
| @ray.remote |
| class _BadActor: |
| def try_create(self): |
| from sglang.srt.ray.engine import RayEngine |
|
|
| try: |
| RayEngine( |
| model_path=_MODEL, |
| tp_size=1, |
| dp_size=2, |
| use_ray=True, |
| ) |
| return None |
| except (NotImplementedError, RuntimeError) as e: |
| return str(e) |
|
|
| pg = placement_group([{"CPU": 1, "GPU": 1}], strategy="STRICT_PACK") |
| ray.get(pg.ready()) |
|
|
| actor = _BadActor.options( |
| num_cpus=1, |
| num_gpus=0, |
| scheduling_strategy=PlacementGroupSchedulingStrategy( |
| placement_group=pg, |
| placement_group_bundle_index=0, |
| ), |
| ).remote() |
|
|
| try: |
| error_msg = ray.get(actor.try_create.remote(), timeout=120) |
| self.assertIsNotNone(error_msg, "Expected error but RayEngine created OK") |
| self.assertIn("dp_size", error_msg.lower()) |
| finally: |
| ray.util.remove_placement_group(pg) |
|
|
| def test_missing_placement_group_raises(self): |
| """RayEngine without a placement group should raise RuntimeError.""" |
|
|
| @ray.remote(num_gpus=1) |
| def _try_create_without_pg(): |
| from sglang.srt.ray.engine import RayEngine |
|
|
| try: |
| RayEngine( |
| model_path=_MODEL, |
| tp_size=1, |
| use_ray=True, |
| ) |
| return None |
| except RuntimeError as e: |
| return str(e) |
|
|
| error_msg = ray.get(_try_create_without_pg.remote(), timeout=120) |
| self.assertIsNotNone( |
| error_msg, "Expected RuntimeError but RayEngine created OK" |
| ) |
| self.assertIn("placement group", error_msg.lower()) |
|
|
|
|
| |
| |
| |
|
|
|
|
| @unittest.skipUnless(_has_ray, "ray is not installed") |
| @unittest.skipUnless(_NUM_GPUS >= 1, "requires at least 1 GPU") |
| class TestRayHTTPServerTP1(unittest.TestCase): |
| """Test the Ray HTTP server path (launch_server.py --use-ray). |
| |
| Launches the server inside a Ray task on a placement group (mirrors |
| examples/anyscale/driver_online.py) and sends HTTP requests to it. |
| """ |
|
|
| @classmethod |
| def setUpClass(cls): |
| import requests as req_lib |
|
|
| if not ray.is_initialized(): |
| ray.init(log_to_driver=True, runtime_env=_RAY_RUNTIME_ENV) |
|
|
| cls.port = 30100 |
| cls.pg = placement_group( |
| [{"CPU": 1, "GPU": 1}], |
| strategy="STRICT_PACK", |
| ) |
| ray.get(cls.pg.ready()) |
|
|
| pg_strategy = PlacementGroupSchedulingStrategy( |
| placement_group=cls.pg, |
| placement_group_bundle_index=0, |
| ) |
|
|
| |
| @ray.remote(num_cpus=0, num_gpus=0) |
| def _get_ip(): |
| return ray.util.get_node_ip_address() |
|
|
| cls.node_ip = ray.get(_get_ip.options(scheduling_strategy=pg_strategy).remote()) |
| cls.base_url = f"http://{cls.node_ip}:{cls.port}" |
|
|
| |
| @ray.remote |
| def _launch(**kwargs): |
| from sglang.srt.ray.http_server import launch_server |
| from sglang.srt.server_args import ServerArgs |
|
|
| launch_server(ServerArgs(**kwargs)) |
|
|
| cls.server_ref = _launch.options( |
| num_cpus=1, |
| num_gpus=0, |
| scheduling_strategy=pg_strategy, |
| ).remote( |
| model_path=_MODEL, |
| tp_size=1, |
| port=cls.port, |
| host="0.0.0.0", |
| use_ray=True, |
| ) |
|
|
| |
| t0 = time.time() |
| timeout = 600 |
| healthy = False |
| while time.time() - t0 < timeout: |
| ready, _ = ray.wait([cls.server_ref], timeout=0) |
| if ready: |
| try: |
| ray.get(cls.server_ref) |
| except Exception as e: |
| raise RuntimeError(f"Server task crashed: {e}") from e |
| raise RuntimeError("Server task exited before becoming healthy") |
| try: |
| if req_lib.get(f"{cls.base_url}/health", timeout=5).status_code == 200: |
| healthy = True |
| break |
| except req_lib.exceptions.RequestException: |
| pass |
| time.sleep(3) |
|
|
| if not healthy: |
| ray.cancel(cls.server_ref, force=True) |
| raise RuntimeError(f"Server did not become healthy within {timeout}s") |
|
|
| @classmethod |
| def tearDownClass(cls): |
| try: |
| ray.cancel(cls.server_ref, force=True) |
| except Exception: |
| pass |
| try: |
| ray.util.remove_placement_group(cls.pg) |
| except Exception: |
| pass |
| ray.shutdown() |
|
|
| def test_health_endpoint(self): |
| import requests |
|
|
| resp = requests.get(f"{self.base_url}/health", timeout=10) |
| self.assertEqual(resp.status_code, 200) |
|
|
| def test_generate_endpoint(self): |
| import requests |
|
|
| resp = requests.post( |
| f"{self.base_url}/generate", |
| json={ |
| "text": "The capital of France is", |
| "sampling_params": _SAMPLING_PARAMS, |
| }, |
| timeout=60, |
| ) |
| resp.raise_for_status() |
| data = resp.json() |
| self.assertIn("text", data) |
| self.assertGreater(len(data["text"]), 0) |
| print(f"HTTP response: {data['text'][:200]}") |
|
|
| def test_generate_multiple(self): |
| import requests |
|
|
| for prompt in _PROMPTS: |
| resp = requests.post( |
| f"{self.base_url}/generate", |
| json={ |
| "text": prompt, |
| "sampling_params": _SAMPLING_PARAMS, |
| }, |
| timeout=60, |
| ) |
| resp.raise_for_status() |
| data = resp.json() |
| self.assertIn("text", data) |
| self.assertGreater(len(data["text"]), 0, f"Empty output for: {prompt}") |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|