| """ |
| Unit tests for AsyncMMDataProcessor. |
| |
| Covers: |
| - Async and sync processing paths |
| - Concurrency limiting via semaphore |
| - Per-call timeout behavior (async and sync) |
| - Argument passthrough (images, audios, text/ids, request_obj, kwargs) |
| - Error propagation and shutdown behavior |
| """ |
|
|
| import asyncio |
| import logging |
| import threading |
| import time |
| from unittest.mock import Mock |
|
|
| import pytest |
|
|
| from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor |
|
|
|
|
| class TestAsyncMMDataProcessor: |
| """Test suite for AsyncMMDataProcessor.""" |
|
|
| @pytest.fixture |
| def async_processor(self): |
| """Create a processor exposing an async process_mm_data_async.""" |
|
|
| class AsyncProc: |
| async def process_mm_data_async( |
| self, |
| *, |
| image_data=None, |
| audio_data=None, |
| input_text=None, |
| request_obj=None, |
| **kwargs, |
| ): |
| |
| delay = kwargs.get("delay_s", 0.0) |
| if delay: |
| await asyncio.sleep(delay) |
| return { |
| "path": "async", |
| "images": image_data, |
| "audios": audio_data, |
| "text": input_text, |
| "request": request_obj, |
| "kwargs": kwargs, |
| } |
|
|
| return AsyncProc() |
|
|
| @pytest.fixture |
| def sync_processor(self): |
| """Provide a processor exposing a sync process_mm_data.""" |
|
|
| class SyncProc: |
| def process_mm_data( |
| self, |
| *, |
| image_data=None, |
| audio_data=None, |
| input_text=None, |
| request_obj=None, |
| **kwargs, |
| ): |
| delay = kwargs.get("delay_s", 0.0) |
| if delay: |
| |
| time.sleep(delay) |
| return { |
| "path": "sync", |
| "images": image_data, |
| "audios": audio_data, |
| "text": input_text, |
| "request": request_obj, |
| "kwargs": kwargs, |
| } |
|
|
| return SyncProc() |
|
|
| @pytest.mark.asyncio |
| async def test_async_path_basic(self, async_processor): |
| """Async processor should be awaited directly.""" |
| proc = AsyncMMDataProcessor(async_processor) |
| out = await proc.process( |
| image_data=["img1.png"], |
| audio_data=["a.wav"], |
| input_text_or_ids="hello", |
| request_obj={"rid": 1}, |
| mode="fast", |
| ) |
| assert out["path"] == "async" |
| assert out["images"] == ["img1.png"] |
| assert out["audios"] == ["a.wav"] |
| assert out["text"] == "hello" |
| assert out["request"] == {"rid": 1} |
| assert out["kwargs"]["mode"] == "fast" |
|
|
| @pytest.mark.asyncio |
| async def test_sync_fallback_basic(self, sync_processor): |
| """Sync processor should run in fallback executor.""" |
| proc = AsyncMMDataProcessor(sync_processor) |
| out = await proc.process( |
| image_data=[b"\x00\x01"], |
| audio_data=None, |
| input_text_or_ids=[1, 2, 3], |
| request_obj="req-obj", |
| role="user", |
| ) |
| assert out["path"] == "sync" |
| assert out["images"] == [b"\x00\x01"] |
| assert out["audios"] is None |
| assert out["text"] == [1, 2, 3] |
| assert out["request"] == "req-obj" |
| assert out["kwargs"]["role"] == "user" |
|
|
| @pytest.mark.asyncio |
| async def test_timeout_async(self, async_processor): |
| """Timeout should raise asyncio.TimeoutError for async path.""" |
| proc = AsyncMMDataProcessor(async_processor, timeout_s=0.01) |
| with pytest.raises(asyncio.TimeoutError): |
| await proc.process( |
| input_text_or_ids="slow", |
| request_obj=None, |
| delay_s=0.05, |
| ) |
|
|
| @pytest.mark.asyncio |
| async def test_timeout_sync(self, sync_processor): |
| """Timeout should raise asyncio.TimeoutError for sync fallback path.""" |
| proc = AsyncMMDataProcessor(sync_processor, timeout_s=0.01) |
| with pytest.raises(asyncio.TimeoutError): |
| await proc.process( |
| input_text_or_ids="slow", |
| request_obj=None, |
| delay_s=0.05, |
| ) |
|
|
| @pytest.mark.asyncio |
| async def test_semaphore_release_after_timeout(self, sync_processor): |
| """ |
| If a call times out, the semaphore should be released so a subsequent call can proceed. |
| Use >=2 fallback workers so the timed-out thread doesn't block the next call. |
| """ |
| proc = AsyncMMDataProcessor( |
| sync_processor, |
| max_concurrent_calls=2, |
| timeout_s=0.01, |
| ) |
|
|
| |
| with pytest.raises(asyncio.TimeoutError): |
| await proc.process( |
| input_text_or_ids="slow1", request_obj=None, delay_s=0.05 |
| ) |
|
|
| |
| out = await proc.process(input_text_or_ids="ok", request_obj=None, delay_s=0.0) |
| assert out["text"] == "ok" |
|
|
| @pytest.mark.asyncio |
| async def test_concurrency_limit_async(self): |
| """Ensure max_concurrent_calls caps concurrency for async path.""" |
| current = 0 |
| max_seen = 0 |
|
|
| class AsyncProc: |
| async def process_mm_data_async(self, **kwargs): |
| nonlocal current, max_seen |
| current += 1 |
| max_seen = max(max_seen, current) |
| try: |
| await asyncio.sleep(0.02) |
| return {"ok": True} |
| finally: |
| current -= 1 |
|
|
| proc = AsyncMMDataProcessor(AsyncProc(), max_concurrent_calls=2) |
|
|
| tasks = [ |
| proc.process(input_text_or_ids=f"t{i}", request_obj=None) for i in range(6) |
| ] |
| await asyncio.gather(*tasks) |
|
|
| assert max_seen <= 2 |
|
|
| @pytest.mark.asyncio |
| async def test_concurrency_limit_sync(self): |
| """Ensure max_concurrent_calls caps concurrency for sync fallback path.""" |
| current = 0 |
| max_seen = 0 |
| lock = threading.Lock() |
|
|
| class SyncProc: |
| def process_mm_data(self, **kwargs): |
| nonlocal current, max_seen |
| with lock: |
| current += 1 |
| max_seen = max(max_seen, current) |
| try: |
| time.sleep(0.02) |
| return {"ok": True} |
| finally: |
| with lock: |
| current -= 1 |
|
|
| proc = AsyncMMDataProcessor(SyncProc(), max_concurrent_calls=3) |
|
|
| tasks = [ |
| proc.process(input_text_or_ids=f"s{i}", request_obj=None) for i in range(9) |
| ] |
| await asyncio.gather(*tasks) |
|
|
| assert max_seen <= 3 |
|
|
| @pytest.mark.asyncio |
| async def test_error_from_async_processor(self): |
| """Exceptions raised by the async processor should propagate.""" |
|
|
| class BadAsync: |
| async def process_mm_data_async(self, **_): |
| await asyncio.sleep(0) |
| raise ValueError("async boom") |
|
|
| proc = AsyncMMDataProcessor(BadAsync()) |
| with pytest.raises(ValueError, match="async boom"): |
| await proc.process(input_text_or_ids="x", request_obj=None) |
|
|
| @pytest.mark.asyncio |
| async def test_error_from_sync_processor(self): |
| """Exceptions raised by the sync processor should propagate.""" |
|
|
| class BadSync: |
| def process_mm_data(self, **_): |
| raise RuntimeError("sync boom") |
|
|
| proc = AsyncMMDataProcessor(BadSync()) |
| with pytest.raises(RuntimeError, match="sync boom"): |
| await proc.process(input_text_or_ids="x", request_obj=None) |
|
|
| @pytest.mark.asyncio |
| async def test_missing_both_methods_raises(self): |
| """Processor missing both methods should raise at call time.""" |
|
|
| class Empty: |
| pass |
|
|
| proc = AsyncMMDataProcessor(Empty()) |
| with pytest.raises( |
| RuntimeError, match="neither 'process_mm_data_async' nor 'process_mm_data'" |
| ): |
| await proc.process(input_text_or_ids="x", request_obj=None) |
|
|
| @pytest.mark.asyncio |
| async def test_async_attribute_not_coroutine_uses_sync_fallback(self): |
| """ |
| If `process_mm_data_async` exists but isn't a coroutine function, |
| wrapper should treat it as sync and use `process_mm_data`. |
| """ |
|
|
| class WeirdProc: |
| |
| def process_mm_data_async(self, **_): |
| return {"path": "would-be-async"} |
|
|
| def process_mm_data(self, **_): |
| return {"path": "sync"} |
|
|
| proc = AsyncMMDataProcessor(WeirdProc()) |
| out = await proc.process(input_text_or_ids="x", request_obj=None) |
| assert out["path"] == "sync" |
|
|
| @pytest.mark.asyncio |
| async def test_kwargs_and_request_passthrough_async(self, async_processor): |
| """Extra kwargs and request_obj should be forwarded on async path.""" |
| proc = AsyncMMDataProcessor(async_processor) |
| out = await proc.process( |
| image_data=["i1", "i2"], |
| audio_data=["a1"], |
| input_text_or_ids="hello world", |
| request_obj={"uid": 42}, |
| return_meta=True, |
| delay_s=0.0, |
| ) |
| assert out["images"] == ["i1", "i2"] |
| assert out["audios"] == ["a1"] |
| assert out["text"] == "hello world" |
| assert out["request"] == {"uid": 42} |
| assert out["kwargs"]["return_meta"] is True |
|
|
| @pytest.mark.asyncio |
| async def test_kwargs_and_request_passthrough_sync(self, sync_processor): |
| """Extra kwargs and request_obj should be forwarded on sync path.""" |
| proc = AsyncMMDataProcessor(sync_processor) |
| out = await proc.process( |
| image_data=None, |
| audio_data=[], |
| input_text_or_ids=[101, 102], |
| request_obj=("r", 7), |
| lang="en", |
| ) |
| assert out["images"] is None |
| assert out["audios"] == [] |
| assert out["text"] == [101, 102] |
| assert out["request"] == ("r", 7) |
| assert out["kwargs"]["lang"] == "en" |
|
|
| def test_shutdown_on_sync_executor(self, sync_processor): |
| """Explicit shutdown should close fallback executor for sync path.""" |
| proc = AsyncMMDataProcessor(sync_processor) |
| |
| proc.fallback_exec = Mock() |
| proc.shutdown() |
| proc.fallback_exec.shutdown.assert_called_once_with(wait=False) |
|
|
| def test_del_calls_shutdown(self, sync_processor, caplog): |
| """__del__ should best-effort shutdown without raising.""" |
| caplog.set_level(logging.DEBUG) |
| proc = AsyncMMDataProcessor(sync_processor) |
| proc.fallback_exec = Mock() |
| |
| proc.__del__() |
| proc.fallback_exec.shutdown.assert_called_once_with(wait=False) |
|
|
| @pytest.mark.asyncio |
| async def test_concurrent_mixed_requests(self, async_processor): |
| """Mix different payloads and ensure all complete with valid outputs.""" |
| proc = AsyncMMDataProcessor(async_processor, max_concurrent_calls=4) |
|
|
| tasks = [ |
| proc.process(input_text_or_ids="t1", request_obj=1), |
| proc.process(image_data=["i.png"], input_text_or_ids=[9, 8], request_obj=2), |
| proc.process( |
| audio_data=["v.wav"], input_text_or_ids="speech", request_obj=3 |
| ), |
| proc.process( |
| image_data=[], audio_data=[], input_text_or_ids=None, request_obj=4 |
| ), |
| ] |
| outs = await asyncio.gather(*tasks) |
| assert len(outs) == 4 |
| for out in outs: |
| assert "path" in out |
| assert out["path"] == "async" |
|
|
| @pytest.mark.asyncio |
| async def test_many_requests_values_match_inputs(self, sync_processor): |
| """For sync path, ensure each response corresponds to its specific input.""" |
| proc = AsyncMMDataProcessor(sync_processor, max_concurrent_calls=8) |
| texts = [f"msg-{i}" for i in range(10)] |
| tasks = [ |
| proc.process(input_text_or_ids=t, request_obj=i) |
| for i, t in enumerate(texts) |
| ] |
| outs = await asyncio.gather(*tasks) |
| got = [o["text"] for o in outs] |
| assert got == texts |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__]) |
|
|