import os import time import random import gc from typing import Optional, Tuple import gradio as gr import numpy as np import torch from PIL import Image, ImageChops, ImageFilter, ImageOps from diffusers import AutoPipelineForInpainting, LCMScheduler, AutoencoderTiny MODEL_CHOICES = { "DreamShaper 8 Inpainting": "Lykon/dreamshaper-8-inpainting", "Official SD1.5 Inpainting": "stable-diffusion-v1-5/stable-diffusion-inpainting", } DEFAULT_MODEL_LABEL = "DreamShaper 8 Inpainting" LCM_LORA_ID = os.getenv( "LCM_LORA_ID", "latent-consistency/lcm-lora-sdv1-5", ) TINY_VAE_ID = os.getenv( "TINY_VAE_ID", "madebyollin/taesd", ) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 PIPE = None PIPE_MODEL_ID = None def to_pil(x) -> Optional[Image.Image]: if x is None: return None if isinstance(x, Image.Image): return x if isinstance(x, np.ndarray): if x.dtype != np.uint8: x = np.clip(x, 0, 255).astype(np.uint8) return Image.fromarray(x) return None def resolve_model_id(model_label: str) -> str: return MODEL_CHOICES.get(model_label, MODEL_CHOICES[DEFAULT_MODEL_LABEL]) def unload_pipe(): global PIPE, PIPE_MODEL_ID if PIPE is not None: del PIPE PIPE = None PIPE_MODEL_ID = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def load_pipe(model_label: str): global PIPE, PIPE_MODEL_ID model_id = resolve_model_id(model_label) if PIPE is not None and PIPE_MODEL_ID == model_id: return PIPE unload_pipe() pipe = AutoPipelineForInpainting.from_pretrained( model_id, torch_dtype=DTYPE, safety_checker=None, requires_safety_checker=False, ) pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) pipe.load_lora_weights(LCM_LORA_ID) try: pipe.fuse_lora() except Exception: pass pipe.vae = AutoencoderTiny.from_pretrained( TINY_VAE_ID, torch_dtype=DTYPE, ) pipe = pipe.to(DEVICE) pipe.set_progress_bar_config(disable=True) try: pipe.enable_attention_slicing() except Exception: pass PIPE = pipe PIPE_MODEL_ID = model_id return PIPE def extract_image_and_mask(editor_value) -> Tuple[Image.Image, Image.Image]: if editor_value is None: raise gr.Error("Upload an image and draw a mask first.") if isinstance(editor_value, dict): background = to_pil(editor_value.get("background")) composite = to_pil(editor_value.get("composite")) layers = editor_value.get("layers") or [] else: background = to_pil(editor_value) composite = None layers = [] if background is None: raise gr.Error("No base image found. Upload an image first.") image = background.convert("RGB") mask = None for layer in layers: layer_img = to_pil(layer) if layer_img is None: continue if layer_img.size != image.size: layer_img = layer_img.resize(image.size, Image.Resampling.NEAREST) if layer_img.mode != "RGBA": layer_img = layer_img.convert("RGBA") alpha = layer_img.getchannel("A") if mask is None: mask = alpha else: mask = ImageChops.lighter(mask, alpha) if mask is None and composite is not None: composite = composite.convert("RGB") if composite.size != image.size: composite = composite.resize(image.size, Image.Resampling.NEAREST) diff = ImageChops.difference(image, composite).convert("L") mask = diff.point(lambda p: 255 if p > 12 else 0) if mask is None: raise gr.Error("Draw over the area you want to repaint.") mask = mask.convert("L") if np.array(mask).max() < 10: raise gr.Error("Mask is empty. Draw over the area you want to repaint.") return image, mask def mask_bbox(mask: Image.Image, threshold: int = 10) -> Tuple[int, int, int, int]: arr = np.array(mask.convert("L")) ys, xs = np.where(arr > threshold) if len(xs) == 0 or len(ys) == 0: raise gr.Error("Mask is empty. Draw over the area you want to repaint.") left = int(xs.min()) right = int(xs.max()) + 1 top = int(ys.min()) bottom = int(ys.max()) + 1 return left, top, right, bottom def clamp_bbox( bbox: Tuple[int, int, int, int], image_size: Tuple[int, int], ) -> Tuple[int, int, int, int]: w, h = image_size left, top, right, bottom = bbox left = max(0, min(left, w)) right = max(0, min(right, w)) top = max(0, min(top, h)) bottom = max(0, min(bottom, h)) return left, top, right, bottom def expand_bbox( bbox: Tuple[int, int, int, int], image_size: Tuple[int, int], padding: int, min_side: int, ) -> Tuple[int, int, int, int]: image_w, image_h = image_size left, top, right, bottom = bbox left -= padding top -= padding right += padding bottom += padding left, top, right, bottom = clamp_bbox( (left, top, right, bottom), image_size, ) crop_w = right - left crop_h = bottom - top if crop_w < min_side: extra = min_side - crop_w left -= extra // 2 right += extra - extra // 2 if crop_h < min_side: extra = min_side - crop_h top -= extra // 2 bottom += extra - extra // 2 if left < 0: right -= left left = 0 if top < 0: bottom -= top top = 0 if right > image_w: shift = right - image_w left -= shift right = image_w if bottom > image_h: shift = bottom - image_h top -= shift bottom = image_h left, top, right, bottom = clamp_bbox( (left, top, right, bottom), image_size, ) return left, top, right, bottom def pad_to_multiple_of_8( image: Image.Image, mask: Image.Image, ) -> Tuple[Image.Image, Image.Image, Tuple[int, int], Tuple[int, int]]: original_w, original_h = image.size padded_w = ((original_w + 7) // 8) * 8 padded_h = ((original_h + 7) // 8) * 8 pad_w = padded_w - original_w pad_h = padded_h - original_h if pad_w == 0 and pad_h == 0: return image, mask, (original_w, original_h), (padded_w, padded_h) image = ImageOps.expand( image, border=(0, 0, pad_w, pad_h), fill=0, ) mask = ImageOps.expand( mask, border=(0, 0, pad_w, pad_h), fill=0, ) return image, mask, (original_w, original_h), (padded_w, padded_h) def make_crop_inputs( image: Image.Image, mask: Image.Image, crop_padding: int, min_crop_side: int, ): raw_bbox = mask_bbox(mask) crop_bbox = expand_bbox( raw_bbox, image.size, padding=int(crop_padding), min_side=int(min_crop_side), ) crop_image = image.crop(crop_bbox) crop_mask = mask.crop(crop_bbox) padded_image, padded_mask, crop_size, padded_size = pad_to_multiple_of_8( crop_image, crop_mask, ) return { "raw_bbox": raw_bbox, "crop_bbox": crop_bbox, "crop_image": crop_image, "crop_mask": crop_mask, "padded_image": padded_image, "padded_mask": padded_mask, "crop_size": crop_size, "padded_size": padded_size, } def run_inpaint( editor_value, model_label: str, prompt: str, crop_padding: int, work_mode: str, min_working_window: int, steps: int, guidance_scale: float, strength: float, seed: int, random_seed: bool, mask_blur: int, paste_whole_crop: bool, ): if not prompt or not prompt.strip(): raise gr.Error("Write a prompt for the masked area.") model_id = resolve_model_id(model_label) pipe = load_pipe(model_label) original_image, original_mask = extract_image_and_mask(editor_value) original_image = original_image.convert("RGB") original_mask = original_mask.convert("L") if mask_blur > 0: generation_mask_source = original_mask.filter( ImageFilter.GaussianBlur(radius=int(mask_blur)) ) else: generation_mask_source = original_mask if work_mode == "Crop around mask": crop_info = make_crop_inputs( original_image, generation_mask_source, crop_padding=int(crop_padding), min_crop_side=int(min_working_window), ) gen_image = crop_info["padded_image"] gen_mask = crop_info["padded_mask"] crop_bbox = crop_info["crop_bbox"] crop_size = crop_info["crop_size"] padded_size = crop_info["padded_size"] else: crop_info = None gen_image, gen_mask, crop_size, padded_size = pad_to_multiple_of_8( original_image, generation_mask_source, ) crop_bbox = (0, 0, original_image.size[0], original_image.size[1]) if random_seed: seed = random.randint(0, 2**31 - 1) generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) start = time.perf_counter() with torch.inference_mode(): generated = pipe( prompt=prompt.strip(), image=gen_image, mask_image=gen_mask, height=gen_image.size[1], width=gen_image.size[0], num_inference_steps=int(steps), guidance_scale=float(guidance_scale), strength=float(strength), generator=generator, ).images[0].convert("RGB") elapsed = time.perf_counter() - start if padded_size != crop_size: generated = generated.crop((0, 0, crop_size[0], crop_size[1])) if work_mode == "Crop around mask": left, top, right, bottom = crop_bbox final_result = original_image.copy() if paste_whole_crop: final_result.paste(generated, (left, top)) paste_mode = "whole generated crop" else: final_mask = original_mask.crop(crop_bbox) if mask_blur > 0: final_mask = final_mask.filter( ImageFilter.GaussianBlur(radius=int(mask_blur)) ) final_result.paste(generated, (left, top), final_mask) paste_mode = "masked area only" crop_report = ( f"- raw mask bbox: `{crop_info['raw_bbox']}`\n" f"- working crop bbox: `{crop_bbox}`\n" f"- working crop size: `{crop_size[0]}x{crop_size[1]}`\n" f"- generation size: `{padded_size[0]}x{padded_size[1]}`\n" f"- paste mode: `{paste_mode}`" ) else: if paste_whole_crop: final_result = generated paste_mode = "whole generated image" else: final_mask = original_mask if mask_blur > 0: final_mask = final_mask.filter( ImageFilter.GaussianBlur(radius=int(mask_blur)) ) final_result = Image.composite(generated, original_image, final_mask) paste_mode = "masked area only" crop_report = ( "- working crop: `none`\n" f"- original size: `{original_image.size[0]}x{original_image.size[1]}`\n" f"- generation size: `{padded_size[0]}x{padded_size[1]}`\n" f"- paste mode: `{paste_mode}`" ) padding_used = "yes" if padded_size != crop_size else "no" info = ( "**Done**\n\n" f"- device: `{DEVICE}`\n" f"- selected model: `{model_label}`\n" f"- model id: `{model_id}`\n" "- speed trick: `LCM-LoRA`\n" "- vae: `TAESD`\n" f"- work mode: `{work_mode}`\n" "- resize: `none`\n" f"- context around mask: `{crop_padding}` px\n" f"- minimum working window: `{min_working_window}` px\n" f"- padding to multiple of 8: `{padding_used}`\n" f"{crop_report}\n" f"- steps: `{steps}`\n" f"- guidance: `{guidance_scale}`\n" f"- strength: `{strength}`\n" f"- mask blur: `{mask_blur}`\n" f"- seed: `{seed}`\n" f"- time: `{elapsed:.1f}s`\n\n" "This version does not downscale or upscale. " "In crop mode it sends the mask bbox plus surrounding context to the model, " "then pastes the result back into the original image." ) return final_result, info with gr.Blocks(title="SD15 Light Inpaint CPU") as demo: gr.Markdown( "# SD15 Light Inpaint CPU\n\n" "Upload an image, draw over the area you want to repaint, and describe what should appear there.\n\n" "Default mode uses a local crop around the mask instead of resizing the whole image. " "This keeps the model working at the original local visual scale while still reducing the amount of image sent to the model." ) with gr.Row(): with gr.Column(): editor = gr.ImageEditor( label="Image + mask", type="pil", ) model_label = gr.Dropdown( label="Base inpaint model", choices=list(MODEL_CHOICES.keys()), value=DEFAULT_MODEL_LABEL, ) prompt = gr.Textbox( label="Prompt for masked area", value="a beautiful fantasy detail, coherent with the original image, natural lighting", lines=3, ) crop_padding = gr.Slider( label="Context around mask", minimum=32, maximum=384, step=32, value=128, ) with gr.Accordion("Advanced", open=False): work_mode = gr.Radio( label="Work mode", choices=[ "Crop around mask", "Full image no resize", ], value="Crop around mask", ) min_working_window = gr.Slider( label="Minimum working window size", minimum=128, maximum=768, step=64, value=384, ) paste_whole_crop = gr.Checkbox( label="Paste whole generated crop back, diagnostic", value=False, ) steps = gr.Slider( label="Steps", minimum=2, maximum=8, step=1, value=4, ) guidance_scale = gr.Slider( label="Guidance scale / CFG. LCM usually works best around 1.0-2.0", minimum=1.0, maximum=8.0, step=0.1, value=1.5, ) strength = gr.Slider( label="Strength", minimum=0.3, maximum=1.0, step=0.05, value=0.85, ) mask_blur = gr.Slider( label="Mask blur", minimum=0, maximum=16, step=1, value=0, ) seed = gr.Number( label="Seed", value=12345, precision=0, ) random_seed = gr.Checkbox( label="Random seed", value=False, ) button = gr.Button("Generate", variant="primary") with gr.Column(): output = gr.Image(label="Result", type="pil") info = gr.Markdown() button.click( fn=run_inpaint, inputs=[ editor, model_label, prompt, crop_padding, work_mode, min_working_window, steps, guidance_scale, strength, seed, random_seed, mask_blur, paste_whole_crop, ], outputs=[output, info], ) demo.queue(max_size=4).launch( server_name="0.0.0.0", server_port=7860, ssr_mode=False, )