LightDiffusion-Next / tests /unit /test_image_saver.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
import logging
import torch
from src.FileManaging import ImageSaver
def test_save_images_guard(tmp_path, caplog):
"""save_images should abort and warn when asked to save too many images at once."""
saver = ImageSaver.SaveImage()
saver.output_dir = str(tmp_path)
# Create more images than MAX_IMAGES_PER_SAVE but keep them small to avoid memory pressure
images = [torch.rand(3, 32, 32) for _ in range(ImageSaver.MAX_IMAGES_PER_SAVE + 1)]
caplog.set_level(logging.WARNING)
res = saver.save_images(images)
assert isinstance(res, dict)
assert res["ui"]["images"] == []
assert any("Attempting to save" in rec.getMessage() for rec in caplog.records)
def test_save_images_aborts_on_large_batched_tensor(caplog):
"""A single batched tensor with a very large batch dimension should be treated like many images and abort."""
saver = ImageSaver.SaveImage()
batch = 1024
tensor = torch.zeros((batch, 3, 16, 16))
caplog.set_level(logging.WARNING)
res = saver.save_images([tensor])
assert res == {"ui": {"images": []}}
assert any("Attempting to save" in rec.getMessage() for rec in caplog.records)
# Diagnostic details should include an idx=0 entry and the batch size (1024) in the message
assert any("idx=0" in rec.getMessage() and "1024" in rec.getMessage() for rec in caplog.records)
# Ensure the filename_prefix and store_bytes_prefix are included for tracing
assert any("filename_prefix=LD" in rec.getMessage() for rec in caplog.records)
def test_save_images_saves_single_image(tmp_path):
saver = ImageSaver.SaveImage()
saver.output_dir = str(tmp_path)
tensor = torch.rand((1, 3, 32, 32))
res = saver.save_images([tensor], filename_prefix="LD", prompt="test")
assert isinstance(res, dict)
assert "ui" in res and "images" in res["ui"]
assert len(res["ui"]["images"]) == 1