File size: 1,893 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import logging
import torch

from src.FileManaging import ImageSaver


def test_save_images_guard(tmp_path, caplog):
    """save_images should abort and warn when asked to save too many images at once."""
    saver = ImageSaver.SaveImage()
    saver.output_dir = str(tmp_path)

    # Create more images than MAX_IMAGES_PER_SAVE but keep them small to avoid memory pressure
    images = [torch.rand(3, 32, 32) for _ in range(ImageSaver.MAX_IMAGES_PER_SAVE + 1)]

    caplog.set_level(logging.WARNING)
    res = saver.save_images(images)

    assert isinstance(res, dict)
    assert res["ui"]["images"] == []
    assert any("Attempting to save" in rec.getMessage() for rec in caplog.records)


def test_save_images_aborts_on_large_batched_tensor(caplog):
    """A single batched tensor with a very large batch dimension should be treated like many images and abort."""
    saver = ImageSaver.SaveImage()

    batch = 1024
    tensor = torch.zeros((batch, 3, 16, 16))

    caplog.set_level(logging.WARNING)
    res = saver.save_images([tensor])

    assert res == {"ui": {"images": []}}
    assert any("Attempting to save" in rec.getMessage() for rec in caplog.records)
    # Diagnostic details should include an idx=0 entry and the batch size (1024) in the message
    assert any("idx=0" in rec.getMessage() and "1024" in rec.getMessage() for rec in caplog.records)
    # Ensure the filename_prefix and store_bytes_prefix are included for tracing
    assert any("filename_prefix=LD" in rec.getMessage() for rec in caplog.records)


def test_save_images_saves_single_image(tmp_path):
    saver = ImageSaver.SaveImage()
    saver.output_dir = str(tmp_path)

    tensor = torch.rand((1, 3, 32, 32))
    res = saver.save_images([tensor], filename_prefix="LD", prompt="test")

    assert isinstance(res, dict)
    assert "ui" in res and "images" in res["ui"]
    assert len(res["ui"]["images"]) == 1