File size: 18,436 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
"""Ultimate SD Upscale - tiled upscaling with seam fix."""
from src.AutoEncoders import VariationalAE
from src.sample import sampling
from src.UltimateSDUpscale import USDU_upscaler, image_util
import torch
from PIL import ImageFilter, ImageDraw, Image
from enum import Enum
import math

state = USDU_upscaler.state


class UnsupportedModel(Exception):
    pass


class StableDiffusionProcessing:
    """Container for SD processing parameters."""
    def __init__(self, init_img: Image.Image, model, positive, negative, vae, seed, steps, cfg,
                 sampler_name, scheduler, denoise, upscale_by, uniform_tile_mode, callback=None):
        self.init_images = [init_img]
        self.image_mask = None
        self.mask_blur = 0
        self.inpaint_full_res_padding = 0
        self.width, self.height = init_img.width, init_img.height
        self.model, self.positive, self.negative, self.vae = model, positive, negative, vae
        self.seed, self.steps, self.cfg = seed, steps, cfg
        self.sampler_name, self.scheduler, self.denoise = sampler_name, scheduler, denoise
        self.init_size = (init_img.width, init_img.height)
        self.upscale_by, self.uniform_tile_mode = upscale_by, uniform_tile_mode
        self.extra_generation_params = {}
        self.callback = callback


class Processed:
    """Container for processed images."""
    def __init__(self, p, images, seed, info):
        self.images, self.seed, self.info = images, seed, info

    def infotext(self, p, index):
        return None


def fix_seed(p):
    pass


def process_images(p, pipeline=False):
    """Process tiles using inpainting."""
    image_mask = p.image_mask.convert("L")
    init_image = p.init_images[0]

    crop_region = image_util.get_crop_region(image_mask, p.inpaint_full_res_padding)
    x1, y1, x2, y2 = crop_region
    crop_ratio = (x2 - x1) / (y2 - y1)
    p_ratio = p.width / p.height
    
    if crop_ratio > p_ratio:
        target_width, target_height = x2 - x1, round((x2 - x1) / p_ratio)
    else:
        target_width, target_height = round((y2 - y1) * p_ratio), y2 - y1
    
    crop_region, _ = image_util.expand_crop(crop_region, image_mask.width, image_mask.height, target_width, target_height)
    tile_size = (p.width, p.height)

    if p.mask_blur > 0:
        image_mask = image_mask.filter(ImageFilter.GaussianBlur(p.mask_blur))

    tiles = [img.crop(crop_region) for img in USDU_upscaler.batch]
    initial_tile_size = tiles[0].size
    tiles = [t.resize(tile_size, Image.Resampling.LANCZOS) if t.size != tile_size else t for t in tiles]

    positive_cropped = image_util.crop_cond(p.positive, crop_region, p.init_size, init_image.size, tile_size)
    negative_cropped = image_util.crop_cond(p.negative, crop_region, p.init_size, init_image.size, tile_size)

    batched_tiles = torch.cat([image_util.pil_to_tensor(t) for t in tiles], dim=0)
    (latent,) = VariationalAE.VAEEncode().encode(p.vae, batched_tiles)
    
    # Auto-detect Flux for disabling multi-scale and setting correct flags
    model_sampling_obj = p.model.get_model_object("model_sampling")
    from src.sample.sampling import ModelSamplingFlux, ModelSamplingFlux2
    is_flux = isinstance(model_sampling_obj, (ModelSamplingFlux, ModelSamplingFlux2))
    is_flux2 = isinstance(model_sampling_obj, ModelSamplingFlux2)

    # Pass crop offsets for positional embedding coherence (Critical for Flux/DiT)
    model_options = p.model.model_options.copy()
    transformer_options = model_options.get("transformer_options", {}).copy()
    transformer_options["top"] = y1
    transformer_options["left"] = x1
    model_options["transformer_options"] = transformer_options

    (samples,) = sampling.common_ksampler(p.model, p.seed, p.steps, p.cfg, p.sampler_name, p.scheduler,
                                          positive_cropped, negative_cropped, latent, denoise=p.denoise, 
                                          pipeline=pipeline, flux=is_flux, flux2=is_flux2,
                                          model_options=model_options, callback=p.callback)
    (decoded,) = VariationalAE.VAEDecode().decode(p.vae, samples)

    for i, tile_sampled in enumerate([image_util.tensor_to_pil(decoded, j) for j in range(len(decoded))]):
        init_image = USDU_upscaler.batch[i]
        if tile_sampled.size != initial_tile_size:
            tile_sampled = tile_sampled.resize(initial_tile_size, Image.Resampling.LANCZOS)

        image_tile_only = Image.new("RGBA", init_image.size)
        image_tile_only.paste(tile_sampled, crop_region[:2])
        temp = image_tile_only.copy()
        temp.putalpha(image_mask.resize(temp.size))
        image_tile_only.paste(temp, image_tile_only)
        result = init_image.convert("RGBA")
        result.alpha_composite(image_tile_only)
        USDU_upscaler.batch[i] = result.convert("RGB")

    return Processed(p, [USDU_upscaler.batch[0]], p.seed, None)


class USDUMode(Enum):
    LINEAR, CHESS, NONE = 0, 1, 2


class USDUSFMode(Enum):
    NONE, BAND_PASS, HALF_TILE, HALF_TILE_PLUS_INTERSECTIONS = 0, 1, 2, 3


class USDUpscaler:
    """Main upscaler class."""
    def __init__(self, p, image, upscaler_index, save_redraw, save_seams_fix, tile_width, tile_height):
        self.p, self.image = p, image
        self.scale_factor = math.ceil(max(p.width, p.height) / max(image.width, image.height))
        self.upscaler = USDU_upscaler.sd_upscalers[upscaler_index]
        self.redraw = USDURedraw()
        self.redraw.tile_width = tile_width or tile_height
        self.redraw.tile_height = tile_height or tile_width
        self.seams_fix = USDUSeamsFix()
        self.seams_fix.tile_width = self.redraw.tile_width
        self.seams_fix.tile_height = self.redraw.tile_height
        self.initial_info = None
        self.rows = math.ceil(self.p.height / self.redraw.tile_height)
        self.cols = math.ceil(self.p.width / self.redraw.tile_width)

    def get_factor(self, num):
        if num == 1: return 2
        for f in [4, 3, 2]:
            if num % f == 0: return f
        return 0

    def get_factors(self):
        scales, current = [], 1
        while current < self.scale_factor:
            f = self.get_factor(self.scale_factor // current)
            scales.append(f)
            current *= f
        self.scales = enumerate(scales)

    def upscale(self):
        print(f"Canva: {self.p.width}x{self.p.height}, Image: {self.image.width}x{self.image.height}, Scale: {self.scale_factor}")
        self.get_factors()
        for idx, val in self.scales:
            print(f"Upscaling iteration {idx + 1} with scale factor {val}")
            self.image = self.upscaler.scaler.upscale(self.image, val, self.upscaler.data_path)
        self.image = self.image.resize((self.p.width, self.p.height), resample=Image.LANCZOS)

    def setup_redraw(self, mode, padding, mask_blur):
        self.redraw.mode = USDUMode(mode)
        self.redraw.enabled = self.redraw.mode != USDUMode.NONE
        self.redraw.padding = padding
        self.p.mask_blur = mask_blur

    def setup_seams_fix(self, padding, denoise, mask_blur, width, mode):
        self.seams_fix.padding, self.seams_fix.denoise = padding, denoise
        self.seams_fix.mask_blur, self.seams_fix.width = mask_blur, width
        self.seams_fix.mode = USDUSFMode(mode)
        self.seams_fix.enabled = self.seams_fix.mode != USDUSFMode.NONE

    def calc_jobs_count(self):
        global state
        redraw = (self.rows * self.cols) if self.redraw.enabled else 0
        seams = self.rows * (self.cols - 1) + (self.rows - 1) * self.cols
        state.job_count = redraw + seams

    def print_info(self):
        print(f"Tile: {self.redraw.tile_width}x{self.redraw.tile_height}, Grid: {self.rows}x{self.cols}")

    def add_extra_info(self):
        self.p.extra_generation_params.update({
            "Ultimate SD upscale upscaler": self.upscaler.name,
            "Ultimate SD upscale tile_width": self.redraw.tile_width,
            "Ultimate SD upscale tile_height": self.redraw.tile_height,
        })

    def process(self, pipeline):
        USDU_upscaler.state.begin()
        self.calc_jobs_count()
        self.result_images = []
        if self.redraw.enabled:
            self.image = self.redraw.start(self.p, self.image, self.rows, self.cols, pipeline)
            self.initial_info = self.redraw.initial_info
        self.result_images.append(self.image)
        if self.seams_fix.enabled:
            self.image = self.seams_fix.start(self.p, self.image, self.rows, self.cols, pipeline)
            self.initial_info = self.seams_fix.initial_info
            self.result_images.append(self.image)
        USDU_upscaler.state.end()


class USDURedraw:
    """Tile redraw functionality."""
    def init_draw(self, p, width, height):
        p.inpaint_full_res = True
        p.inpaint_full_res_padding = self.padding
        p.width = math.ceil((self.tile_width + self.padding) / 64) * 64
        p.height = math.ceil((self.tile_height + self.padding) / 64) * 64
        mask = Image.new("L", (width, height), "black")
        return mask, ImageDraw.Draw(mask)

    def calc_rectangle(self, xi, yi):
        return xi * self.tile_width, yi * self.tile_height, (xi + 1) * self.tile_width, (yi + 1) * self.tile_height

    def linear_process(self, p, image, rows, cols, pipeline=False):
        global state
        mask, draw = self.init_draw(p, image.width, image.height)
        for yi in range(rows):
            for xi in range(cols):
                if state.interrupted: break
                draw.rectangle(self.calc_rectangle(xi, yi), fill="white")
                p.init_images, p.image_mask = [image], mask
                processed = process_images(p, pipeline)
                draw.rectangle(self.calc_rectangle(xi, yi), fill="black")
                if processed.images: image = processed.images[0]
        p.width, p.height = image.width, image.height
        self.initial_info = processed.infotext(p, 0)
        return image

    def start(self, p, image, rows, cols, pipeline=False):
        self.initial_info = None
        return self.linear_process(p, image, rows, cols, pipeline)


class USDUSeamsFix:
    """Seam fixing functionality."""
    def init_draw(self, p):
        self.initial_info = None
        p.width = math.ceil((self.tile_width + self.padding) / 64) * 64
        p.height = math.ceil((self.tile_height + self.padding) / 64) * 64

    def half_tile_process(self, p, image, rows, cols, pipeline=False):
        global state
        self.init_draw(p)
        processed = None
        gradient = Image.linear_gradient("L")
        
        row_gradient = Image.new("L", (self.tile_width, self.tile_height), "black")
        row_gradient.paste(gradient.resize((self.tile_width, self.tile_height // 2), Image.BICUBIC), (0, 0))
        row_gradient.paste(gradient.rotate(180).resize((self.tile_width, self.tile_height // 2), Image.BICUBIC), (0, self.tile_height // 2))
        
        col_gradient = Image.new("L", (self.tile_width, self.tile_height), "black")
        col_gradient.paste(gradient.rotate(90).resize((self.tile_width // 2, self.tile_height), Image.BICUBIC), (0, 0))
        col_gradient.paste(gradient.rotate(270).resize((self.tile_width // 2, self.tile_height), Image.BICUBIC), (self.tile_width // 2, 0))

        p.denoising_strength, p.mask_blur = self.denoise, self.mask_blur

        for yi in range(rows - 1):
            for xi in range(cols):
                p.width, p.height = self.tile_width, self.tile_height
                p.inpaint_full_res, p.inpaint_full_res_padding = True, self.padding
                mask = Image.new("L", (image.width, image.height), "black")
                mask.paste(row_gradient, (xi * self.tile_width, yi * self.tile_height + self.tile_height // 2))
                p.init_images, p.image_mask = [image], mask
                processed = process_images(p, pipeline)
                if processed.images: image = processed.images[0]

        for yi in range(rows):
            for xi in range(cols - 1):
                p.width, p.height = self.tile_width, self.tile_height
                p.inpaint_full_res, p.inpaint_full_res_padding = True, self.padding
                mask = Image.new("L", (image.width, image.height), "black")
                mask.paste(col_gradient, (xi * self.tile_width + self.tile_width // 2, yi * self.tile_height))
                p.init_images, p.image_mask = [image], mask
                processed = process_images(p, pipeline)
                if processed.images: image = processed.images[0]

        p.width, p.height = image.width, image.height
        if processed: self.initial_info = processed.infotext(p, 0)
        return image

    def start(self, p, image, rows, cols, pipeline=False):
        return self.half_tile_process(p, image, rows, cols, pipeline)


class Script(USDU_upscaler.Script):
    """Main script runner."""
    def run(self, p, _, tile_width, tile_height, mask_blur, padding, seams_fix_width, seams_fix_denoise,
            seams_fix_padding, upscaler_index, save_upscaled_image, redraw_mode, save_seams_fix_image,
            seams_fix_mask_blur, seams_fix_type, target_size_type, custom_width, custom_height, custom_scale, pipeline=False):
        fix_seed(p)
        USDU_upscaler.torch_gc()
        p.do_not_save_grid = p.do_not_save_samples = True
        p.inpaint_full_res = False
        p.inpainting_fill, p.n_iter, p.batch_size = 1, 1, 1

        init_img = image_util.flatten(p.init_images[0], USDU_upscaler.opts.img2img_background_color)
        p.width = math.ceil((init_img.width * custom_scale) / 64) * 64
        p.height = math.ceil((init_img.height * custom_scale) / 64) * 64

        upscaler = USDUpscaler(p, init_img, upscaler_index, save_upscaled_image, save_seams_fix_image, tile_width, tile_height)
        upscaler.upscale()
        upscaler.setup_redraw(redraw_mode, padding, mask_blur)
        upscaler.setup_seams_fix(seams_fix_padding, seams_fix_denoise, seams_fix_mask_blur, seams_fix_width, seams_fix_type)
        upscaler.print_info()
        upscaler.add_extra_info()
        upscaler.process(pipeline)
        return Processed(p, upscaler.result_images, p.seed, upscaler.initial_info or "")


# Monkey-patch overrides
_old_init = USDUpscaler.__init__
def _new_init(self, p, image, upscaler_index, save_redraw, save_seams_fix, tile_width, tile_height):
    # Determine downscale factor from model (8 for SD, 16 for Flux)
    downscale_factor = 8
    try:
        latent_format = p.model.get_model_object("latent_format")
        if hasattr(latent_format, "downscale_factor"):
            downscale_factor = latent_format.downscale_factor
    except Exception:
        pass
        
    p.width = math.ceil((image.width * p.upscale_by) / downscale_factor) * downscale_factor
    p.height = math.ceil((image.height * p.upscale_by) / downscale_factor) * downscale_factor
    _old_init(self, p, image, upscaler_index, save_redraw, save_seams_fix, tile_width, tile_height)
USDUpscaler.__init__ = _new_init

_old_redraw = USDURedraw.init_draw
def _new_redraw(self, p, width, height):
    mask, draw = _old_redraw(self, p, width, height)
    
    downscale_factor = 8
    try:
        latent_format = p.model.get_model_object("latent_format")
        if hasattr(latent_format, "downscale_factor"):
            downscale_factor = latent_format.downscale_factor
    except Exception:
        pass
        
    p.width = math.ceil((self.tile_width + self.padding) / downscale_factor) * downscale_factor
    p.height = math.ceil((self.tile_height + self.padding) / downscale_factor) * downscale_factor
    return mask, draw
USDURedraw.init_draw = _new_redraw

_old_seams = USDUSeamsFix.init_draw
def _new_seams(self, p):
    _old_seams(self, p)
    
    downscale_factor = 8
    try:
        latent_format = p.model.get_model_object("latent_format")
        if hasattr(latent_format, "downscale_factor"):
            downscale_factor = latent_format.downscale_factor
    except Exception:
        pass
        
    p.width = math.ceil((self.tile_width + self.padding) / downscale_factor) * downscale_factor
    p.height = math.ceil((self.tile_height + self.padding) / downscale_factor) * downscale_factor
USDUSeamsFix.init_draw = _new_seams

_old_upscale = USDUpscaler.upscale
def _new_upscale(self):
    _old_upscale(self)
    USDU_upscaler.batch = [self.image] + [img.resize((self.p.width, self.p.height), Image.LANCZOS) for img in USDU_upscaler.batch[1:]]
USDUpscaler.upscale = _new_upscale

MAX_RESOLUTION = 8192
MODES = {"Linear": USDUMode.LINEAR, "Chess": USDUMode.CHESS, "None": USDUMode.NONE}
SEAM_FIX_MODES = {"None": USDUSFMode.NONE, "Band Pass": USDUSFMode.BAND_PASS, "Half Tile": USDUSFMode.HALF_TILE, "Half Tile + Intersections": USDUSFMode.HALF_TILE_PLUS_INTERSECTIONS}


class UltimateSDUpscale:
    """Main entry point for Ultimate SD Upscale."""
    def upscale(self, image, model, positive, negative, vae, upscale_by, seed, steps, cfg, sampler_name,
                scheduler, denoise, upscale_model, mode_type, tile_width, tile_height, mask_blur, tile_padding,
                seam_fix_mode, seam_fix_denoise, seam_fix_mask_blur, seam_fix_width, seam_fix_padding,
                force_uniform_tiles, pipeline=False, callback=None):
        USDU_upscaler.sd_upscalers[0] = USDU_upscaler.UpscalerData()
        USDU_upscaler.actual_upscaler = upscale_model
        USDU_upscaler.batch = [image_util.tensor_to_pil(image, i) for i in range(len(image))]

        sdprocessing = StableDiffusionProcessing(
            image_util.tensor_to_pil(image), model, positive, negative, vae, seed, steps, cfg,
            sampler_name, scheduler, denoise, upscale_by, force_uniform_tiles, callback=callback)

        Script().run(
            p=sdprocessing, _=None, tile_width=tile_width, tile_height=tile_height, mask_blur=mask_blur,
            padding=tile_padding, seams_fix_width=seam_fix_width, seams_fix_denoise=seam_fix_denoise,
            seams_fix_padding=seam_fix_padding, upscaler_index=0, save_upscaled_image=False,
            redraw_mode=MODES[mode_type], save_seams_fix_image=False, seams_fix_mask_blur=seam_fix_mask_blur,
            seams_fix_type=SEAM_FIX_MODES[seam_fix_mode], target_size_type=2, custom_width=None,
            custom_height=None, custom_scale=upscale_by, pipeline=pipeline)

        return (torch.cat([image_util.pil_to_tensor(img) for img in USDU_upscaler.batch], dim=0),)