import unittest import torch from PIL import Image from src.Utilities import color from src.user.app_instance import AppInstance from src.AutoEncoders import taesd class TestPreviewQuality(unittest.TestCase): def test_linear_to_srgb_values(self): # Known values across the transfer function vals = torch.tensor([0.0, 0.0031308, 0.04, 0.5, 1.0], dtype=torch.float32) out = color.linear_to_srgb(vals) # Expected values (approx) expected = torch.tensor([ 0.0, 0.0031308 * 12.92, 1.055 * (0.04 ** (1.0 / 2.4)) - 0.055, 1.055 * (0.5 ** (1.0 / 2.4)) - 0.055, 1.0, ], dtype=torch.float32) self.assertTrue(torch.allclose(out, expected, atol=1e-6)) def test_decode_uses_lanczos_for_thumbnail(self): # Monkeypatch TAESD.decode to return a synthetic tensor large enough # to trigger downsampling, and monkeypatch Image.fromarray to capture # the resample argument passed to thumbnail. orig_decode = taesd.TAESD.decode orig_fromarray = Image.fromarray called = {} def fake_decode(self, x): # Return a [B, C, H, W] tensor in [-1, 1] so that after .add(1).mul(0.5) # values are in [0,1]. Use a large size to trigger thumbnail. return torch.full((1, 3, 1024, 1024), 0.0, dtype=x.dtype) class FakeImage: def __init__(self): self.width = 1024 self.height = 1024 def thumbnail(self, size, resample=None): called['resample'] = resample # Emulate behavior of PIL thumbnail self.width = min(self.width, size[0]) self.height = min(self.height, size[1]) def save(self, *args, **kwargs): return None def fake_fromarray(arr, mode=None): return FakeImage() try: taesd.TAESD.decode = fake_decode Image.fromarray = fake_fromarray latent = torch.zeros((1, 4, 64, 64), dtype=torch.float32) imgs = taesd.decode_latents_to_images(latent) # Confirm our fake thumbnail captured the resample argument self.assertIn('resample', called) self.assertEqual(called['resample'], Image.Resampling.LANCZOS) finally: taesd.TAESD.decode = orig_decode Image.fromarray = orig_fromarray def test_app_preview_defaults(self): app = AppInstance() self.assertTrue(hasattr(app, 'preview_srgb')) self.assertTrue(app.preview_srgb) self.assertEqual(app.preview_format, 'WEBP') self.assertEqual(app.preview_quality, 90) def test_decode_applies_srgb_when_enabled(self): # Monkeypatch TAESD.decode to return constant zero (-> 0.5 after norm) orig_decode = taesd.TAESD.decode try: def fake_decode(self, x): return torch.zeros((1, 3, 4, 4), dtype=x.dtype, device=x.device) taesd.TAESD.decode = fake_decode # Ensure preview_srgb enabled from src.user.app_instance import app as global_app old_flag = global_app.preview_srgb global_app.preview_srgb = True latent = torch.zeros((1, 4, 4, 4), dtype=torch.float32) imgs = taesd.decode_latents_to_images(latent) self.assertTrue(len(imgs) > 0) img = imgs[0] r, g, b = img.getpixel((0, 0)) # Expected sRGB value for linear=0.5 lin = 0.5 srgb = 1.055 * (lin ** (1.0 / 2.4)) - 0.055 # The implementation casts to uint8 (truncates), so expect floor behavior expect = int(srgb * 255.0) self.assertEqual(r, expect) self.assertEqual(g, expect) self.assertEqual(b, expect) finally: taesd.TAESD.decode = orig_decode global_app.preview_srgb = old_flag def test_server_callback_uses_preview_format(self): # Ensure server's preview callback attempts to use configured preview format import io import server as server_mod orig_save = Image.Image.save orig_decode = server_mod.decode_latents_to_images try: saved = [] def fake_save(self, buffer, format=None, **kwargs): saved.append(format) # write some bytes so the buffer isn't empty try: buffer.write(b"OK") except Exception: pass Image.Image.save = fake_save def fake_decode(latents, flux=False): return [Image.new('RGB', (64, 64), color='red')] server_mod.decode_latents_to_images = fake_decode from src.user import app_instance as _app_instance old_fmt = _app_instance.app.preview_format _app_instance.app.preview_format = 'WEBP' cb = server_mod.make_server_callback(20) cb({'i': 0, 'total_steps': 20, 'denoised': torch.zeros((1, 4, 4, 4))}) self.assertTrue(len(saved) > 0) self.assertEqual(saved[0].upper(), 'WEBP') finally: Image.Image.save = orig_save server_mod.decode_latents_to_images = orig_decode _app_instance.app.preview_format = old_fmt if __name__ == "__main__": unittest.main()