LightDiffusion-Next / tests /unit /test_single_request_chunking.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
import asyncio
import base64
import pytest
import server
from src.FileManaging import ImageSaver
@pytest.mark.asyncio
async def test_single_request_large_num_images_is_chunked(monkeypatch):
server.LD_MAX_IMAGES_PER_GROUP = 3
async def immediate_to_thread(func, /, *args, **kwargs):
return func(*args, **kwargs)
monkeypatch.setattr(server.asyncio, "to_thread", immediate_to_thread)
def fake_pipeline(**kwargs):
per_sample_info = kwargs.get("per_sample_info", [])
results = {}
for info in per_sample_info:
rid = info["request_id"]
filename = f"{rid}_{len(results.get(rid, []))}_img.png"
ImageSaver.store_image_bytes(f"LD-REQ-{rid}", filename, "Classic", b"PNGDATA")
results.setdefault(rid, []).append({"filename": filename, "subfolder": "Classic"})
return {"batched_results": results}
monkeypatch.setattr(server, "pipeline", fake_pipeline)
req = server.GenerateRequest(prompt="p", num_images=10)
pr = server.PendingRequest(req, request_id="r_big")
buf = server.GenerationBuffer()
await buf._process_group([pr])
assert pr.future.done()
res = pr.future.result()
# Should have produced multiple images
assert isinstance(res, dict)
if "images" in res:
assert len(res["images"]) == 10
else:
# Single image case should not happen for num_images=10
pytest.fail("Expected 10 images in response")
# Buffer should be emptied
assert ImageSaver.pop_image_bytes("LD-REQ-r_big") == []
@pytest.mark.asyncio
async def test_single_request_respects_batch_size_semantics(monkeypatch):
server.LD_MAX_IMAGES_PER_GROUP = 32
async def immediate_to_thread(func, /, *args, **kwargs):
return func(*args, **kwargs)
monkeypatch.setattr(server.asyncio, "to_thread", immediate_to_thread)
monkeypatch.setattr(ImageSaver, "MAX_IMAGES_PER_SAVE", 16)
calls = []
def fake_pipeline(**kwargs):
calls.append((kwargs["number"], kwargs["batch"]))
per_sample_info = kwargs.get("per_sample_info", [])
results = {}
for idx, info in enumerate(per_sample_info):
rid = info["request_id"]
filename = f"{rid}_{idx}_img.png"
ImageSaver.store_image_bytes(info["filename_prefix"], filename, "Classic", b"PNGDATA")
results.setdefault(rid, []).append({"filename": filename, "subfolder": "Classic"})
return {"batched_results": results}
monkeypatch.setattr(server, "pipeline", fake_pipeline)
req = server.GenerateRequest(prompt="p", num_images=5, batch_size=2)
pr = server.PendingRequest(req, request_id="r_batch")
buf = server.GenerationBuffer()
await buf._process_group([pr])
assert calls == [(2, 2), (2, 2), (1, 1)]
assert pr.future.done()
res = pr.future.result()
assert isinstance(res, dict)
assert "images" in res
assert len(res["images"]) == 5
assert ImageSaver.pop_image_bytes("LD-REQ-r_batch") == []