File size: 2,049 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import asyncio

import pytest

import server
from src.FileManaging import ImageSaver


@pytest.mark.asyncio
async def test_chunking_respects_image_saver_limit(monkeypatch):
    server.LD_MAX_IMAGES_PER_GROUP = 32

    async def immediate_to_thread(func, /, *args, **kwargs):
        return func(*args, **kwargs)

    monkeypatch.setattr(server.asyncio, "to_thread", immediate_to_thread)

    # Force a small ImageSaver limit so the buffer must chunk accordingly
    monkeypatch.setattr(ImageSaver, "MAX_IMAGES_PER_SAVE", 3)

    calls = []

    def fake_pipeline(**kwargs):
        # Record invocation and the number of per-sample entries
        calls.append(kwargs)
        per_sample_info = kwargs.get("per_sample_info", [])
        results = {}
        for info in per_sample_info:
            rid = info["request_id"]
            filename = f"{rid}_{len(results.get(rid, []))}_img.png"
            ImageSaver.store_image_bytes(info["filename_prefix"], filename, "Classic", b"PNGDATA")
            results.setdefault(rid, []).append({"filename": filename, "subfolder": "Classic"})
        return {"batched_results": results}

    monkeypatch.setattr(server, "pipeline", fake_pipeline)

    # Create 10 small requests (1 image each) with a large batch_size so the
    # saver limit, not the request batch size, is what forces chunking.
    items = [
        server.PendingRequest(
            server.GenerateRequest(prompt=f"p{i}", num_images=1, batch_size=10),
            request_id=str(i),
        )
        for i in range(10)
    ]

    buf = server.GenerationBuffer()

    await buf._process_group(items)

    # Expect chunking to force multiple pipeline invocations
    assert len(calls) == 4
    # Each chunk should have at most 3 images as per ImageSaver.MAX_IMAGES_PER_SAVE
    for c in calls:
        assert c["number"] <= 3

    # Ensure all requests were completed
    for p in items:
        assert p.future.done()
        res = p.future.result()
        assert isinstance(res, dict)
        assert "images" in res or "image" in res