File size: 5,466 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import unittest
import torch
from PIL import Image

from src.Utilities import color
from src.user.app_instance import AppInstance
from src.AutoEncoders import taesd


class TestPreviewQuality(unittest.TestCase):
    def test_linear_to_srgb_values(self):
        # Known values across the transfer function
        vals = torch.tensor([0.0, 0.0031308, 0.04, 0.5, 1.0], dtype=torch.float32)
        out = color.linear_to_srgb(vals)
        # Expected values (approx)
        expected = torch.tensor([
            0.0,
            0.0031308 * 12.92,
            1.055 * (0.04 ** (1.0 / 2.4)) - 0.055,
            1.055 * (0.5 ** (1.0 / 2.4)) - 0.055,
            1.0,
        ], dtype=torch.float32)
        self.assertTrue(torch.allclose(out, expected, atol=1e-6))

    def test_decode_uses_lanczos_for_thumbnail(self):
        # Monkeypatch TAESD.decode to return a synthetic tensor large enough
        # to trigger downsampling, and monkeypatch Image.fromarray to capture
        # the resample argument passed to thumbnail.
        orig_decode = taesd.TAESD.decode
        orig_fromarray = Image.fromarray
        called = {}

        def fake_decode(self, x):
            # Return a [B, C, H, W] tensor in [-1, 1] so that after .add(1).mul(0.5)
            # values are in [0,1]. Use a large size to trigger thumbnail.
            return torch.full((1, 3, 1024, 1024), 0.0, dtype=x.dtype)

        class FakeImage:
            def __init__(self):
                self.width = 1024
                self.height = 1024

            def thumbnail(self, size, resample=None):
                called['resample'] = resample
                # Emulate behavior of PIL thumbnail
                self.width = min(self.width, size[0])
                self.height = min(self.height, size[1])

            def save(self, *args, **kwargs):
                return None

        def fake_fromarray(arr, mode=None):
            return FakeImage()

        try:
            taesd.TAESD.decode = fake_decode
            Image.fromarray = fake_fromarray

            latent = torch.zeros((1, 4, 64, 64), dtype=torch.float32)
            imgs = taesd.decode_latents_to_images(latent)

            # Confirm our fake thumbnail captured the resample argument
            self.assertIn('resample', called)
            self.assertEqual(called['resample'], Image.Resampling.LANCZOS)
        finally:
            taesd.TAESD.decode = orig_decode
            Image.fromarray = orig_fromarray

    def test_app_preview_defaults(self):
        app = AppInstance()
        self.assertTrue(hasattr(app, 'preview_srgb'))
        self.assertTrue(app.preview_srgb)
        self.assertEqual(app.preview_format, 'WEBP')
        self.assertEqual(app.preview_quality, 90)

    def test_decode_applies_srgb_when_enabled(self):
        # Monkeypatch TAESD.decode to return constant zero (-> 0.5 after norm)
        orig_decode = taesd.TAESD.decode
        try:
            def fake_decode(self, x):
                return torch.zeros((1, 3, 4, 4), dtype=x.dtype, device=x.device)
            taesd.TAESD.decode = fake_decode

            # Ensure preview_srgb enabled
            from src.user.app_instance import app as global_app
            old_flag = global_app.preview_srgb
            global_app.preview_srgb = True

            latent = torch.zeros((1, 4, 4, 4), dtype=torch.float32)
            imgs = taesd.decode_latents_to_images(latent)
            self.assertTrue(len(imgs) > 0)
            img = imgs[0]
            r, g, b = img.getpixel((0, 0))

            # Expected sRGB value for linear=0.5
            lin = 0.5
            srgb = 1.055 * (lin ** (1.0 / 2.4)) - 0.055
            # The implementation casts to uint8 (truncates), so expect floor behavior
            expect = int(srgb * 255.0)
            self.assertEqual(r, expect)
            self.assertEqual(g, expect)
            self.assertEqual(b, expect)

        finally:
            taesd.TAESD.decode = orig_decode
            global_app.preview_srgb = old_flag

    def test_server_callback_uses_preview_format(self):
        # Ensure server's preview callback attempts to use configured preview format
        import io
        import server as server_mod
        orig_save = Image.Image.save
        orig_decode = server_mod.decode_latents_to_images
        try:
            saved = []
            def fake_save(self, buffer, format=None, **kwargs):
                saved.append(format)
                # write some bytes so the buffer isn't empty
                try:
                    buffer.write(b"OK")
                except Exception:
                    pass
            Image.Image.save = fake_save

            def fake_decode(latents, flux=False):
                return [Image.new('RGB', (64, 64), color='red')]
            server_mod.decode_latents_to_images = fake_decode

            from src.user import app_instance as _app_instance
            old_fmt = _app_instance.app.preview_format
            _app_instance.app.preview_format = 'WEBP'

            cb = server_mod.make_server_callback(20)
            cb({'i': 0, 'total_steps': 20, 'denoised': torch.zeros((1, 4, 4, 4))})

            self.assertTrue(len(saved) > 0)
            self.assertEqual(saved[0].upper(), 'WEBP')
        finally:
            Image.Image.save = orig_save
            server_mod.decode_latents_to_images = orig_decode
            _app_instance.app.preview_format = old_fmt

if __name__ == "__main__":
    unittest.main()