Hanrui / sglang /test /manual /test_async_mm_data_processor.py
Lekr0's picture
Add files using upload-large-folder tool
61ba51e verified
"""
Unit tests for AsyncMMDataProcessor.
Covers:
- Async and sync processing paths
- Concurrency limiting via semaphore
- Per-call timeout behavior (async and sync)
- Argument passthrough (images, audios, text/ids, request_obj, kwargs)
- Error propagation and shutdown behavior
"""
import asyncio
import logging
import threading
import time
from unittest.mock import Mock
import pytest
from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor
class TestAsyncMMDataProcessor:
"""Test suite for AsyncMMDataProcessor."""
@pytest.fixture
def async_processor(self):
"""Create a processor exposing an async process_mm_data_async."""
class AsyncProc:
async def process_mm_data_async(
self,
*,
image_data=None,
audio_data=None,
input_text=None,
request_obj=None,
**kwargs,
):
# Allow tests to simulate latency via kwargs
delay = kwargs.get("delay_s", 0.0)
if delay:
await asyncio.sleep(delay)
return {
"path": "async",
"images": image_data,
"audios": audio_data,
"text": input_text,
"request": request_obj,
"kwargs": kwargs,
}
return AsyncProc()
@pytest.fixture
def sync_processor(self):
"""Provide a processor exposing a sync process_mm_data."""
class SyncProc:
def process_mm_data(
self,
*,
image_data=None,
audio_data=None,
input_text=None,
request_obj=None,
**kwargs,
):
delay = kwargs.get("delay_s", 0.0)
if delay:
# Simulate CPU/blocking work
time.sleep(delay)
return {
"path": "sync",
"images": image_data,
"audios": audio_data,
"text": input_text,
"request": request_obj,
"kwargs": kwargs,
}
return SyncProc()
@pytest.mark.asyncio
async def test_async_path_basic(self, async_processor):
"""Async processor should be awaited directly."""
proc = AsyncMMDataProcessor(async_processor)
out = await proc.process(
image_data=["img1.png"],
audio_data=["a.wav"],
input_text_or_ids="hello",
request_obj={"rid": 1},
mode="fast",
)
assert out["path"] == "async"
assert out["images"] == ["img1.png"]
assert out["audios"] == ["a.wav"]
assert out["text"] == "hello"
assert out["request"] == {"rid": 1}
assert out["kwargs"]["mode"] == "fast"
@pytest.mark.asyncio
async def test_sync_fallback_basic(self, sync_processor):
"""Sync processor should run in fallback executor."""
proc = AsyncMMDataProcessor(sync_processor)
out = await proc.process(
image_data=[b"\x00\x01"],
audio_data=None,
input_text_or_ids=[1, 2, 3],
request_obj="req-obj",
role="user",
)
assert out["path"] == "sync"
assert out["images"] == [b"\x00\x01"]
assert out["audios"] is None
assert out["text"] == [1, 2, 3]
assert out["request"] == "req-obj"
assert out["kwargs"]["role"] == "user"
@pytest.mark.asyncio
async def test_timeout_async(self, async_processor):
"""Timeout should raise asyncio.TimeoutError for async path."""
proc = AsyncMMDataProcessor(async_processor, timeout_s=0.01)
with pytest.raises(asyncio.TimeoutError):
await proc.process(
input_text_or_ids="slow",
request_obj=None,
delay_s=0.05, # longer than timeout
)
@pytest.mark.asyncio
async def test_timeout_sync(self, sync_processor):
"""Timeout should raise asyncio.TimeoutError for sync fallback path."""
proc = AsyncMMDataProcessor(sync_processor, timeout_s=0.01)
with pytest.raises(asyncio.TimeoutError):
await proc.process(
input_text_or_ids="slow",
request_obj=None,
delay_s=0.05, # longer than timeout
)
@pytest.mark.asyncio
async def test_semaphore_release_after_timeout(self, sync_processor):
"""
If a call times out, the semaphore should be released so a subsequent call can proceed.
Use >=2 fallback workers so the timed-out thread doesn't block the next call.
"""
proc = AsyncMMDataProcessor(
sync_processor,
max_concurrent_calls=2,
timeout_s=0.01,
)
# First call will time out
with pytest.raises(asyncio.TimeoutError):
await proc.process(
input_text_or_ids="slow1", request_obj=None, delay_s=0.05
)
# Second call should be able to acquire the semaphore and complete
out = await proc.process(input_text_or_ids="ok", request_obj=None, delay_s=0.0)
assert out["text"] == "ok"
@pytest.mark.asyncio
async def test_concurrency_limit_async(self):
"""Ensure max_concurrent_calls caps concurrency for async path."""
current = 0
max_seen = 0
class AsyncProc:
async def process_mm_data_async(self, **kwargs):
nonlocal current, max_seen
current += 1
max_seen = max(max_seen, current)
try:
await asyncio.sleep(0.02)
return {"ok": True}
finally:
current -= 1
proc = AsyncMMDataProcessor(AsyncProc(), max_concurrent_calls=2)
tasks = [
proc.process(input_text_or_ids=f"t{i}", request_obj=None) for i in range(6)
]
await asyncio.gather(*tasks)
assert max_seen <= 2
@pytest.mark.asyncio
async def test_concurrency_limit_sync(self):
"""Ensure max_concurrent_calls caps concurrency for sync fallback path."""
current = 0
max_seen = 0
lock = threading.Lock()
class SyncProc:
def process_mm_data(self, **kwargs):
nonlocal current, max_seen
with lock:
current += 1
max_seen = max(max_seen, current)
try:
time.sleep(0.02)
return {"ok": True}
finally:
with lock:
current -= 1
proc = AsyncMMDataProcessor(SyncProc(), max_concurrent_calls=3)
tasks = [
proc.process(input_text_or_ids=f"s{i}", request_obj=None) for i in range(9)
]
await asyncio.gather(*tasks)
assert max_seen <= 3
@pytest.mark.asyncio
async def test_error_from_async_processor(self):
"""Exceptions raised by the async processor should propagate."""
class BadAsync:
async def process_mm_data_async(self, **_):
await asyncio.sleep(0)
raise ValueError("async boom")
proc = AsyncMMDataProcessor(BadAsync())
with pytest.raises(ValueError, match="async boom"):
await proc.process(input_text_or_ids="x", request_obj=None)
@pytest.mark.asyncio
async def test_error_from_sync_processor(self):
"""Exceptions raised by the sync processor should propagate."""
class BadSync:
def process_mm_data(self, **_):
raise RuntimeError("sync boom")
proc = AsyncMMDataProcessor(BadSync())
with pytest.raises(RuntimeError, match="sync boom"):
await proc.process(input_text_or_ids="x", request_obj=None)
@pytest.mark.asyncio
async def test_missing_both_methods_raises(self):
"""Processor missing both methods should raise at call time."""
class Empty:
pass
proc = AsyncMMDataProcessor(Empty())
with pytest.raises(
RuntimeError, match="neither 'process_mm_data_async' nor 'process_mm_data'"
):
await proc.process(input_text_or_ids="x", request_obj=None)
@pytest.mark.asyncio
async def test_async_attribute_not_coroutine_uses_sync_fallback(self):
"""
If `process_mm_data_async` exists but isn't a coroutine function,
wrapper should treat it as sync and use `process_mm_data`.
"""
class WeirdProc:
# Not a coroutine function:
def process_mm_data_async(self, **_):
return {"path": "would-be-async"}
def process_mm_data(self, **_):
return {"path": "sync"}
proc = AsyncMMDataProcessor(WeirdProc())
out = await proc.process(input_text_or_ids="x", request_obj=None)
assert out["path"] == "sync"
@pytest.mark.asyncio
async def test_kwargs_and_request_passthrough_async(self, async_processor):
"""Extra kwargs and request_obj should be forwarded on async path."""
proc = AsyncMMDataProcessor(async_processor)
out = await proc.process(
image_data=["i1", "i2"],
audio_data=["a1"],
input_text_or_ids="hello world",
request_obj={"uid": 42},
return_meta=True,
delay_s=0.0,
)
assert out["images"] == ["i1", "i2"]
assert out["audios"] == ["a1"]
assert out["text"] == "hello world"
assert out["request"] == {"uid": 42}
assert out["kwargs"]["return_meta"] is True
@pytest.mark.asyncio
async def test_kwargs_and_request_passthrough_sync(self, sync_processor):
"""Extra kwargs and request_obj should be forwarded on sync path."""
proc = AsyncMMDataProcessor(sync_processor)
out = await proc.process(
image_data=None,
audio_data=[],
input_text_or_ids=[101, 102],
request_obj=("r", 7),
lang="en",
)
assert out["images"] is None
assert out["audios"] == []
assert out["text"] == [101, 102]
assert out["request"] == ("r", 7)
assert out["kwargs"]["lang"] == "en"
def test_shutdown_on_sync_executor(self, sync_processor):
"""Explicit shutdown should close fallback executor for sync path."""
proc = AsyncMMDataProcessor(sync_processor)
# Swap real executor for a mock to assert shutdown behavior
proc.fallback_exec = Mock()
proc.shutdown()
proc.fallback_exec.shutdown.assert_called_once_with(wait=False)
def test_del_calls_shutdown(self, sync_processor, caplog):
"""__del__ should best-effort shutdown without raising."""
caplog.set_level(logging.DEBUG)
proc = AsyncMMDataProcessor(sync_processor)
proc.fallback_exec = Mock()
# Simulate object destruction
proc.__del__()
proc.fallback_exec.shutdown.assert_called_once_with(wait=False)
@pytest.mark.asyncio
async def test_concurrent_mixed_requests(self, async_processor):
"""Mix different payloads and ensure all complete with valid outputs."""
proc = AsyncMMDataProcessor(async_processor, max_concurrent_calls=4)
tasks = [
proc.process(input_text_or_ids="t1", request_obj=1),
proc.process(image_data=["i.png"], input_text_or_ids=[9, 8], request_obj=2),
proc.process(
audio_data=["v.wav"], input_text_or_ids="speech", request_obj=3
),
proc.process(
image_data=[], audio_data=[], input_text_or_ids=None, request_obj=4
),
]
outs = await asyncio.gather(*tasks)
assert len(outs) == 4
for out in outs:
assert "path" in out
assert out["path"] == "async"
@pytest.mark.asyncio
async def test_many_requests_values_match_inputs(self, sync_processor):
"""For sync path, ensure each response corresponds to its specific input."""
proc = AsyncMMDataProcessor(sync_processor, max_concurrent_calls=8)
texts = [f"msg-{i}" for i in range(10)]
tasks = [
proc.process(input_text_or_ids=t, request_obj=i)
for i, t in enumerate(texts)
]
outs = await asyncio.gather(*tasks)
got = [o["text"] for o in outs]
assert got == texts
if __name__ == "__main__":
pytest.main([__file__])