File size: 6,752 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import torch
from unittest.mock import patch, MagicMock

from src.Core.Context import Context
from src.Core.AbstractModel import ModelCapabilities
from src.Processors.HiresFix import HiresFix


from src.Core.AbstractModel import AbstractModel


class FakeModel(AbstractModel):
    def __init__(self):
        super().__init__(model_path=None)
        self._capabilities = ModelCapabilities()
        self._capabilities.supports_hires_fix = True
        self._capabilities.requires_size_conditioning = False
        self.model = object()

    def _create_capabilities(self):
        return self._capabilities

    def load(self, model_path: str = None):
        self._loaded = True
        return self

    def get_model_object(self, name):
        # Provide a latent format object with downscale factor for tests
        class LF:
            downscale_factor = 8
        return LF()

    def encode_prompt(self, prompt, negative_prompt, clip_skip: int = -2):
        # Return a dummy encoded conditioning
        cond = [[torch.zeros(1, 77, 768), {}]]
        return cond, cond

    def generate(self, ctx, positive, negative, latent_image=None, start_step=None, last_step=None, disable_noise=False, callback=None):
        return {"samples": torch.randn(1, 4, 64, 64)}

    def decode(self, latents):
        batch, channels, lh, lw = latents.shape
        return torch.randn(batch, 3, lh * 8, lw * 8)


def _make_latents():
    return {"samples": torch.zeros(1, 4, 64, 64)}


def test_uses_hires_ctx_defaults():
    ctx = Context(prompt="test")
    ctx.generation.width = 512
    ctx.generation.height = 512
    ctx.sampling.steps = 40
    ctx.sampling.cfg = 2.5
    ctx.sampling.sampler = "foo"
    ctx.sampling.scheduler = "bar"

    model = FakeModel()
    latents = _make_latents()

    with patch("src.Utilities.upscale.LatentUpscale") as mock_upscale:
        mock_upscale.return_value.upscale.return_value = ({"samples": torch.zeros(1, 4, 128, 128)},)
        with patch("src.sample.sampling.KSampler") as mock_ksampler:
            mock_ksampler.return_value.sample.return_value = ({"samples": torch.zeros(1, 4, 64, 64)},)

            HiresFix.apply(latents, ctx, model, positive=[[torch.zeros(1, 77, 768), {}]], negative=[[torch.zeros(1, 77, 768), {}]])

            # Verify sampler was called
            assert mock_ksampler.return_value.sample.called
            call_kwargs = mock_ksampler.return_value.sample.call_args.kwargs

            # Steps should reflect hires_ctx.steps -> max(10, int(40 * 0.5)) == 20
            assert call_kwargs["steps"] == 20
            # CFG should be the hires_ctx default (8.0) for non-flux models
            assert call_kwargs["cfg"] == 8.0
            # Sampler name should be taken from hires_ctx (unchanged here)
            assert call_kwargs["sampler_name"] == ctx.sampling.sampler
            # Scheduler should be propagated through to the sampler (regression guard)
            assert call_kwargs["scheduler"] == ctx.sampling.scheduler


def test_injects_size_conditioning_for_sdxl():
    ctx = Context(prompt="sdxl test")
    ctx.generation.width = 512
    ctx.generation.height = 512

    model = FakeModel()
    # mark model as requiring size conditioning (SDXL-like)
    model.capabilities.requires_size_conditioning = True

    # Provide encoded conditioning structures (list of [tensor, meta]) that should be updated
    positive = [[torch.zeros(1, 77, 2048), {}]]
    negative = [[torch.zeros(1, 77, 2048), {}]]

    latents = _make_latents()

    with patch("src.Utilities.upscale.LatentUpscale") as mock_upscale:
        mock_upscale.return_value.upscale.return_value = ({"samples": torch.zeros(1, 4, 128, 128)},)
        with patch("src.sample.sampling.KSampler") as mock_ksampler:
            mock_ksampler.return_value.sample.return_value = ({"samples": torch.zeros(1, 4, 64, 64)},)

            HiresFix.apply(latents, ctx, model, positive=positive, negative=negative)

            assert mock_ksampler.return_value.sample.called
            call_kwargs = mock_ksampler.return_value.sample.call_args.kwargs
            called_positive = call_kwargs.get("positive")

            # The metadata dict in the conditioning should now contain width/height matching hires ctx (512*2 = 1024)
            assert isinstance(called_positive, list)
            meta = called_positive[0][1]
            assert meta.get("width") == 1024
            assert meta.get("height") == 1024
            assert meta.get("target_width") == 1024
            assert meta.get("target_height") == 1024


def test_respects_cfg_for_flux():
    ctx = Context(prompt="flux test")
    ctx.generation.width = 512
    ctx.generation.height = 512
    ctx.sampling.cfg = 1.2

    model = FakeModel()
    model.capabilities.is_flux = True

    latents = _make_latents()

    with patch("src.Utilities.upscale.LatentUpscale") as mock_upscale:
        mock_upscale.return_value.upscale.return_value = ({"samples": torch.zeros(1, 16, 128, 128)},)
        with patch("src.sample.sampling.KSampler") as mock_ksampler:
            mock_ksampler.return_value.sample.return_value = ({"samples": torch.zeros(1, 16, 64, 64)},)

            HiresFix.apply(latents, ctx, model, positive=[[torch.zeros(1, 77, 4096), {}]], negative=[[torch.zeros(1, 77, 4096), {}]])

            assert mock_ksampler.return_value.sample.called
            call_kwargs = mock_ksampler.return_value.sample.call_args.kwargs

            # For Flux models, HiresFix should honor the original ctx.sampling.cfg
            assert call_kwargs["cfg"] == 1.2


def test_reencodes_raw_prompt_for_sdxl():
    ctx = Context(prompt="raw prompt")
    ctx.generation.width = 512
    ctx.generation.height = 512

    model = FakeModel()
    model.capabilities.requires_size_conditioning = True

    latents = _make_latents()

    with patch("src.Utilities.upscale.LatentUpscale") as mock_upscale:
        mock_upscale.return_value.upscale.return_value = ({"samples": torch.zeros(1, 4, 128, 128)},)
        with patch("src.sample.sampling.KSampler") as mock_ksampler:
            mock_ksampler.return_value.sample.return_value = ({"samples": torch.zeros(1, 4, 64, 64)},)

            # Provide raw text prompts (should trigger re-encoding)
            HiresFix.apply(latents, ctx, model, positive="a prompt", negative="")

            assert mock_ksampler.return_value.sample.called
            call_kwargs = mock_ksampler.return_value.sample.call_args.kwargs
            called_positive = call_kwargs.get("positive")

            # Should be encoded and updated to hires size
            assert isinstance(called_positive, list)
            meta = called_positive[0][1]
            assert meta.get("width") == 1024
            assert meta.get("height") == 1024