Spaces:
Running on Zero
Running on Zero
File size: 5,466 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import unittest
import torch
from PIL import Image
from src.Utilities import color
from src.user.app_instance import AppInstance
from src.AutoEncoders import taesd
class TestPreviewQuality(unittest.TestCase):
def test_linear_to_srgb_values(self):
# Known values across the transfer function
vals = torch.tensor([0.0, 0.0031308, 0.04, 0.5, 1.0], dtype=torch.float32)
out = color.linear_to_srgb(vals)
# Expected values (approx)
expected = torch.tensor([
0.0,
0.0031308 * 12.92,
1.055 * (0.04 ** (1.0 / 2.4)) - 0.055,
1.055 * (0.5 ** (1.0 / 2.4)) - 0.055,
1.0,
], dtype=torch.float32)
self.assertTrue(torch.allclose(out, expected, atol=1e-6))
def test_decode_uses_lanczos_for_thumbnail(self):
# Monkeypatch TAESD.decode to return a synthetic tensor large enough
# to trigger downsampling, and monkeypatch Image.fromarray to capture
# the resample argument passed to thumbnail.
orig_decode = taesd.TAESD.decode
orig_fromarray = Image.fromarray
called = {}
def fake_decode(self, x):
# Return a [B, C, H, W] tensor in [-1, 1] so that after .add(1).mul(0.5)
# values are in [0,1]. Use a large size to trigger thumbnail.
return torch.full((1, 3, 1024, 1024), 0.0, dtype=x.dtype)
class FakeImage:
def __init__(self):
self.width = 1024
self.height = 1024
def thumbnail(self, size, resample=None):
called['resample'] = resample
# Emulate behavior of PIL thumbnail
self.width = min(self.width, size[0])
self.height = min(self.height, size[1])
def save(self, *args, **kwargs):
return None
def fake_fromarray(arr, mode=None):
return FakeImage()
try:
taesd.TAESD.decode = fake_decode
Image.fromarray = fake_fromarray
latent = torch.zeros((1, 4, 64, 64), dtype=torch.float32)
imgs = taesd.decode_latents_to_images(latent)
# Confirm our fake thumbnail captured the resample argument
self.assertIn('resample', called)
self.assertEqual(called['resample'], Image.Resampling.LANCZOS)
finally:
taesd.TAESD.decode = orig_decode
Image.fromarray = orig_fromarray
def test_app_preview_defaults(self):
app = AppInstance()
self.assertTrue(hasattr(app, 'preview_srgb'))
self.assertTrue(app.preview_srgb)
self.assertEqual(app.preview_format, 'WEBP')
self.assertEqual(app.preview_quality, 90)
def test_decode_applies_srgb_when_enabled(self):
# Monkeypatch TAESD.decode to return constant zero (-> 0.5 after norm)
orig_decode = taesd.TAESD.decode
try:
def fake_decode(self, x):
return torch.zeros((1, 3, 4, 4), dtype=x.dtype, device=x.device)
taesd.TAESD.decode = fake_decode
# Ensure preview_srgb enabled
from src.user.app_instance import app as global_app
old_flag = global_app.preview_srgb
global_app.preview_srgb = True
latent = torch.zeros((1, 4, 4, 4), dtype=torch.float32)
imgs = taesd.decode_latents_to_images(latent)
self.assertTrue(len(imgs) > 0)
img = imgs[0]
r, g, b = img.getpixel((0, 0))
# Expected sRGB value for linear=0.5
lin = 0.5
srgb = 1.055 * (lin ** (1.0 / 2.4)) - 0.055
# The implementation casts to uint8 (truncates), so expect floor behavior
expect = int(srgb * 255.0)
self.assertEqual(r, expect)
self.assertEqual(g, expect)
self.assertEqual(b, expect)
finally:
taesd.TAESD.decode = orig_decode
global_app.preview_srgb = old_flag
def test_server_callback_uses_preview_format(self):
# Ensure server's preview callback attempts to use configured preview format
import io
import server as server_mod
orig_save = Image.Image.save
orig_decode = server_mod.decode_latents_to_images
try:
saved = []
def fake_save(self, buffer, format=None, **kwargs):
saved.append(format)
# write some bytes so the buffer isn't empty
try:
buffer.write(b"OK")
except Exception:
pass
Image.Image.save = fake_save
def fake_decode(latents, flux=False):
return [Image.new('RGB', (64, 64), color='red')]
server_mod.decode_latents_to_images = fake_decode
from src.user import app_instance as _app_instance
old_fmt = _app_instance.app.preview_format
_app_instance.app.preview_format = 'WEBP'
cb = server_mod.make_server_callback(20)
cb({'i': 0, 'total_steps': 20, 'denoised': torch.zeros((1, 4, 4, 4))})
self.assertTrue(len(saved) > 0)
self.assertEqual(saved[0].upper(), 'WEBP')
finally:
Image.Image.save = orig_save
server_mod.decode_latents_to_images = orig_decode
_app_instance.app.preview_format = old_fmt
if __name__ == "__main__":
unittest.main()
|