File size: 6,932 Bytes
a402b9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | """
Integration test for abort_request functionality with a SGLang server.
Run with:
python -m unittest sglang.test.srt.entrypoints.http_server.test_abort_request -v
"""
import threading
import time
import unittest
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestAbortRequest(CustomTestCase):
"""Integration test class for abort request functionality."""
model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST
@classmethod
def setUpClass(cls):
"""Launch the server."""
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--disable-cuda-graph"],
)
cls.completion_url = f"{cls.base_url}/generate"
cls.abort_url = f"{cls.base_url}/abort_request"
cls.health_url = f"{cls.base_url}/health"
print(f"Server started at {cls.base_url}")
@classmethod
def tearDownClass(cls):
"""Clean up the server."""
kill_process_tree(cls.process.pid)
def _send_completion_request(
self,
text: str,
request_id: str,
max_tokens: int = 50,
temperature: float = 0.8,
stream: bool = True,
) -> requests.Response:
"""Send a completion request to the server."""
payload = {
"text": text,
"sampling_params": {
"max_new_tokens": max_tokens,
"temperature": temperature,
},
"stream": stream,
"rid": request_id,
}
response = requests.post(
self.completion_url,
json=payload,
headers={"Content-Type": "application/json"},
timeout=30,
stream=stream,
)
return response
def _send_abort_request(self, request_id: str) -> requests.Response:
"""Send an abort request."""
payload = {"rid": request_id}
return requests.post(self.abort_url, json=payload, timeout=10)
def _check_server_health(self) -> bool:
"""Check if server is healthy."""
try:
response = requests.get(self.health_url, timeout=5)
return response.status_code == 200
except:
return False
def test_abort_during_non_streaming_generation(self):
"""Test aborting a non-streaming request during generation."""
self.assertTrue(self._check_server_health(), "Server should be healthy")
request_id = "test_abort_non_streaming"
completion_result = {}
def run_completion():
response = self._send_completion_request(
"Write a detailed essay about artificial intelligence",
max_tokens=500,
temperature=1,
request_id=request_id,
stream=False,
)
if response.status_code == 200:
result = response.json()
completion_result["text"] = result.get("text", "")
completion_result["finish_reason"] = result.get("meta_info", {}).get(
"finish_reason"
)
completion_thread = threading.Thread(target=run_completion)
completion_thread.start()
time.sleep(0.1)
abort_response = self._send_abort_request(request_id)
completion_thread.join()
self.assertEqual(abort_response.status_code, 200)
self.assertIsNotNone(completion_result, "Should have completion result")
if completion_result:
finish_reason_obj = completion_result.get("finish_reason")
self.assertIsNotNone(finish_reason_obj, "Should have finish_reason")
if finish_reason_obj:
self.assertEqual(
finish_reason_obj.get("type"), "abort", "Should be aborted"
)
def test_batch_requests_with_selective_abort(self):
"""Test multiple concurrent requests with selective abort of one request."""
self.assertTrue(self._check_server_health(), "Server should be healthy")
request_ids = ["batch_test_0", "batch_test_1", "batch_test_2"]
abort_target_id = "batch_test_1"
completion_results = {}
threads = []
def run_completion(req_id, prompt):
response = self._send_completion_request(
f"Write a story about {prompt}",
max_tokens=100,
temperature=0.8,
request_id=req_id,
stream=False,
)
if response.status_code == 200:
result = response.json()
completion_results[req_id] = {
"text": result.get("text", ""),
"finish_reason": result.get("meta_info", {}).get("finish_reason"),
}
# Start all requests
prompts = ["a knight's adventure", "a space discovery", "a chef's restaurant"]
for i, req_id in enumerate(request_ids):
thread = threading.Thread(target=run_completion, args=(req_id, prompts[i]))
threads.append(thread)
thread.start()
# Abort one request
time.sleep(0.1)
abort_response = self._send_abort_request(abort_target_id)
# Wait for completion
for thread in threads:
thread.join(timeout=30)
# Verify results
self.assertEqual(abort_response.status_code, 200)
# Check aborted request
aborted_result = completion_results.get(abort_target_id)
self.assertIsNotNone(
aborted_result, f"Aborted request {abort_target_id} should have result"
)
if aborted_result:
aborted_finish_reason = aborted_result.get("finish_reason")
self.assertIsNotNone(
aborted_finish_reason, "Aborted request should have finish_reason"
)
if aborted_finish_reason:
self.assertEqual(aborted_finish_reason.get("type"), "abort")
# Check other requests completed normally
normal_completions = 0
for req_id in request_ids:
if req_id != abort_target_id and req_id in completion_results:
result = completion_results[req_id]
if result:
finish_reason = result.get("finish_reason")
if finish_reason and finish_reason.get("type") == "length":
normal_completions += 1
self.assertEqual(
normal_completions, 2, "Other 2 requests should complete normally"
)
if __name__ == "__main__":
unittest.main(verbosity=2, warnings="ignore")
|