Spaces:
Running on Zero
Running on Zero
| import asyncio | |
| import base64 | |
| import pytest | |
| import server | |
| from src.FileManaging import ImageSaver | |
| async def test_single_request_large_num_images_is_chunked(monkeypatch): | |
| server.LD_MAX_IMAGES_PER_GROUP = 3 | |
| async def immediate_to_thread(func, /, *args, **kwargs): | |
| return func(*args, **kwargs) | |
| monkeypatch.setattr(server.asyncio, "to_thread", immediate_to_thread) | |
| def fake_pipeline(**kwargs): | |
| per_sample_info = kwargs.get("per_sample_info", []) | |
| results = {} | |
| for info in per_sample_info: | |
| rid = info["request_id"] | |
| filename = f"{rid}_{len(results.get(rid, []))}_img.png" | |
| ImageSaver.store_image_bytes(f"LD-REQ-{rid}", filename, "Classic", b"PNGDATA") | |
| results.setdefault(rid, []).append({"filename": filename, "subfolder": "Classic"}) | |
| return {"batched_results": results} | |
| monkeypatch.setattr(server, "pipeline", fake_pipeline) | |
| req = server.GenerateRequest(prompt="p", num_images=10) | |
| pr = server.PendingRequest(req, request_id="r_big") | |
| buf = server.GenerationBuffer() | |
| await buf._process_group([pr]) | |
| assert pr.future.done() | |
| res = pr.future.result() | |
| # Should have produced multiple images | |
| assert isinstance(res, dict) | |
| if "images" in res: | |
| assert len(res["images"]) == 10 | |
| else: | |
| # Single image case should not happen for num_images=10 | |
| pytest.fail("Expected 10 images in response") | |
| # Buffer should be emptied | |
| assert ImageSaver.pop_image_bytes("LD-REQ-r_big") == [] | |
| async def test_single_request_respects_batch_size_semantics(monkeypatch): | |
| server.LD_MAX_IMAGES_PER_GROUP = 32 | |
| async def immediate_to_thread(func, /, *args, **kwargs): | |
| return func(*args, **kwargs) | |
| monkeypatch.setattr(server.asyncio, "to_thread", immediate_to_thread) | |
| monkeypatch.setattr(ImageSaver, "MAX_IMAGES_PER_SAVE", 16) | |
| calls = [] | |
| def fake_pipeline(**kwargs): | |
| calls.append((kwargs["number"], kwargs["batch"])) | |
| per_sample_info = kwargs.get("per_sample_info", []) | |
| results = {} | |
| for idx, info in enumerate(per_sample_info): | |
| rid = info["request_id"] | |
| filename = f"{rid}_{idx}_img.png" | |
| ImageSaver.store_image_bytes(info["filename_prefix"], filename, "Classic", b"PNGDATA") | |
| results.setdefault(rid, []).append({"filename": filename, "subfolder": "Classic"}) | |
| return {"batched_results": results} | |
| monkeypatch.setattr(server, "pipeline", fake_pipeline) | |
| req = server.GenerateRequest(prompt="p", num_images=5, batch_size=2) | |
| pr = server.PendingRequest(req, request_id="r_batch") | |
| buf = server.GenerationBuffer() | |
| await buf._process_group([pr]) | |
| assert calls == [(2, 2), (2, 2), (1, 1)] | |
| assert pr.future.done() | |
| res = pr.future.result() | |
| assert isinstance(res, dict) | |
| assert "images" in res | |
| assert len(res["images"]) == 5 | |
| assert ImageSaver.pop_image_bytes("LD-REQ-r_batch") == [] | |