|
|
from PIL import Image, ImageFilter |
|
|
import torch |
|
|
import math |
|
|
from nodes import common_ksampler, VAEEncode, VAEDecode, VAEDecodeTiled |
|
|
from comfy_extras.nodes_custom_sampler import SamplerCustom |
|
|
from utils import pil_to_tensor, tensor_to_pil, get_crop_region, expand_crop, crop_cond |
|
|
from modules import shared |
|
|
|
|
|
if (not hasattr(Image, 'Resampling')): |
|
|
Image.Resampling = Image |
|
|
|
|
|
|
|
|
class StableDiffusionProcessing: |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
init_img, |
|
|
model, |
|
|
positive, |
|
|
negative, |
|
|
vae, |
|
|
seed, |
|
|
steps, |
|
|
cfg, |
|
|
sampler_name, |
|
|
scheduler, |
|
|
denoise, |
|
|
upscale_by, |
|
|
uniform_tile_mode, |
|
|
tiled_decode, |
|
|
custom_sampler=None, |
|
|
custom_sigmas=None |
|
|
): |
|
|
|
|
|
self.init_images = [init_img] |
|
|
self.image_mask = None |
|
|
self.mask_blur = 0 |
|
|
self.inpaint_full_res_padding = 0 |
|
|
self.width = init_img.width |
|
|
self.height = init_img.height |
|
|
|
|
|
|
|
|
self.model = model |
|
|
self.positive = positive |
|
|
self.negative = negative |
|
|
self.vae = vae |
|
|
self.seed = seed |
|
|
self.steps = steps |
|
|
self.cfg = cfg |
|
|
self.sampler_name = sampler_name |
|
|
self.scheduler = scheduler |
|
|
self.denoise = denoise |
|
|
|
|
|
|
|
|
self.custom_sampler = custom_sampler |
|
|
self.custom_sigmas = custom_sigmas |
|
|
|
|
|
if (custom_sampler is not None) ^ (custom_sigmas is not None): |
|
|
print("[USDU] Both custom sampler and custom sigmas must be provided, defaulting to widget sampler and sigmas") |
|
|
|
|
|
|
|
|
self.init_size = init_img.width, init_img.height |
|
|
self.upscale_by = upscale_by |
|
|
self.uniform_tile_mode = uniform_tile_mode |
|
|
self.tiled_decode = tiled_decode |
|
|
self.vae_decoder = VAEDecode() |
|
|
self.vae_encoder = VAEEncode() |
|
|
self.vae_decoder_tiled = VAEDecodeTiled() |
|
|
|
|
|
|
|
|
self.extra_generation_params = {} |
|
|
|
|
|
|
|
|
class Processed: |
|
|
|
|
|
def __init__(self, p: StableDiffusionProcessing, images: list, seed: int, info: str): |
|
|
self.images = images |
|
|
self.seed = seed |
|
|
self.info = info |
|
|
|
|
|
def infotext(self, p: StableDiffusionProcessing, index): |
|
|
return None |
|
|
|
|
|
|
|
|
def fix_seed(p: StableDiffusionProcessing): |
|
|
pass |
|
|
|
|
|
|
|
|
def sample(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise, custom_sampler, custom_sigmas): |
|
|
|
|
|
|
|
|
|
|
|
if custom_sampler is not None and custom_sigmas is not None: |
|
|
custom_sample = SamplerCustom() |
|
|
(samples, _) = getattr(custom_sample, custom_sample.FUNCTION)( |
|
|
model=model, |
|
|
add_noise=True, |
|
|
noise_seed=seed, |
|
|
cfg=cfg, |
|
|
positive=positive, |
|
|
negative=negative, |
|
|
sampler=custom_sampler, |
|
|
sigmas=custom_sigmas, |
|
|
latent_image=latent |
|
|
) |
|
|
return samples |
|
|
|
|
|
|
|
|
(samples,) = common_ksampler(model, seed, steps, cfg, sampler_name, |
|
|
scheduler, positive, negative, latent, denoise=denoise) |
|
|
return samples |
|
|
|
|
|
|
|
|
def process_images(p: StableDiffusionProcessing) -> Processed: |
|
|
|
|
|
|
|
|
|
|
|
image_mask = p.image_mask.convert('L') |
|
|
init_image = p.init_images[0] |
|
|
|
|
|
|
|
|
crop_region = get_crop_region(image_mask, p.inpaint_full_res_padding) |
|
|
|
|
|
if p.uniform_tile_mode: |
|
|
|
|
|
x1, y1, x2, y2 = crop_region |
|
|
crop_width = x2 - x1 |
|
|
crop_height = y2 - y1 |
|
|
crop_ratio = crop_width / crop_height |
|
|
p_ratio = p.width / p.height |
|
|
if crop_ratio > p_ratio: |
|
|
target_width = crop_width |
|
|
target_height = round(crop_width / p_ratio) |
|
|
else: |
|
|
target_width = round(crop_height * p_ratio) |
|
|
target_height = crop_height |
|
|
crop_region, _ = expand_crop(crop_region, image_mask.width, image_mask.height, target_width, target_height) |
|
|
tile_size = p.width, p.height |
|
|
else: |
|
|
|
|
|
x1, y1, x2, y2 = crop_region |
|
|
crop_width = x2 - x1 |
|
|
crop_height = y2 - y1 |
|
|
target_width = math.ceil(crop_width / 8) * 8 |
|
|
target_height = math.ceil(crop_height / 8) * 8 |
|
|
crop_region, tile_size = expand_crop(crop_region, image_mask.width, |
|
|
image_mask.height, target_width, target_height) |
|
|
|
|
|
|
|
|
if p.mask_blur > 0: |
|
|
image_mask = image_mask.filter(ImageFilter.GaussianBlur(p.mask_blur)) |
|
|
|
|
|
|
|
|
tiles = [img.crop(crop_region) for img in shared.batch] |
|
|
|
|
|
|
|
|
initial_tile_size = tiles[0].size |
|
|
|
|
|
|
|
|
for i, tile in enumerate(tiles): |
|
|
if tile.size != tile_size: |
|
|
tiles[i] = tile.resize(tile_size, Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
positive_cropped = crop_cond(p.positive, crop_region, p.init_size, init_image.size, tile_size) |
|
|
negative_cropped = crop_cond(p.negative, crop_region, p.init_size, init_image.size, tile_size) |
|
|
|
|
|
|
|
|
batched_tiles = torch.cat([pil_to_tensor(tile) for tile in tiles], dim=0) |
|
|
(latent,) = p.vae_encoder.encode(p.vae, batched_tiles) |
|
|
|
|
|
|
|
|
samples = sample(p.model, p.seed, p.steps, p.cfg, p.sampler_name, p.scheduler, positive_cropped, |
|
|
negative_cropped, latent, p.denoise, p.custom_sampler, p.custom_sigmas) |
|
|
|
|
|
|
|
|
if not p.tiled_decode: |
|
|
(decoded,) = p.vae_decoder.decode(p.vae, samples) |
|
|
else: |
|
|
print("[USDU] Using tiled decode") |
|
|
(decoded,) = p.vae_decoder_tiled.decode(p.vae, samples, 512) |
|
|
|
|
|
|
|
|
tiles_sampled = [tensor_to_pil(decoded, i) for i in range(len(decoded))] |
|
|
|
|
|
for i, tile_sampled in enumerate(tiles_sampled): |
|
|
init_image = shared.batch[i] |
|
|
|
|
|
|
|
|
if tile_sampled.size != initial_tile_size: |
|
|
tile_sampled = tile_sampled.resize(initial_tile_size, Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
image_tile_only = Image.new('RGBA', init_image.size) |
|
|
image_tile_only.paste(tile_sampled, crop_region[:2]) |
|
|
|
|
|
|
|
|
|
|
|
temp = image_tile_only.copy() |
|
|
temp.putalpha(image_mask) |
|
|
image_tile_only.paste(temp, image_tile_only) |
|
|
|
|
|
|
|
|
result = init_image.convert('RGBA') |
|
|
result.alpha_composite(image_tile_only) |
|
|
|
|
|
|
|
|
result = result.convert('RGB') |
|
|
|
|
|
shared.batch[i] = result |
|
|
|
|
|
processed = Processed(p, [shared.batch[0]], p.seed, None) |
|
|
return processed |
|
|
|