import math import torch from typing import Any, Dict, Optional, Tuple from src.AutoDetailer import AD_util, bbox, tensor_util, SEGS from src.Utilities import util from src.AutoEncoders import VariationalAE from src.Device import Device from src.sample import ksampler_util, samplers, sampling, sampling_util class DifferentialDiffusion: def apply(self, model): model = model.clone() model.set_model_denoise_mask_function(self.forward) return (model,) def forward(self, sigma, denoise_mask, extra_options): model = extra_options["model"] step_sigmas = extra_options["sigmas"] ts_from = model.inner_model.model_sampling.timestep(step_sigmas[0]) ts_to = model.inner_model.model_sampling.timestep(model.inner_model.model_sampling.sigma_min) threshold = (model.inner_model.model_sampling.timestep(sigma[0]) - ts_to) / (ts_from - ts_to) return (denoise_mask >= threshold).to(denoise_mask.dtype) def crop_condition_mask(mask, image, crop_region): x1, y1, x2, y2 = crop_region if len(mask.shape) == 4: return mask[:, y1:y2, x1:x2, :] elif len(mask.shape) == 3: return mask[y1:y2, x1:x2, :] elif len(mask.shape) == 2: return mask[y1:y2, x1:x2] raise ValueError(f"Unsupported mask shape: {mask.shape}") def to_latent_image(pixels, vae): return VariationalAE.VAEEncode().encode(vae, pixels)[0] def calculate_sigmas2(model, sampler, scheduler, steps): return ksampler_util.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, steps) def get_noise_sampler(x, cpu, total_sigmas, **kwargs): if "extra_args" in kwargs and "seed" in kwargs["extra_args"]: sigma_min, sigma_max = total_sigmas[total_sigmas > 0].min(), total_sigmas.max() return sampling_util.BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=kwargs["extra_args"].get("seed"), cpu=cpu) return None def ksampler2(sampler_name, total_sigmas, extra_options={}, inpaint_options={}, pipeline=False, disable_multiscale=True): if disable_multiscale: extra_options = {**extra_options, "enable_multiscale": False, "multiscale_factor": 1.0, "multiscale_fullres_start": 0, "multiscale_fullres_end": 0, "multiscale_intermittent_fullres": False} if sampler_name == "dpmpp_2m_sde": def sample_dpmpp_sde(model, x, sigmas, pipeline, **kwargs): noise_sampler = get_noise_sampler(x, True, total_sigmas, **kwargs) if noise_sampler: kwargs["noise_sampler"] = noise_sampler return samplers.sample_dpmpp_2m_sde(model, x, sigmas, pipeline=pipeline, **kwargs) return sampling.KSAMPLER(sample_dpmpp_sde, extra_options, inpaint_options) return sampling.ksampler(sampler_name, pipeline=pipeline, extra_options=extra_options) class Noise_RandomNoise: def __init__(self, seed): self.seed = seed def generate_noise(self, input_latent): return ksampler_util.prepare_noise(input_latent["samples"], self.seed, input_latent.get("batch_index"), seeds_per_sample=None) def sample_with_custom_noise(model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image, noise=None, callback=None, pipeline=False): out = {**latent_image, "samples": latent_image["samples"]} if noise is None: noise = Noise_RandomNoise(noise_seed).generate_noise(out) device = Device.get_torch_device() noise, latent = noise.to(device), latent_image["samples"].to(device) noise_mask = latent_image.get("noise_mask") if noise_mask is not None: noise_mask = noise_mask.to(device) samples = sampling.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent, noise_mask=noise_mask, callback=callback, disable_pbar=not util.PROGRESS_BAR_ENABLED, seed=noise_seed, pipeline=pipeline) out["samples"] = samples.to(Device.intermediate_device()) return out, out def separated_sample(model, add_noise, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, sigma_ratio=1.0, sampler_opt=None, noise=None, callback=None, scheduler_func=None, pipeline=False): total_sigmas = calculate_sigmas2(model, sampler_name, scheduler, steps) sigmas = total_sigmas[start_at_step:] * sigma_ratio if start_at_step else total_sigmas return sample_with_custom_noise(model, add_noise, seed, cfg, positive, negative, ksampler2(sampler_name, total_sigmas, pipeline=pipeline), sigmas, latent_image, noise=noise, callback=callback, pipeline=pipeline)[1] def ksampler_wrapper(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise, refiner_ratio=None, refiner_model=None, refiner_clip=None, refiner_positive=None, refiner_negative=None, sigma_factor=1.0, noise=None, callback=None, scheduler_func=None, pipeline=False): advanced_steps = math.floor(steps / denoise) return separated_sample(model, True, seed, advanced_steps, cfg, sampler_name, scheduler, positive, negative, latent_image, advanced_steps - steps, advanced_steps - steps + steps, False, sigma_ratio=sigma_factor, noise=noise, callback=callback, scheduler_func=scheduler_func, pipeline=pipeline) def _compute_detailer_resize(width, height, guide_size, max_size): upscale = guide_size / min(width, height) new_w, new_h = int(width * upscale), int(height * upscale) if new_w > max_size or new_h > max_size: upscale *= max_size / max(new_w, new_h) new_w, new_h = int(width * upscale), int(height * upscale) # Round dimensions to nearest multiple of 8 for VAE compatibility. # Non-divisible-by-8 dimensions cause NaN in VAE encode when tiled # encoding is used, because round(dim * 0.125) != dim // 8. new_w = max(8, (new_w + 4) // 8 * 8) new_h = max(8, (new_h + 4) // 8 * 8) force_inpaint = False if new_w == 0 or new_h == 0: force_inpaint = True upscale, new_w, new_h = 1.0, width, height # Also round when force inpaint to keep VAE compatibility new_w = max(8, (new_w + 4) // 8 * 8) new_h = max(8, (new_h + 4) // 8 * 8) return upscale, new_w, new_h, force_inpaint def enhance_detail(image, model, clip, vae, guide_size, guide_size_for_bbox, max_size, bbox, seed, steps, cfg, sampler_name, scheduler, positive, negative, denoise, noise_mask, force_inpaint, wildcard_opt=None, wildcard_opt_concat_mode=None, detailer_hook=None, refiner_ratio=None, refiner_model=None, refiner_clip=None, refiner_positive=None, refiner_negative=None, control_net_wrapper=None, cycle=1, inpaint_model=False, noise_mask_feather=0, callback=None, scheduler_func=None, pipeline=False): if noise_mask is not None: noise_mask = tensor_util.tensor_gaussian_blur_mask(noise_mask, noise_mask_feather).squeeze(3) h, w = image.shape[1], image.shape[2] upscale, new_w, new_h, force_inpaint = _compute_detailer_resize(w, h, guide_size, max_size) if force_inpaint: print("Detailer: force inpaint") print(f"Detailer: segment upscale for ({bbox[2]-bbox[0]}, {bbox[3]-bbox[1]}) | crop region {w, h} x {upscale} -> {new_w, new_h}") upscaled_image = tensor_util.tensor_resize(image, new_w, new_h) latent_image = to_latent_image(upscaled_image, vae) if noise_mask is not None: latent_image["noise_mask"] = noise_mask refined_latent = latent_image for i in range(cycle): refined_latent = ksampler_wrapper(model, seed + i, steps, cfg, sampler_name, scheduler, positive, negative, refined_latent, denoise, refiner_ratio, refiner_model, refiner_clip, refiner_positive, refiner_negative, noise=None, callback=callback, scheduler_func=scheduler_func, pipeline=pipeline) try: refined_image = vae.decode(refined_latent["samples"]) except Exception: # Standard tile size for SDXL VAE to avoid artifacts refined_image = vae.decode_tiled(refined_latent["samples"], tile_x=256, tile_y=256) return tensor_util.tensor_resize(refined_image, w, h).cpu(), None class DetailerForEach: @staticmethod def do_detail(image, segs, model, clip, vae, guide_size, guide_size_for_bbox, max_size, seed, steps, cfg, sampler_name, scheduler, positive, negative, denoise, feather, noise_mask, force_inpaint, wildcard_opt=None, detailer_hook=None, refiner_ratio=None, refiner_model=None, refiner_clip=None, refiner_positive=None, refiner_negative=None, cycle=1, inpaint_model=False, noise_mask_feather=0, callback=None, scheduler_func_opt=None, pipeline=False): image = image.clone() enhanced_alpha_list, enhanced_list, cropped_list, cnet_pil_list, new_segs = [], [], [], [], [] segs = AD_util.segs_scale_match(segs, image.shape) wmode, wildcard_chooser = bbox.process_wildcard_for_segs(wildcard_opt) if noise_mask_feather > 0 and "denoise_mask_function" not in model.model_options: model = DifferentialDiffusion().apply(model)[0] for i, seg in enumerate(segs[1]): # Check for interrupt before each segment from src.user import app_instance app = getattr(app_instance, "app", None) if app and getattr(app, "interrupt_flag", False): print(f"Detailer: Interrupt requested, stopping at segment {i}") break cropped_image = tensor_util.to_tensor(AD_util.crop_ndarray4(image.cpu().numpy(), seg.crop_region)) mask = tensor_util.tensor_gaussian_blur_mask(tensor_util.to_tensor(seg.cropped_mask), feather) if (seg.cropped_mask == 0).all().item(): print("Detailer: segment skip [empty mask]") continue seg_seed, wildcard_item = wildcard_chooser.get(seg) seg_seed = seed + i if seg_seed is None else seg_seed crop_h, crop_w = int(cropped_image.shape[1]), int(cropped_image.shape[2]) _, crop_new_w, crop_new_h, _ = _compute_detailer_resize(crop_w, crop_h, guide_size, max_size) def crop_cond(cond_list): if cond_list is None: return None # Extract crop region coordinates x1, y1, x2, y2 = [int(round(c)) for c in seg.crop_region] res = [] for entry in cond_list: if isinstance(entry, (list, tuple)) and len(entry) > 1 and isinstance(entry[1], dict): new_dict = entry[1].copy() # Apply mask cropping if present if "mask" in new_dict: new_dict["mask"] = crop_condition_mask(new_dict["mask"], image, seg.crop_region) # CRITICAL: Preserve pooled_output for SDXL if "pooled_output" in entry[1]: new_dict["pooled_output"] = entry[1]["pooled_output"] # Inject SDXL size conditioning for the crop # Use crop-local dimensions to match the actual sampling resolution. new_dict["width"] = crop_new_w new_dict["height"] = crop_new_h new_dict["crop_w"] = 0 new_dict["crop_h"] = 0 new_dict["target_width"] = crop_new_w new_dict["target_height"] = crop_new_h res.append([entry[0], new_dict]) else: res.append(entry) return res orig_cropped_image = cropped_image.clone() enhanced_image, cnet_pils = enhance_detail(cropped_image, model, clip, vae, guide_size, guide_size_for_bbox, max_size, seg.bbox, seg_seed, steps, cfg, sampler_name, scheduler, crop_cond(positive), crop_cond(negative), denoise, seg.cropped_mask, force_inpaint, wildcard_opt=wildcard_item, wildcard_opt_concat_mode=None, detailer_hook=detailer_hook, refiner_ratio=refiner_ratio, refiner_model=refiner_model, refiner_clip=refiner_clip, refiner_positive=refiner_positive, refiner_negative=refiner_negative, control_net_wrapper=seg.control_net_wrapper, cycle=cycle, inpaint_model=inpaint_model, noise_mask_feather=noise_mask_feather, callback=callback, scheduler_func=scheduler_func_opt, pipeline=pipeline) if enhanced_image is not None: image = image.cpu() tensor_util.tensor_paste(image, enhanced_image.cpu(), (seg.crop_region[0], seg.crop_region[1]), mask) enhanced_list.append(enhanced_image) enhanced_image_alpha = tensor_util.tensor_convert_rgba(enhanced_image) mask = tensor_util.tensor_resize(mask, *tensor_util.tensor_get_size(enhanced_image)) tensor_util.tensor_putalpha(enhanced_image_alpha, mask) enhanced_alpha_list.append(enhanced_image_alpha) cropped_list.append(orig_cropped_image) new_segs.append(SEGS.SEG(enhanced_image.numpy(), seg.cropped_mask, seg.confidence, seg.crop_region, seg.bbox, seg.label, seg.control_net_wrapper)) for lst in [cropped_list, enhanced_list, enhanced_alpha_list]: lst.sort(key=lambda x: x.shape, reverse=True) return tensor_util.tensor_convert_rgb(image), cropped_list, enhanced_list, enhanced_alpha_list, cnet_pil_list, (segs[0], new_segs) def empty_pil_tensor(w=64, h=64): return torch.zeros((1, h, w, 3), dtype=torch.float32) class DetailerForEachTest(DetailerForEach): def doit(self, image, segs, model, clip, vae, guide_size, guide_size_for, max_size, seed, steps, cfg, sampler_name, scheduler, positive, negative, denoise, feather, noise_mask, force_inpaint, wildcard, detailer_hook=None, cycle=1, inpaint_model=False, noise_mask_feather=0, callback=None, scheduler_func_opt=None, pipeline=False): if len(image.shape) == 4 and image.shape[0] > 1: batch_size = image.shape[0] results = [[], [], [], [], []] for i in range(batch_size): # Check for interrupt before each batch item from src.user import app_instance app = getattr(app_instance, "app", None) if app and getattr(app, "interrupt_flag", False): print(f"ADetailer: Interrupt requested, stopping at batch item {i}") break enhanced, cropped, enh, enh_alpha, cnet, _ = DetailerForEach.do_detail( image[i:i+1], segs, model, clip, vae, guide_size, guide_size_for, max_size, seed + i, steps, cfg, sampler_name, scheduler, positive, negative, denoise, feather, noise_mask, force_inpaint, wildcard, detailer_hook, cycle=cycle, inpaint_model=inpaint_model, noise_mask_feather=noise_mask_feather, callback=callback, scheduler_func_opt=scheduler_func_opt, pipeline=pipeline) results[0].append(enhanced) results[1].extend(cropped) results[2].extend(enh) results[3].extend(enh_alpha) results[4].extend(cnet) return torch.cat(results[0], dim=0), results[1], results[2], results[3], results[4] or [empty_pil_tensor()] enhanced, cropped, enh, enh_alpha, cnet, _ = DetailerForEach.do_detail( image, segs, model, clip, vae, guide_size, guide_size_for, max_size, seed, steps, cfg, sampler_name, scheduler, positive, negative, denoise, feather, noise_mask, force_inpaint, wildcard, detailer_hook, cycle=cycle, inpaint_model=inpaint_model, noise_mask_feather=noise_mask_feather, callback=callback, scheduler_func_opt=scheduler_func_opt, pipeline=pipeline) return enhanced, cropped, enh, enh_alpha, [empty_pil_tensor()]