File size: 3,021 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import asyncio
import base64

import pytest

import server
from src.FileManaging import ImageSaver


@pytest.mark.asyncio
async def test_single_request_large_num_images_is_chunked(monkeypatch):
    server.LD_MAX_IMAGES_PER_GROUP = 3

    async def immediate_to_thread(func, /, *args, **kwargs):
        return func(*args, **kwargs)

    monkeypatch.setattr(server.asyncio, "to_thread", immediate_to_thread)

    def fake_pipeline(**kwargs):
        per_sample_info = kwargs.get("per_sample_info", [])
        results = {}
        for info in per_sample_info:
            rid = info["request_id"]
            filename = f"{rid}_{len(results.get(rid, []))}_img.png"
            ImageSaver.store_image_bytes(f"LD-REQ-{rid}", filename, "Classic", b"PNGDATA")
            results.setdefault(rid, []).append({"filename": filename, "subfolder": "Classic"})
        return {"batched_results": results}

    monkeypatch.setattr(server, "pipeline", fake_pipeline)

    req = server.GenerateRequest(prompt="p", num_images=10)
    pr = server.PendingRequest(req, request_id="r_big")
    buf = server.GenerationBuffer()

    await buf._process_group([pr])

    assert pr.future.done()
    res = pr.future.result()
    # Should have produced multiple images
    assert isinstance(res, dict)
    if "images" in res:
        assert len(res["images"]) == 10
    else:
        # Single image case should not happen for num_images=10
        pytest.fail("Expected 10 images in response")

    # Buffer should be emptied
    assert ImageSaver.pop_image_bytes("LD-REQ-r_big") == []


@pytest.mark.asyncio
async def test_single_request_respects_batch_size_semantics(monkeypatch):
    server.LD_MAX_IMAGES_PER_GROUP = 32
    async def immediate_to_thread(func, /, *args, **kwargs):
        return func(*args, **kwargs)

    monkeypatch.setattr(server.asyncio, "to_thread", immediate_to_thread)
    monkeypatch.setattr(ImageSaver, "MAX_IMAGES_PER_SAVE", 16)

    calls = []

    def fake_pipeline(**kwargs):
        calls.append((kwargs["number"], kwargs["batch"]))
        per_sample_info = kwargs.get("per_sample_info", [])
        results = {}
        for idx, info in enumerate(per_sample_info):
            rid = info["request_id"]
            filename = f"{rid}_{idx}_img.png"
            ImageSaver.store_image_bytes(info["filename_prefix"], filename, "Classic", b"PNGDATA")
            results.setdefault(rid, []).append({"filename": filename, "subfolder": "Classic"})
        return {"batched_results": results}

    monkeypatch.setattr(server, "pipeline", fake_pipeline)

    req = server.GenerateRequest(prompt="p", num_images=5, batch_size=2)
    pr = server.PendingRequest(req, request_id="r_batch")
    buf = server.GenerationBuffer()

    await buf._process_group([pr])

    assert calls == [(2, 2), (2, 2), (1, 1)]
    assert pr.future.done()
    res = pr.future.result()
    assert isinstance(res, dict)
    assert "images" in res
    assert len(res["images"]) == 5
    assert ImageSaver.pop_image_bytes("LD-REQ-r_batch") == []